From f68886d5dfa11f5090a37f35c4049d25dfb23870 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 9 Oct 2023 17:29:50 +0100 Subject: [PATCH] Fix bugged actions --- example_config.yaml | 202 +++++++++--------- src/primaite/game/agent/actions.py | 118 +++++++--- src/primaite/game/agent/observations.py | 58 ++--- .../network/hardware/nodes/router.py | 10 +- .../network/hardware/nodes/switch.py | 11 +- 5 files changed, 230 insertions(+), 169 deletions(-) diff --git a/example_config.yaml b/example_config.yaml index 9f679223..f7faf589 100644 --- a/example_config.yaml +++ b/example_config.yaml @@ -63,8 +63,8 @@ game_config: services: - service_ref: data_manipulation_bot observations: - - operating_status - - health_status + operating_status + health_status folders: {} action_space: @@ -197,221 +197,221 @@ game_config: 1: action: NODE_SERVICE_SCAN options: - - node_id: 2 - - service_id: 1 + node_id: 2 + service_id: 1 # stop webapp service 2: action: NODE_SERVICE_STOP options: - - node_id: 2 - - service_id: 1 + node_id: 2 + service_id: 1 # start webapp service 3: action: "NODE_SERVICE_START" options: - - node_id: 2 - - service_id: 1 + node_id: 2 + service_id: 1 4: action: "NODE_SERVICE_PAUSE" options: - - node_id: 2 - - service_id: 1 + node_id: 2 + service_id: 1 5: action: "NODE_SERVICE_RESUME" options: - - node_id: 2 - - service_id: 1 + node_id: 2 + service_id: 1 6: action: "NODE_SERVICE_RESTART" options: - - node_id: 2 - - service_id: 1 + node_id: 2 + service_id: 1 7: action: "NODE_SERVICE_DISABLE" options: - - node_id: 2 - - service_id: 1 + node_id: 2 + service_id: 1 8: action: "NODE_SERVICE_ENABLE" options: - - node_id: 2 - - service_id: 1 + node_id: 2 + service_id: 1 9: action: "NODE_FILE_SCAN" options: - - node_id: 3 - - folder_id: 1 - - file_id: 1 + node_id: 3 + folder_id: 1 + file_id: 1 10: action: "NODE_FILE_CHECKHASH" options: - - node_id: 3 - - folder_id: 1 - - file_id: 1 + node_id: 3 + folder_id: 1 + file_id: 1 11: action: "NODE_FILE_DELETE" options: - - node_id: 3 - - folder_id: 1 - - file_id: 1 + node_id: 3 + folder_id: 1 + file_id: 1 12: action: "NODE_FILE_REPAIR" options: - - node_id: 3 - - folder_id: 1 - - file_id: 1 + node_id: 3 + folder_id: 1 + file_id: 1 13: action: "NODE_FILE_RESTORE" options: - - node_id: 3 - - folder_id: 1 - - file_id: 1 + node_id: 3 + folder_id: 1 + file_id: 1 14: action: "NODE_FOLDER_SCAN" options: - - node_id: 3 - - folder_id: 1 + node_id: 3 + folder_id: 1 15: action: "NODE_FOLDER_CHECKHASH" options: - - node_id: 3 - - folder_id: 1 + node_id: 3 + folder_id: 1 16: action: "NODE_FOLDER_REPAIR" options: - - node_id: 3 - - folder_id: 1 + node_id: 3 + folder_id: 1 17: action: "NODE_FOLDER_RESTORE" options: - - node_id: 3 - - folder_id: 1 + node_id: 3 + folder_id: 1 18: action: "NODE_OS_SCAN" options: - - node_id: 3 + node_id: 3 19: action: "NODE_SHUTDOWN" options: - - node_id: 6 + node_id: 6 20: action: "NODE_STARTUP" options: - - node_id: 6 + node_id: 6 21: action: "NODE_RESET" options: - - node_id: 6 + node_id: 6 22: action: "NETWORK_ACL_ADDRULE" options: - - position: 6 - - permission: 2 - - source_node_id: ... - - dest_node_id: ... - - source_port_id: ... - - dest_port_id: ... - - protocol_id: ... + position: 1 + permission: 2 + source_ip_id: 7 + dest_ip_id: 1 + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 23: action: "NETWORK_ACL_ADDRULE" options: - - position: 5 - - permission: 2 - - source_node_id: ... - - dest_node_id: ... - - source_port_id: ... - - dest_port_id: ... - - protocol_id: ... + position: 1 + permission: 2 + source_ip_id: 8 + dest_ip_id: 1 + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 24: action: "NETWORK_ACL_ADDRULE" options: - - position: 4 - - permission: 2 - - source_node_id: ... - - dest_node_id: ... - - source_port_id: ... - - dest_port_id: ... - - protocol_id: ... + position: 1 + permission: 2 + source_ip_id: 7 + dest_ip_id: 3 + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 25: action: "NETWORK_ACL_ADDRULE" options: - - position: 3 - - permission: 2 - - source_node_id: ... - - dest_node_id: ... - - source_port_id: ... - - dest_port_id: ... - - protocol_id: ... + position: 1 + permission: 2 + source_ip_id: 8 + dest_ip_id: 3 + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 26: action: "NETWORK_ACL_ADDRULE" options: - - position: 2 - - permission: 2 - - source_node_id: ... - - dest_node_id: ... - - source_port_id: ... - - dest_port_id: ... - - protocol_id: ... + position: 1 + permission: 2 + source_ip_id: 7 + dest_ip_id: 4 + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 27: action: "NETWORK_ACL_ADDRULE" options: - - position: 1 - - permission: 2 - - source_node_id: ... - - dest_node_id: ... - - source_port_id: ... - - dest_port_id: ... - - protocol_id: ... + position: 1 + permission: 2 + source_ip_id: 8 + dest_ip_id: 4 + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 28: action: "NETWORK_ACL_REMOVERULE" options: - - position: 0 + position: 0 29: action: "NETWORK_ACL_REMOVERULE" options: - - position: 1 + position: 1 30: action: "NETWORK_ACL_REMOVERULE" options: - - position: 2 + position: 2 31: action: "NETWORK_ACL_REMOVERULE" options: - - position: 3 + position: 3 32: action: "NETWORK_ACL_REMOVERULE" options: - - position: 4 + position: 4 33: action: "NETWORK_ACL_REMOVERULE" options: - - position: 5 + position: 5 34: action: "NETWORK_ACL_REMOVERULE" options: - - position: 6 + position: 6 35: action: "NETWORK_ACL_REMOVERULE" options: - - position: 7 + position: 7 36: action: "NETWORK_ACL_REMOVERULE" options: - - position: 8 + position: 8 37: action: "NETWORK_ACL_REMOVERULE" options: - - position: 9 + position: 9 38: action: "NETWORK_NIC_DISABLE" options: - - node_id: 6 - - nic_index: 1 + node_id: 6 + nic_id: 1 39: action: "NETWORK_NIC_ENABLE" options: - - node_id: 6 - - nic_index: 1 + node_id: 6 + nic_id: 1 options: nodes: diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 3f674fbb..1e6893ff 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -5,6 +5,8 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING from gym import spaces from primaite.simulator.sim_container import Simulation +from primaite import getLogger +_LOGGER = getLogger(__name__) if TYPE_CHECKING: from primaite.game.session import PrimaiteSession @@ -253,7 +255,7 @@ class NodeShutdownAction(NodeAbstractAction): class NodeStartupAction(NodeAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes) - self.verb = "start" + self.verb = "startup" class NodeResetAction(NodeAbstractAction): @@ -274,33 +276,73 @@ class NetworkACLAddRuleAction(AbstractAction): **kwargs, ) -> None: super().__init__(manager=manager) - num_permissions = 2 + num_permissions = 3 self.shape: Dict[str, int] = { "position": max_acl_rules, "permission": num_permissions, - "source_ip_idx": num_ips, - "dest_ip_idx": num_ips, - "source_port_idx": num_ports, - "dest_port_idx": num_ports, - "protocol_idx": num_protocols, + "source_ip_id": num_ips, + "dest_ip_id": num_ips, + "source_port_id": num_ports, + "dest_port_id": num_ports, + "protocol_id": num_protocols, } self.target_router_uuid: str = target_router_uuid def form_request( - self, position, permission, source_ip_idx, dest_ip_idx, source_port_idx, dest_port_idx, protocol_idx + self, position, permission, source_ip_id, dest_ip_id, source_port_id, dest_port_id, protocol_id ) -> List[str]: - protocol = self.manager.get_internet_protocol_by_idx(protocol_idx) - src_ip = self.manager.get_ip_address_by_idx(source_ip_idx) - src_port = self.manager.get_port_by_idx(source_port_idx) - dst_ip = self.manager.get_ip_address_by_idx(dest_ip_idx) - dst_port = self.manager.get_port_by_idx(dest_port_idx) + if permission == 0: + permission_str = "UNUSED" + return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS + elif permission == 1: + permission_str = "ALLOW" + elif permission == 2: + permission_str = "DENY" + else: + _LOGGER.warn(f"{self.__class__} received permission {permission}, expected 0 or 1.") + + if protocol_id == 0: + return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS + + if protocol_id == 1: + protocol = "ALL" + else: + protocol = self.manager.get_internet_protocol_by_idx(protocol_id-2) + # subtract 2 to account for UNUSED=0 and ALL=1. + + if source_ip_id in [0,1]: + src_ip = "ALL" + return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS + else: + src_ip = self.manager.get_ip_address_by_idx(source_ip_id-2) + # subtract 2 to account for UNUSED=0, and ALL=1 + + if source_port_id == 1: + src_port = "ALL" + else: + src_port = self.manager.get_port_by_idx(source_port_id-2) + # subtract 2 to account for UNUSED=0, and ALL=1 + + if dest_ip_id in (0,1): + dst_ip = "ALL" + return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS + else: + dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id) + # subtract 2 to account for UNUSED=0, and ALL=1 + + if dest_port_id == 1: + dst_port = "ALL" + else: + dst_port = self.manager.get_port_by_idx(dest_port_id) + # subtract 2 to account for UNUSED=0, and ALL=1 + return [ "network", "node", self.target_router_uuid, "acl", "add_rule", - permission, + permission_str, protocol, src_ip, src_port, @@ -320,36 +362,52 @@ class NetworkACLRemoveRuleAction(AbstractAction): return ["network", "node", self.target_router_uuid, "acl", "remove_rule", position] -class NetworkNICEnableAction(AbstractAction): +class NetworkNICAbstractAction(AbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: super().__init__(manager=manager) self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node} + self.verb: str def form_request(self, node_id: int, nic_id: int) -> List[str]: + node_uuid = self.manager.get_node_uuid_by_idx(node_idx=node_id) + nic_uuid = self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id) + if node_uuid is None or nic_uuid is None: + return ["do_nothing"] return [ "network", "node", - self.manager.get_node_uuid_by_idx(node_idx=node_id), + node_uuid, "nic", - self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id), - "enable", + nic_uuid, + self.verb, ] -class NetworkNICDisableAction(AbstractAction): +class NetworkNICEnableAction(NetworkNICAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node} + super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs) + self.verb = "enable" - def form_request(self, node_id: int, nic_id: int) -> List[str]: - return [ - "network", - "node", - self.manager.get_node_uuid_by_idx(node_idx=node_id), - "nic", - self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id), - "disable", - ] + +class NetworkNICDisableAction(NetworkNICAbstractAction): + def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: + super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs) + self.verb = "disable" + +# class NetworkNICDisableAction(AbstractAction): +# def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: +# super().__init__(manager=manager) +# self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node} + +# def form_request(self, node_id: int, nic_id: int) -> List[str]: +# return [ +# "network", +# "node", +# self.manager.get_node_uuid_by_idx(node_idx=node_id), +# "nic", +# self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id), +# "disable", +# ] class ActionManager: diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index c5b931ee..28c87af1 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, Hashable, List, Optional, TYPE_CHECKING +from typing import Any, Dict, Hashable, List, Optional, TYPE_CHECKING, Sequence, Tuple from gym import spaces from pydantic import BaseModel @@ -15,7 +15,7 @@ the thing requested in the state could equal None. This NOT_PRESENT_IN_STATE is """ -def access_from_nested_dict(dictionary: Dict, keys: List[Hashable]) -> Any: +def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any: """ Access an item from a deeply dictionary with a list of keys. @@ -29,12 +29,13 @@ def access_from_nested_dict(dictionary: Dict, keys: List[Hashable]) -> Any: :return: The value in the dictionary :rtype: Any """ - if len(keys) == 0: + key_list = [*keys] # copy keys to a new list to prevent editing original list + if len(key_list) == 0: return dictionary - k = keys.pop(0) + k = key_list.pop(0) if k not in dictionary: return NOT_PRESENT_IN_STATE - return access_from_nested_dict(dictionary[k], keys) + return access_from_nested_dict(dictionary[k], key_list) class AbstractObservation(ABC): @@ -66,7 +67,7 @@ class AbstractObservation(ABC): class FileObservation(AbstractObservation): - def __init__(self, where: Optional[List[str]] = None) -> None: + def __init__(self, where: Optional[Tuple[str]] = None) -> None: """ _summary_ @@ -79,7 +80,7 @@ class FileObservation(AbstractObservation): :type where: Optional[List[str]] """ super().__init__() - self.where: Optional[List[str]] = where + self.where: Optional[Tuple[str]] = where self.default_observation: spaces.Space = {"health_status": 0} "Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted." @@ -104,7 +105,7 @@ class ServiceObservation(AbstractObservation): default_observation: spaces.Space = {"operating_status": 0, "health_status": 0} "Default observation is what should be returned when the service doesn't exist." - def __init__(self, where: Optional[List[str]] = None) -> None: + def __init__(self, where: Optional[Tuple[str]] = None) -> None: """ :param where: Store information about where in the simulation state dictionary to find the relevant information. Optional. If None, this corresponds that the file does not exist and the observation will be populated with @@ -115,7 +116,7 @@ class ServiceObservation(AbstractObservation): :type where: Optional[List[str]] """ super().__init__() - self.where: Optional[List[str]] = where + self.where: Optional[Tuple[str]] = where def observe(self, state: Dict) -> Dict: if self.where is None: @@ -124,7 +125,7 @@ class ServiceObservation(AbstractObservation): service_state = access_from_nested_dict(state, self.where) if service_state is NOT_PRESENT_IN_STATE: return self.default_observation - return {"operating_status": service_state["operating_status"], "health_status": service_state["health_status"]} + return {"operating_status": service_state["operating_state"], "health_status": service_state["health_status"]} @property def space(self) -> spaces.Space: @@ -132,7 +133,9 @@ class ServiceObservation(AbstractObservation): @classmethod def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where:Optional[List[str]]=None): - return cls(where=parent_where+["services",session.ref_map_services[config['service_ref']]]) + return cls( + where=parent_where+["services",session.ref_map_services[config['service_ref']].uuid] + ) @@ -140,7 +143,7 @@ class LinkObservation(AbstractObservation): default_observation: spaces.Space = {"protocols": {"all": {"load": 0}}} "Default observation is what should be returned when the link doesn't exist." - def __init__(self, where: Optional[List[str]] = None) -> None: + def __init__(self, where: Optional[Tuple[str]] = None) -> None: """ :param where: Store information about where in the simulation state dictionary to find the relevant information. Optional. If None, this corresponds that the file does not exist and the observation will be populated with @@ -151,7 +154,7 @@ class LinkObservation(AbstractObservation): :type where: Optional[List[str]] """ super().__init__() - self.where: Optional[List[str]] = where + self.where: Optional[Tuple[str]] = where def observe(self, state: Dict) -> Dict: if self.where is None: @@ -180,7 +183,7 @@ class LinkObservation(AbstractObservation): class FolderObservation(AbstractObservation): - def __init__(self, where: Optional[List[str]] = None, files: List[FileObservation] = []) -> None: + def __init__(self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = []) -> None: """Initialise folder Observation, including files inside of the folder. :param where: Where in the simulation state dictionary to find the relevant information for this folder. @@ -199,7 +202,7 @@ class FolderObservation(AbstractObservation): """ super().__init__() - self.where: Optional[List[str]] = where + self.where: Optional[Tuple[str]] = where self.files: List[FileObservation] = files @@ -246,9 +249,9 @@ class FolderObservation(AbstractObservation): class NicObservation(AbstractObservation): default_observation: spaces.Space = {"nic_status": 0} - def __init__(self, where: Optional[List[str]] = None) -> None: + def __init__(self, where: Optional[Tuple[str]] = None) -> None: super().__init__() - self.where: Optional[List[str]] = where + self.where: Optional[Tuple[str]] = where def observe(self, state: Dict) -> Dict: if self.where is None: @@ -271,7 +274,7 @@ class NicObservation(AbstractObservation): class NodeObservation(AbstractObservation): def __init__( self, - where: Optional[List[str]] = None, + where: Optional[Tuple[str]] = None, services: List[ServiceObservation] = [], folders: List[FolderObservation] = [], nics: List[NicObservation] = [], @@ -298,7 +301,7 @@ class NodeObservation(AbstractObservation): :type max_nics: int, optional """ super().__init__() - self.where: Optional[List[str]] = where + self.where: Optional[Tuple[str]] = where self.services: List[ServiceObservation] = services self.folders: List[FolderObservation] = folders @@ -371,10 +374,10 @@ class AclObservation(AbstractObservation): # if a file is created at runtime, we have currently got no way of telling the observation space to track it. # this needs adding, but not for the MVP. def __init__( - self, node_ip_to_id: Dict[str,int], ports: List[int], protocols: list[str], where: Optional[List[str]] = None, num_rules: int = 10 + self, node_ip_to_id: Dict[str,int], ports: List[int], protocols: list[str], where: Optional[Tuple[str]] = None, num_rules: int = 10 ) -> None: super().__init__() - self.where: Optional[List[str]] = where + self.where: Optional[Tuple[str]] = where self.num_rules: int = num_rules self.node_to_id: Dict[str, int] = node_ip_to_id "List of node IP addresses, order in this list determines how they are converted to an ID" @@ -403,6 +406,8 @@ class AclObservation(AbstractObservation): if acl_state is NOT_PRESENT_IN_STATE: return self.default_observation + + #TODO: what if the ACL has more rules than num of max rules for obs space obs = {} obs["RULES"] = {} for i, rule_state in acl_state.items(): @@ -466,7 +471,7 @@ class AclObservation(AbstractObservation): node_ip_to_id=node_ip_to_idx, ports=session.options.ports, protocols=session.options.protocols, - where=["network", "nodes", router_uuid]) + where=["network", "nodes", router_uuid, "acl", "acl"]) @@ -498,7 +503,7 @@ class UC2BlueObservation(AbstractObservation): where:Optional[List[str]] = None, ) -> None: super().__init__() - self.where: Optional[List[str]] = where + self.where: Optional[Tuple[str]] = where self.nodes: List[NodeObservation] = nodes self.links: List[LinkObservation] = links @@ -517,11 +522,10 @@ class UC2BlueObservation(AbstractObservation): return self.default_observation obs = {} - obs['NODES'] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)} obs['LINKS'] = {i + 1: link.observe(state) for i, link in enumerate(self.links)} - obs['ACL'] = {self.acl.observe(state)} - obs['ICS'] = {self.ics.observe(state)} + obs['ACL'] = self.acl.observe(state) + obs['ICS'] = self.ics.observe(state) return obs @@ -546,7 +550,7 @@ class UC2BlueObservation(AbstractObservation): acl = AclObservation.from_config(config=acl_config, session=session) ics_config = config["ics"] - ics = ICSObservation.from_config(ics_config) + ics = ICSObservation.from_config(config=ics_config, session=session) new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=['network']) return new diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index 2e7681a9..3691c101 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -111,11 +111,11 @@ class AccessControlList(SimComponent): Action( func=lambda request, context: self.add_rule( ACLAction[request[0]], - IPProtocol[request[1]], - IPv4Address[request[2]], - Port[request[3]], - IPv4Address[request[4]], - Port[request[5]], + None if request[1] is "ALL" else IPProtocol[request[1]], + IPv4Address(request[2]), + None if request[3] is "ALL" else Port[request[3]], + IPv4Address(request[4]), + None if request[5] is "ALL" else Port[request[5]], int(request[6]), ) ), diff --git a/src/primaite/simulator/network/hardware/nodes/switch.py b/src/primaite/simulator/network/hardware/nodes/switch.py index ac8dabd1..bb296203 100644 --- a/src/primaite/simulator/network/hardware/nodes/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/switch.py @@ -55,12 +55,11 @@ class Switch(Node): :return: Current state of this object and child objects. """ - return { - "uuid": self.uuid, - "num_ports": self.num_ports, # redundant? - "ports": {port_num: port.describe_state() for port_num, port in self.switch_ports.items()}, - "mac_address_table": {mac: port for mac, port in self.mac_address_table.items()}, - } + state = super().describe_state() + state["ports"] = {port_num: port.describe_state() for port_num, port in self.switch_ports.items()} + state["num_ports"]= self.num_ports # redundant? + state["mac_address_table"]= {mac: port for mac, port in self.mac_address_table.items()} + return state def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort): """