From 91f06c15f6d52ef821dde93f5ad7ec7d23f95072 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 9 Oct 2023 18:35:30 +0100 Subject: [PATCH] Fix formatting with precommit --- src/primaite/game/agent/actions.py | 67 +++---- src/primaite/game/agent/interface.py | 6 +- src/primaite/game/agent/observations.py | 185 ++++++++++-------- src/primaite/game/agent/rewards.py | 26 +-- src/primaite/game/session.py | 61 +++--- .../network/hardware/nodes/switch.py | 4 +- 6 files changed, 189 insertions(+), 160 deletions(-) diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 1e6893ff..cba90305 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -4,8 +4,9 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING from gym import spaces -from primaite.simulator.sim_container import Simulation from primaite import getLogger +from primaite.simulator.sim_container import Simulation + _LOGGER = getLogger(__name__) if TYPE_CHECKING: @@ -90,56 +91,56 @@ class NodeServiceAbstractAction(AbstractAction): class NodeServiceScanAction(NodeServiceAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "scan" class NodeServiceStopAction(NodeServiceAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "stop" class NodeServiceStartAction(NodeServiceAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "start" class NodeServicePauseAction(NodeServiceAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "pause" class NodeServiceResumeAction(NodeServiceAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "resume" class NodeServiceRestartAction(NodeServiceAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "restart" class NodeServiceDisableAction(NodeServiceAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "disable" class NodeServiceEnableAction(NodeServiceAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "enable" class NodeFolderAbstractAction(AbstractAction): @abstractmethod - def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: super().__init__(manager=manager) self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders} self.verb: str @@ -153,25 +154,25 @@ class NodeFolderAbstractAction(AbstractAction): class NodeFolderScanAction(NodeFolderAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) self.verb: str = "scan" class NodeFolderCheckhashAction(NodeFolderAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) self.verb: str = "checkhash" class NodeFolderRepairAction(NodeFolderAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) self.verb: str = "repair" class NodeFolderRestoreAction(NodeFolderAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) self.verb: str = "restore" @@ -293,7 +294,7 @@ class NetworkACLAddRuleAction(AbstractAction): ) -> List[str]: if permission == 0: permission_str = "UNUSED" - return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS + return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS elif permission == 1: permission_str = "ALLOW" elif permission == 2: @@ -302,30 +303,30 @@ class NetworkACLAddRuleAction(AbstractAction): _LOGGER.warn(f"{self.__class__} received permission {permission}, expected 0 or 1.") if protocol_id == 0: - return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS + return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS if protocol_id == 1: protocol = "ALL" else: - protocol = self.manager.get_internet_protocol_by_idx(protocol_id-2) + protocol = self.manager.get_internet_protocol_by_idx(protocol_id - 2) # subtract 2 to account for UNUSED=0 and ALL=1. - if source_ip_id in [0,1]: + if source_ip_id in [0, 1]: src_ip = "ALL" - return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS + return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS else: - src_ip = self.manager.get_ip_address_by_idx(source_ip_id-2) + src_ip = self.manager.get_ip_address_by_idx(source_ip_id - 2) # subtract 2 to account for UNUSED=0, and ALL=1 if source_port_id == 1: src_port = "ALL" else: - src_port = self.manager.get_port_by_idx(source_port_id-2) + src_port = self.manager.get_port_by_idx(source_port_id - 2) # subtract 2 to account for UNUSED=0, and ALL=1 - if dest_ip_id in (0,1): + if dest_ip_id in (0, 1): dst_ip = "ALL" - return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS + return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS else: dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id) # subtract 2 to account for UNUSED=0, and ALL=1 @@ -394,6 +395,7 @@ class NetworkNICDisableAction(NetworkNICAbstractAction): super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs) self.verb = "disable" + # class NetworkNICDisableAction(AbstractAction): # def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: # super().__init__(manager=manager) @@ -495,7 +497,7 @@ class ActionManager: "num_protocols": len(self.protocols), "num_ports": len(self.protocols), "num_ips": len(self.ip_address_list), - "max_acl_rules":max_acl_rules, + "max_acl_rules": max_acl_rules, "max_nics_per_node": max_nics_per_node, } self.actions: Dict[str, AbstractAction] = {} @@ -507,8 +509,8 @@ class ActionManager: # option_2: value2 # where `type` decides which AbstractAction subclass should be used # and `options` is an optional dict of options to pass to the init method of the action class - act_type = act_spec.get('type') - act_options = act_spec.get('options', {}) + act_type = act_spec.get("type") + act_options = act_spec.get("options", {}) self.actions[act_type] = self.__act_class_identifiers[act_type](self, **global_action_args, **act_options) self.action_map: Dict[int, Tuple[str, Dict]] = {} @@ -555,13 +557,12 @@ class ActionManager: param_combinations = list(itertools.product(*possibilities)) all_action_possibilities.extend( [ - ( - act_name, {param_names[i]:param_combinations[j][i] for i in range(len(param_names))} - ) for j in range(len(param_combinations))] - ) - - return {i:p for i,p in enumerate(all_action_possibilities)} + (act_name, {param_names[i]: param_combinations[j][i] for i in range(len(param_names))}) + for j in range(len(param_combinations)) + ] + ) + return {i: p for i, p in enumerate(all_action_possibilities)} def get_action(self, action: int) -> Tuple[str, Dict]: """Produce action in CAOS format""" @@ -627,7 +628,7 @@ class ActionManager: session=session, actions=cfg["action_list"], # node_uuids=cfg["options"]["node_uuids"], - **cfg['options'], + **cfg["options"], protocols=session.options.protocols, ports=session.options.ports, ip_address_list=None, diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 6083db6f..817e59b1 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -23,7 +23,7 @@ class AbstractAgent(ABC): observation_space: Optional[ObservationSpace], reward_function: Optional[RewardFunction], ) -> None: - self.agent_name:str = agent_name or "unnamed_agent" + self.agent_name: str = agent_name or "unnamed_agent" self.action_space: Optional[ActionManager] = action_space self.observation_space: Optional[ObservationSpace] = observation_space self.reward_function: Optional[RewardFunction] = reward_function @@ -46,9 +46,9 @@ class AbstractAgent(ABC): def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: # in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39, # then use a bespoke conversion to take 1-40 int back into CAOS action - return ("DO_NOTHING", {} ) + return ("DO_NOTHING", {}) - def format_request(self, action:Tuple[str,Dict], options:Dict[str, int]) -> List[str]: + def format_request(self, action: Tuple[str, Dict], options: Dict[str, int]) -> List[str]: # this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator. # therefore the execution definition needs to be a mapping from CAOS into SIMULATOR """Format action into format expected by the simulator, and apply execution definition if applicable.""" diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index 28c87af1..7b10f957 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -1,10 +1,11 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, Hashable, List, Optional, TYPE_CHECKING, Sequence, Tuple +from typing import Any, Dict, Hashable, List, Optional, Sequence, Tuple, TYPE_CHECKING from gym import spaces from pydantic import BaseModel from primaite.simulator.sim_container import Simulation + if TYPE_CHECKING: from primaite.game.session import PrimaiteSession @@ -29,7 +30,7 @@ def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any: :return: The value in the dictionary :rtype: Any """ - key_list = [*keys] # copy keys to a new list to prevent editing original list + key_list = [*keys] # copy keys to a new list to prevent editing original list if len(key_list) == 0: return dictionary k = key_list.pop(0) @@ -58,7 +59,7 @@ class AbstractObservation(ABC): @classmethod @abstractmethod - def from_config(cls, config:Dict, session:"PrimaiteSession"): + def from_config(cls, config: Dict, session: "PrimaiteSession"): """Create this observation space component form a serialised format. The `session` parameter is for a the PrimaiteSession object that spawns this component. During deserialisation, @@ -98,7 +99,7 @@ class FileObservation(AbstractObservation): @classmethod def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where=None): - return cls(where=parent_where+["files", config["file_name"]]) + return cls(where=parent_where + ["files", config["file_name"]]) class ServiceObservation(AbstractObservation): @@ -132,11 +133,8 @@ class ServiceObservation(AbstractObservation): return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(6)}) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where:Optional[List[str]]=None): - return cls( - where=parent_where+["services",session.ref_map_services[config['service_ref']].uuid] - ) - + def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]] = None): + return cls(where=parent_where + ["services", session.ref_map_services[config["service_ref"]].uuid]) class LinkObservation(AbstractObservation): @@ -179,7 +177,7 @@ class LinkObservation(AbstractObservation): @classmethod def from_config(cls, config: Dict, session: "PrimaiteSession"): - return cls(where=['network','links', session.ref_map_links[config['link_ref']]]) + return cls(where=["network", "links", session.ref_map_links[config["link_ref"]]]) class FolderObservation(AbstractObservation): @@ -237,13 +235,13 @@ class FolderObservation(AbstractObservation): ) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where:Optional[List[str]]): - where = parent_where + ["folders", config['folder_name']] + def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]]): + where = parent_where + ["folders", config["folder_name"]] file_configs = config["files"] files = [FileObservation.from_config(config=f, session=session, parent_where=where) for f in file_configs] - return cls(where=where,files=files) + return cls(where=where, files=files) class NicObservation(AbstractObservation): @@ -267,7 +265,7 @@ class NicObservation(AbstractObservation): return spaces.Dict({"nic_status": spaces.Discrete(3)}) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where:Optional[List[str]]): + def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]]): return cls(where=parent_where + ["NICs", config["nic_uuid"]]) @@ -278,7 +276,7 @@ class NodeObservation(AbstractObservation): services: List[ServiceObservation] = [], folders: List[FolderObservation] = [], nics: List[NicObservation] = [], - logon_status:bool=False + logon_status: bool = False, ) -> None: """ Configurable observation for a node in the simulation. @@ -306,7 +304,7 @@ class NodeObservation(AbstractObservation): self.services: List[ServiceObservation] = services self.folders: List[FolderObservation] = folders self.nics: List[NicObservation] = nics - self.logon_status:bool=logon_status + self.logon_status: bool = logon_status self.default_observation: Dict = { "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, @@ -315,7 +313,7 @@ class NodeObservation(AbstractObservation): "operating_status": 0, } if self.logon_status: - self.default_observation['logon_status']=0 + self.default_observation["logon_status"] = 0 def observe(self, state: Dict) -> Dict: if self.where is None: @@ -332,7 +330,7 @@ class NodeObservation(AbstractObservation): obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)} if self.logon_status: - obs['logon_status'] = 0 + obs["logon_status"] = 0 return obs @@ -345,26 +343,28 @@ class NodeObservation(AbstractObservation): "NICS": spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)}), } if self.logon_status: - space_shape['logon_status'] = spaces.Discrete(3) + space_shape["logon_status"] = spaces.Discrete(3) return spaces.Dict(space_shape) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where:Optional[List[str]]= None) -> "NodeObservation": - node_uuid = session.ref_map_nodes[config['node_ref']] + def from_config( + cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]] = None + ) -> "NodeObservation": + node_uuid = session.ref_map_nodes[config["node_ref"]] if parent_where is None: where = ["network", "nodes", node_uuid] else: where = parent_where + ["nodes", node_uuid] - svc_configs = config.get('services', {}) + svc_configs = config.get("services", {}) services = [ServiceObservation.from_config(config=c, session=session, parent_where=where) for c in svc_configs] - folder_configs = config.get('folders', {}) - folders = [FolderObservation.from_config(config=c,session=session, parent_where=where) for c in folder_configs] + folder_configs = config.get("folders", {}) + folders = [FolderObservation.from_config(config=c, session=session, parent_where=where) for c in folder_configs] nic_uuids = session.simulation.network.nodes[node_uuid].nics.keys() - nic_configs = [{'nic_uuid':n for n in nic_uuids }] if nic_uuids else [] + nic_configs = [{"nic_uuid": n for n in nic_uuids}] if nic_uuids else [] nics = [NicObservation.from_config(config=c, session=session, parent_where=where) for c in nic_configs] - logon_status = config.get('logon_status',False) + logon_status = config.get("logon_status", False) return cls(where=where, services=services, folders=folders, nics=nics, logon_status=logon_status) @@ -374,7 +374,12 @@ class AclObservation(AbstractObservation): # if a file is created at runtime, we have currently got no way of telling the observation space to track it. # this needs adding, but not for the MVP. def __init__( - self, node_ip_to_id: Dict[str,int], ports: List[int], protocols: list[str], where: Optional[Tuple[str]] = None, num_rules: int = 10 + self, + node_ip_to_id: Dict[str, int], + ports: List[int], + protocols: list[str], + where: Optional[Tuple[str]] = None, + num_rules: int = 10, ) -> None: super().__init__() self.where: Optional[Tuple[str]] = where @@ -386,16 +391,18 @@ class AclObservation(AbstractObservation): self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)} "List of protocols which are part of the game, defines ordering when converting to an ID" self.default_observation: Dict = { - "RULES": {i+ 1:{ - "position": i, - "permission": 0, - "source_node_id": 0, - "source_port": 0, - "dest_node_id": 0, - "dest_port": 0, - "protocol": 0, + "RULES": { + i + + 1: { + "position": i, + "permission": 0, + "source_node_id": 0, + "source_port": 0, + "dest_node_id": 0, + "dest_port": 0, + "protocol": 0, } - for i in range(self.num_rules) + for i in range(self.num_rules) } } @@ -406,8 +413,7 @@ class AclObservation(AbstractObservation): if acl_state is NOT_PRESENT_IN_STATE: return self.default_observation - - #TODO: what if the ACL has more rules than num of max rules for obs space + # TODO: what if the ACL has more rules than num of max rules for obs space obs = {} obs["RULES"] = {} for i, rule_state in acl_state.items(): @@ -439,7 +445,8 @@ class AclObservation(AbstractObservation): { "RULE": spaces.Dict( { - i + 1: spaces.Dict( + i + + 1: spaces.Dict( { "position": spaces.Discrete(self.num_rules), "permission": spaces.Discrete(3), @@ -460,23 +467,23 @@ class AclObservation(AbstractObservation): @classmethod def from_config(cls, config: Dict, session: "PrimaiteSession") -> "AclObservation": node_ip_to_idx = {} - for node_idx, node_cfg in enumerate(config['node_order']): + for node_idx, node_cfg in enumerate(config["node_order"]): n_ref = node_cfg["node_ref"] n_obj = session.simulation.network.nodes[session.ref_map_nodes[n_ref]] for nic_uuid, nic_obj in n_obj.nics.items(): node_ip_to_idx[nic_obj.ip_address] = node_idx + 2 - router_uuid = session.ref_map_nodes[config['router_node_ref']] + router_uuid = session.ref_map_nodes[config["router_node_ref"]] return cls( node_ip_to_id=node_ip_to_idx, ports=session.options.ports, protocols=session.options.protocols, - where=["network", "nodes", router_uuid, "acl", "acl"]) - + where=["network", "nodes", router_uuid, "acl", "acl"], + ) class NullObservation(AbstractObservation): - def __init__(self, where:Optional[List[str]]=None): + def __init__(self, where: Optional[List[str]] = None): self.default_observation: Dict = {} def observe(self, state: Dict) -> Dict: @@ -487,20 +494,22 @@ class NullObservation(AbstractObservation): return spaces.Dict({}) @classmethod - def from_config(cls, config:Dict, session:Optional["PrimaiteSession"]=None) -> "NullObservation": + def from_config(cls, config: Dict, session: Optional["PrimaiteSession"] = None) -> "NullObservation": return cls() -class ICSObservation(NullObservation): pass + +class ICSObservation(NullObservation): + pass class UC2BlueObservation(AbstractObservation): def __init__( - self, - nodes: List[NodeObservation], - links: List[LinkObservation], - acl: AclObservation, - ics: ICSObservation, - where:Optional[List[str]] = None, + self, + nodes: List[NodeObservation], + links: List[LinkObservation], + acl: AclObservation, + ics: ICSObservation, + where: Optional[List[str]] = None, ) -> None: super().__init__() self.where: Optional[Tuple[str]] = where @@ -510,36 +519,38 @@ class UC2BlueObservation(AbstractObservation): self.acl: AclObservation = acl self.ics: ICSObservation = ics - self.default_observation : Dict = { - "NODES": {i+1: n.default_observation for i,n in enumerate(self.nodes)}, - "LINKS": {i+1: l.default_observation for i,l in enumerate(self.links)}, + self.default_observation: Dict = { + "NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)}, + "LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)}, "ACL": self.acl.default_observation, "ICS": self.ics.default_observation, } - def observe(self, state:Dict) -> Dict: + def observe(self, state: Dict) -> Dict: if self.where is None: return self.default_observation obs = {} - obs['NODES'] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)} - obs['LINKS'] = {i + 1: link.observe(state) for i, link in enumerate(self.links)} - obs['ACL'] = self.acl.observe(state) - obs['ICS'] = self.ics.observe(state) + obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)} + obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)} + obs["ACL"] = self.acl.observe(state) + obs["ICS"] = self.ics.observe(state) return obs @property def space(self) -> spaces.Space: - return spaces.Dict({ - "NODES": spaces.Dict({i+1: node.space for i, node in enumerate(self.nodes)}), - "LINKS": spaces.Dict({i+1: link.space for i, link in enumerate(self.links)}), - "ACL": self.acl.space, - "ICS": self.ics.space, - }) + return spaces.Dict( + { + "NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}), + "LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}), + "ACL": self.acl.space, + "ICS": self.ics.space, + } + ) @classmethod - def from_config(cls, config:Dict, session:"PrimaiteSession"): + def from_config(cls, config: Dict, session: "PrimaiteSession"): node_configs = config["nodes"] nodes = [NodeObservation.from_config(config=n, session=session) for n in node_configs] @@ -551,18 +562,18 @@ class UC2BlueObservation(AbstractObservation): ics_config = config["ics"] ics = ICSObservation.from_config(config=ics_config, session=session) - new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=['network']) + new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"]) return new class UC2RedObservation(AbstractObservation): - def __init__(self, nodes:List[NodeObservation], where:Optional[List[str]] = None) -> None: + def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None: super().__init__() - self.where:Optional[List[str]] = where + self.where: Optional[List[str]] = where self.nodes: List[NodeObservation] = nodes - self.default_observation : Dict = { - "NODES": {i+1: n.default_observation for i,n in enumerate(self.nodes)}, + self.default_observation: Dict = { + "NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)}, } def observe(self, state: Dict) -> Dict: @@ -570,14 +581,16 @@ class UC2RedObservation(AbstractObservation): return self.default_observation obs = {} - obs['NODES'] = {i+1: node.observe(state) for i, node in enumerate(self.nodes)} + obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)} return obs @property def space(self) -> spaces.Space: - return spaces.Dict({ - "NODES": spaces.Dict({i+1: node.space for i, node in enumerate(self.nodes)}), - }) + return spaces.Dict( + { + "NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}), + } + ) @classmethod def from_config(cls, config: Dict, session: "PrimaiteSession"): @@ -586,7 +599,9 @@ class UC2RedObservation(AbstractObservation): return cls(nodes=nodes, where=["network"]) -class UC2GreenObservation(NullObservation): pass +class UC2GreenObservation(NullObservation): + pass + class ObservationSpace: """ @@ -603,7 +618,7 @@ class ObservationSpace: # what this class does: # keep a list of observations # create observations for an actor from the config - def __init__(self, observation:AbstractObservation) -> None: + def __init__(self, observation: AbstractObservation) -> None: self.obs: AbstractObservation = observation def observe(self, state) -> Dict: @@ -614,12 +629,12 @@ class ObservationSpace: return self.obs.space @classmethod - def from_config(cls, config:Dict, session:"PrimaiteSession") -> "ObservationSpace": - if config['type'] == "UC2BlueObservation": - return cls(UC2BlueObservation.from_config(config.get('options',{}), session=session)) - elif config['type'] == "UC2RedObservation": - return cls(UC2RedObservation.from_config(config.get('options',{}), session=session)) - elif config['type'] == "UC2GreenObservation": - return cls(UC2GreenObservation.from_config(config.get("options",{}), session=session)) + def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationSpace": + if config["type"] == "UC2BlueObservation": + return cls(UC2BlueObservation.from_config(config.get("options", {}), session=session)) + elif config["type"] == "UC2RedObservation": + return cls(UC2RedObservation.from_config(config.get("options", {}), session=session)) + elif config["type"] == "UC2GreenObservation": + return cls(UC2GreenObservation.from_config(config.get("options", {}), session=session)) else: raise ValueError("Observation space type invalid") diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index a4ceb2dd..18925edc 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -1,34 +1,34 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List -class AbstractReward(): + +class AbstractReward: def __init__(self): ... @abstractmethod - def calculate(self, state:Dict) -> float: + def calculate(self, state: Dict) -> float: return 0.3 -class DummyReward(AbstractReward): +class DummyReward(AbstractReward): def calculate(self, state: Dict) -> float: return -0.1 -class RewardFunction(): - __rew_class_identifiers:Dict[str,type[AbstractReward]] = { - "DUMMY" : DummyReward - } - def __init__(self, reward_function:AbstractReward): + +class RewardFunction: + __rew_class_identifiers: Dict[str, type[AbstractReward]] = {"DUMMY": DummyReward} + + def __init__(self, reward_function: AbstractReward): self.reward: AbstractReward = reward_function - def calculate(self, state:Dict) -> float: + def calculate(self, state: Dict) -> float: return self.reward.calculate(state) @classmethod - def from_config(cls, cfg:Dict) -> "RewardFunction": - for rew_component_cfg in cfg['reward_components']: - rew_type = rew_component_cfg['type'] + def from_config(cls, cfg: Dict) -> "RewardFunction": + for rew_component_cfg in cfg["reward_components"]: + rew_type = rew_component_cfg["type"] rew_component = cls.__rew_class_identifiers[rew_type]() new = cls(reward_function=rew_component) return new - diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 4bcf26e4..7b2225ef 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -10,6 +10,7 @@ from typing import Dict, List from pydantic import BaseModel +from primaite import getLogger from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractAgent, RandomAgent from primaite.game.agent.observations import ( @@ -37,14 +38,12 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.database_client import DatabaseClient -from primaite.simulator.system.services.database_service import DatabaseService -from primaite.simulator.system.services.dns_client import DNSClient -from primaite.simulator.system.services.dns_server import DNSServer +from primaite.simulator.system.services.database.database_service import DatabaseService +from primaite.simulator.system.services.dns.dns_client import DNSClient +from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot from primaite.simulator.system.services.service import Service -from primaite import getLogger - _LOGGER = getLogger(__name__) @@ -91,7 +90,7 @@ class PrimaiteSession: agent_action, action_options = agent.get_action(agent_obs, agent_reward) # 9. CAOS action is converted into request (extra information might be needed to enrich # the request, this is what the execution definition is there for) - _LOGGER.debug(f"Formatting agent action {agent_action}") # maybe too many debug log statements + _LOGGER.debug(f"Formatting agent action {agent_action}") # maybe too many debug log statements agent_request = agent.format_request(agent_action, action_options) # 10. primaite session receives the action from the agents and asks the simulation to apply each @@ -106,8 +105,8 @@ class PrimaiteSession: def from_config(cls, cfg: dict) -> "PrimaiteSession": sess = cls() sess.options = PrimaiteSessionOptions( - ports = cfg['game_config']['ports'], - protocols = cfg['game_config']['protocols'], + ports=cfg["game_config"]["ports"], + protocols=cfg["game_config"]["protocols"], ) sim = sess.simulation net = sim.network @@ -230,7 +229,7 @@ class PrimaiteSession: reward_function_cfg = agent_cfg["reward_function"] # CREATE OBSERVATION SPACE - obs_space=ObservationSpace.from_config(observation_space_cfg, sess) + obs_space = ObservationSpace.from_config(observation_space_cfg, sess) """ # if observation_space_cfg is None: @@ -331,23 +330,23 @@ class PrimaiteSession: """ # CREATE ACTION SPACE - action_space_cfg['options']['node_uuids'] = [] + action_space_cfg["options"]["node_uuids"] = [] # if a list of nodes is defined, convert them from node references to node UUIDs - for action_node_option in action_space_cfg.get('options',{}).pop('nodes', {}): - if 'node_ref' in action_node_option: - node_uuid = sess.ref_map_nodes[action_node_option['node_ref']] - action_space_cfg['options']['node_uuids'].append(node_uuid) + for action_node_option in action_space_cfg.get("options", {}).pop("nodes", {}): + if "node_ref" in action_node_option: + node_uuid = sess.ref_map_nodes[action_node_option["node_ref"]] + action_space_cfg["options"]["node_uuids"].append(node_uuid) # Each action space can potentially have a different list of nodes that it can apply to. Therefore, # we will pass node_uuids as a part of the action space config. # However, it's not possible to specify the node uuids directly in the config, as they are generated # dynamically, so we have to translate node references to uuids before passing this config on. - if 'action_list' in action_space_cfg: - for action_config in action_space_cfg['action_list']: - if 'options' in action_config: - if 'target_router_ref' in action_config['options']: - _target = action_config['options']['target_router_ref'] - action_config['options']['target_router_uuid'] = sess.ref_map_nodes[_target] + if "action_list" in action_space_cfg: + for action_config in action_space_cfg["action_list"]: + if "options" in action_config: + if "target_router_ref" in action_config["options"]: + _target = action_config["options"]["target_router_ref"] + action_config["options"]["target_router_uuid"] = sess.ref_map_nodes[_target] action_space = ActionManager.from_config(sess, action_space_cfg) @@ -357,16 +356,30 @@ class PrimaiteSession: # CREATE AGENT if agent_type == "GreenWebBrowsingAgent": # TODO: implement non-random agents and fix this parsing - new_agent = RandomAgent(agent_name=agent_cfg['ref'], action_space=action_space, observation_space=obs_space, reward_function=rew_function) + new_agent = RandomAgent( + agent_name=agent_cfg["ref"], + action_space=action_space, + observation_space=obs_space, + reward_function=rew_function, + ) sess.agents.append(new_agent) elif agent_type == "GATERLAgent": - new_agent = RandomAgent(agent_name=agent_cfg['ref'], action_space=action_space, observation_space=obs_space, reward_function=rew_function) + new_agent = RandomAgent( + agent_name=agent_cfg["ref"], + action_space=action_space, + observation_space=obs_space, + reward_function=rew_function, + ) sess.agents.append(new_agent) elif agent_type == "RedDatabaseCorruptingAgent": - new_agent = RandomAgent(agent_name=agent_cfg['ref'], action_space=action_space, observation_space=obs_space, reward_function=rew_function) + new_agent = RandomAgent( + agent_name=agent_cfg["ref"], + action_space=action_space, + observation_space=obs_space, + reward_function=rew_function, + ) sess.agents.append(new_agent) else: print("agent type not found") - return sess diff --git a/src/primaite/simulator/network/hardware/nodes/switch.py b/src/primaite/simulator/network/hardware/nodes/switch.py index bb296203..09b53483 100644 --- a/src/primaite/simulator/network/hardware/nodes/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/switch.py @@ -57,8 +57,8 @@ class Switch(Node): """ state = super().describe_state() state["ports"] = {port_num: port.describe_state() for port_num, port in self.switch_ports.items()} - state["num_ports"]= self.num_ports # redundant? - state["mac_address_table"]= {mac: port for mac, port in self.mac_address_table.items()} + state["num_ports"] = self.num_ports # redundant? + state["mac_address_table"] = {mac: port for mac, port in self.mac_address_table.items()} return state def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort):