diff --git a/example_config.yaml b/example_config.yaml index 8cf401cc..b47355c3 100644 --- a/example_config.yaml +++ b/example_config.yaml @@ -131,32 +131,36 @@ game_config: action_space: action_list: - - DONOTHING - - NODE_SERVICE_SCAN - - NODE_SERVICE_STOP - # - NODE_SERVICE_START - # - NODE_SERVICE_PAUSE - # - NODE_SERVICE_RESUME - # - NODE_SERVICE_RESTART - # - NODE_SERVICE_DISABLE - # - NODE_SERVICE_ENABLE - # - NODE_FILE_SCAN - # - NODE_FILE_CHECKHASH - # - NODE_FILE_DELETE - # - NODE_FILE_REPAIR - # - NODE_FILE_RESTORE - # - NODE_FOLDER_SCAN - # - NODE_FOLDER_CHECKHASH - # - NODE_FOLDER_REPAIR - # - NODE_FOLDER_RESTORE - # - NODE_OS_SCAN - # - NODE_SHUTDOWN - # - NODE_STARTUP - # - NODE_RESET - # - NETWORK_ACL_ADDRULE - # - NETWORK_ACL_REMOVERULE - # - NETWORK_NIC_ENABLE - - NETWORK_NIC_DISABLE + - type: DONOTHING + - type: NODE_SERVICE_SCAN + - type: NODE_SERVICE_STOP + - type: NODE_SERVICE_START + - type: NODE_SERVICE_PAUSE + - type: NODE_SERVICE_RESUME + - type: NODE_SERVICE_RESTART + - type: NODE_SERVICE_DISABLE + - type: NODE_SERVICE_ENABLE + - type: NODE_FILE_SCAN + - type: NODE_FILE_CHECKHASH + - type: NODE_FILE_DELETE + - type: NODE_FILE_REPAIR + - type: NODE_FILE_RESTORE + - type: NODE_FOLDER_SCAN + - type: NODE_FOLDER_CHECKHASH + - type: NODE_FOLDER_REPAIR + - type: NODE_FOLDER_RESTORE + - type: NODE_OS_SCAN + - type: NODE_SHUTDOWN + - type: NODE_STARTUP + - type: NODE_RESET + - type: NETWORK_ACL_ADDRULE + options: + target_router_ref: router_1 + - type: NETWORK_ACL_REMOVERULE + options: + target_router_ref: router_1 + - type: NETWORK_NIC_ENABLE + - type: NETWORK_NIC_DISABLE action_map: 0: diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 6c4ae3b2..f6f96161 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -1,12 +1,13 @@ +import itertools from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple -import itertools - - -from primaite.simulator.sim_container import Simulation from gym import spaces +from primaite.game.session import PrimaiteSession +from primaite.simulator.sim_container import Simulation + + class ExecutionDefiniton(ABC): """ Converter from actions to simulator requests. @@ -23,9 +24,8 @@ class ExecutionDefiniton(ABC): class AbstractAction(ABC): - @abstractmethod - def __init__(self, manager:"ActionManager", **kwargs) -> None: + def __init__(self, manager: "ActionManager", **kwargs) -> None: """ Init method for action. @@ -35,13 +35,12 @@ class AbstractAction(ABC): per node), we need to pass those options to every action that gets created. To pervent verbosity, these parameters are just broadcasted to all actions and the actions can pay attention to the ones that apply. """ - self.name:str = "" + self.name: str = "" """Human-readable action identifier used for printing, logging, and reporting.""" - self.shape = (0,) - """Tuple describing number of options for each parameter of this action. Can be passed to - gym.spaces.MultiDiscrete to form a valid space.""" - self.manager:ActionManager = manager - + self.shape: Dict[str, int] = {} + """Dictionary describing the number of options for each parameter of this action. The keys of this dict must + align with the keyword args of the form_request method.""" + self.manager: ActionManager = manager @abstractmethod def form_request(self) -> List[str]: @@ -50,14 +49,20 @@ class AbstractAction(ABC): class DoNothingAction(AbstractAction): - def __init__(self, manager:"ActionManager", **kwargs) -> None: + def __init__(self, manager: "ActionManager", **kwargs) -> None: super().__init__(manager=manager) self.name = "DONOTHING" - self.shape = (1,) + self.shape: Dict[str, int] = { + "dummy": 1, + } + # This action does not accept any parameters, therefore it technically has a gymnasium shape of Discrete(1), + # i.e. a choice between one option. To make enumerating this action easier, we are adding a 'dummy' paramter + # with one option. This just aids the Action Manager to enumerate all possibilities. def form_request(self) -> List[str]: return ["do_nothing"] + class NodeServiceAbstractAction(AbstractAction): """ Base class for service actions. @@ -65,211 +70,284 @@ class NodeServiceAbstractAction(AbstractAction): Any action which applies to a service and uses node_id and service_id as its only two parameters can inherit from this base class. """ - @abstractmethod - def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None: - super().__init__(manager=manager) - self.shape: Tuple[int] = (num_nodes, num_services) - self.verb:str - def form_request(self, node_id:int, service_id:int) -> List[str]: + @abstractmethod + def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None: + super().__init__(manager=manager) + self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services} + self.verb: str + + def form_request(self, node_id: int, service_id: int) -> List[str]: node_uuid = self.manager.get_node_uuid_by_idx(node_id) service_uuid = self.manager.get_service_uuid_by_idx(node_id, service_id) if node_uuid is None or service_uuid is None: return ["do_nothing"] - return ['network', 'node', node_uuid, 'services', service_uuid, self.verb] + return ["network", "node", node_uuid, "services", service_uuid, self.verb] + class NodeServiceScanAction(NodeServiceAbstractAction): - def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None: super().__init__(manager=manager) self.verb = "scan" + class NodeServiceStopAction(NodeServiceAbstractAction): - def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None: super().__init__(manager=manager) self.verb = "stop" + class NodeServiceStartAction(NodeServiceAbstractAction): - def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None: super().__init__(manager=manager) self.verb = "start" + class NodeServicePauseAction(NodeServiceAbstractAction): - def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None: super().__init__(manager=manager) self.verb = "pause" + class NodeServiceResumeAction(NodeServiceAbstractAction): - def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None: super().__init__(manager=manager) self.verb = "resume" + class NodeServiceRestartAction(NodeServiceAbstractAction): - def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None: super().__init__(manager=manager) self.verb = "restart" + class NodeServiceDisableAction(NodeServiceAbstractAction): - def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None: super().__init__(manager=manager) self.verb = "disable" + class NodeServiceEnableAction(NodeServiceAbstractAction): - def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None: super().__init__(manager=manager) self.verb = "enable" - class NodeFolderAbstractAction(AbstractAction): @abstractmethod - def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None: super().__init__(manager=manager) - self.shape = (num_nodes, num_folders) + self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders} self.verb: str - def form_request(self, node_id:int, folder_id:int) -> List[str]: + def form_request(self, node_id: int, folder_id: int) -> List[str]: node_uuid = self.manager.get_node_uuid_by_idx(node_id) folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id) if node_uuid is None or folder_uuid is None: return ["do_nothing"] - return ['network', 'node', node_uuid, 'file_system', 'folder', folder_uuid, self.verb] + return ["network", "node", node_uuid, "file_system", "folder", folder_uuid, self.verb] + class NodeFolderScanAction(NodeFolderAbstractAction): - def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None: super().__init__(manager, num_nodes, num_folders, **kwargs) - self.verb:str = "scan" + self.verb: str = "scan" + class NodeFolderCheckhashAction(NodeFolderAbstractAction): - def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None: super().__init__(manager, num_nodes, num_folders, **kwargs) - self.verb:str = "checkhash" + self.verb: str = "checkhash" + class NodeFolderRepairAction(NodeFolderAbstractAction): - def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None: super().__init__(manager, num_nodes, num_folders, **kwargs) - self.verb:str = "repair" + self.verb: str = "repair" + class NodeFolderRestoreAction(NodeFolderAbstractAction): def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None: super().__init__(manager, num_nodes, num_folders, **kwargs) - self.verb:str = "restore" + self.verb: str = "restore" class NodeFileAbstractAction(AbstractAction): @abstractmethod - def __init__(self, manager:"ActionManager", num_nodes:int, num_folders:int, num_files:int, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager=manager) - self.shape:Tuple[int] = (num_nodes, num_folders, num_files) - self.verb:str + self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders, "file_id": num_files} + self.verb: str - def form_request(self, node_id:int, folder_id:int, file_id:int) -> List[str]: + def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]: node_uuid = self.manager.get_node_uuid_by_idx(node_id) folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id) file_uuid = self.manager.get_file_uuid_by_idx(node_idx=node_id, folder_idx=folder_id, file_idx=file_id) if node_uuid is None or folder_uuid is None or file_uuid is None: return ["do_nothing"] - return ['network', 'node', node_uuid, 'file_system', 'folder', folder_uuid, 'files', file_uuid, self.verb] + return ["network", "node", node_uuid, "file_system", "folder", folder_uuid, "files", file_uuid, self.verb] + class NodeFileScanAction(NodeFileAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager, num_nodes, num_folders, num_files, **kwargs) self.verb = "scan" + class NodeFileCheckhashAction(NodeFileAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager, num_nodes, num_folders, num_files, **kwargs) self.verb = "checkhash" + class NodeFileDeleteAction(NodeFileAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager, num_nodes, num_folders, num_files, **kwargs) self.verb = "delete" + class NodeFileRepairAction(NodeFileAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager, num_nodes, num_folders, num_files, **kwargs) self.verb = "repair" + class NodeFileRestoreAction(NodeFileAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager, num_nodes, num_folders, num_files, **kwargs) self.verb = "restore" + class NodeAbstractAction(AbstractAction): @abstractmethod def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: super().__init__(manager=manager) - self.shape: Tuple[int] = (num_nodes,) + self.shape: Dict[str, int] = {"node_id": num_nodes} self.verb: str - def form_request(self, node_id:int) -> List[str]: + def form_request(self, node_id: int) -> List[str]: node_uuid = self.manager.get_node_uuid_by_idx(node_id) return ["network", "node", node_uuid, self.verb] + class NodeOSScanAction(NodeAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: super().__init__(manager=manager) - self.verb = 'scan' + self.verb = "scan" + class NodeShutdownAction(NodeAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: super().__init__(manager=manager) - self.verb = 'shutdown' + self.verb = "shutdown" + class NodeStartupAction(NodeAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: super().__init__(manager=manager) - self.verb = 'start' + self.verb = "start" + class NodeResetAction(NodeAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: super().__init__(manager=manager) - self.verb = 'reset' + self.verb = "reset" + class NetworkACLAddRuleAction(AbstractAction): - def __init__(self, - manager: "ActionManager", - target_router_uuid:str, - max_acl_rules:int, - num_ips:int, - num_ports:int, - num_protocols:int, - **kwargs) -> None: + def __init__( + self, + manager: "ActionManager", + target_router_uuid: str, + max_acl_rules: int, + num_ips: int, + num_ports: int, + num_protocols: int, + **kwargs, + ) -> None: super().__init__(manager=manager) num_permissions = 2 - self.shape: Tuple[int] = (max_acl_rules, num_permissions, num_ips, num_ips, num_ports, num_ports, num_protocols) - self.target_router_uuid:str = target_router_uuid + self.shape: Dict[str, int] = { + "position": max_acl_rules, + "permission": num_permissions, + "source_ip_idx": num_ips, + "dest_ip_idx": num_ips, + "source_port_idx": num_ports, + "dest_port_idx": num_ports, + "protocol_idx": num_protocols, + } + self.target_router_uuid: str = target_router_uuid - def form_request(self, position, permission, source_ip_idx, dest_ip_idx, source_port_idx, dest_port_idx, protocol_idx) -> List[str]: + def form_request( + self, position, permission, source_ip_idx, dest_ip_idx, source_port_idx, dest_port_idx, protocol_idx + ) -> List[str]: protocol = self.manager.get_internet_protocol_by_idx(protocol_idx) src_ip = self.manager.get_ip_address_by_idx(source_ip_idx) src_port = self.manager.get_port_by_idx(source_port_idx) dst_ip = self.manager.get_ip_address_by_idx(dest_ip_idx) dst_port = self.manager.get_port_by_idx(dest_port_idx) return [ - 'network', - 'node', + "network", + "node", self.target_router_uuid, - 'acl', - 'add_rule', + "acl", + "add_rule", permission, protocol, src_ip, src_port, dst_ip, dst_port, - position + position, ] +class NetworkACLRemoveRuleAction(AbstractAction): + def __init__(self, manager: "ActionManager", target_router_uuid: str, max_acl_rules: int, **kwargs) -> None: + super().__init__(manager=manager) + self.shape: Dict[str, int] = {"position": max_acl_rules} + self.target_router_uuid: str = target_router_uuid + def form_request(self, position: int) -> List[str]: + return ["network", "node", self.target_router_uuid, "acl", "remove_rule", position] + + +class NetworkNICEnableAction(AbstractAction): + def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: + super().__init__(manager=manager) + self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node} + + def form_request(self, node_id: int, nic_id: int) -> List[str]: + return [ + "network", + "node", + self.manager.get_node_uuid_by_idx(node_idx=node_id), + "nic", + self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id), + "enable", + ] + + +class NetworkNICDisableAction(AbstractAction): + def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: + super().__init__(manager=manager) + self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node} + + def form_request(self, node_id: int, nic_id: int) -> List[str]: + return [ + "network", + "node", + self.manager.get_node_uuid_by_idx(node_idx=node_id), + "nic", + self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id), + "disable", + ] class ActionManager: # let the action manager handle the conversion of action spaces into a single discrete integer space. # - # when action space is created, it will take subspaces and generate an action map by enumerating all possibilities, # BUT, the action map can be provided in the config, in which case it will use that. @@ -278,69 +356,83 @@ class ActionManager: # 0: DONOTHING # 1: NODE, FILE, SCAN, NODEID=2, FOLDERID=1, FILEID=0 # 2: ...... - __act_class_identifiers:Dict[str,type] = { + __act_class_identifiers: Dict[str, type] = { "DONOTHING": DoNothingAction, "NODE_SERVICE_SCAN": NodeServiceScanAction, "NODE_SERVICE_STOP": NodeServiceStopAction, - # "NODE_SERVICE_START": NodeServiceStartAction, - # "NODE_SERVICE_PAUSE": NodeServicePauseAction, - # "NODE_SERVICE_RESUME": NodeServiceResumeAction, - # "NODE_SERVICE_RESTART": NodeServiceRestartAction, - # "NODE_SERVICE_DISABLE": NodeServiceDisableAction, - # "NODE_SERVICE_ENABLE": NodeServiceEnableAction, - # "NODE_FILE_SCAN": NodeFileScanAction, - # "NODE_FILE_CHECKHASH": NodeFileCheckhashAction, - # "NODE_FILE_DELETE": NodeFileDeleteAction, - # "NODE_FILE_REPAIR": NodeFileRepairAction, - # "NODE_FILE_RESTORE": NodeFileRestoreAction, + "NODE_SERVICE_START": NodeServiceStartAction, + "NODE_SERVICE_PAUSE": NodeServicePauseAction, + "NODE_SERVICE_RESUME": NodeServiceResumeAction, + "NODE_SERVICE_RESTART": NodeServiceRestartAction, + "NODE_SERVICE_DISABLE": NodeServiceDisableAction, + "NODE_SERVICE_ENABLE": NodeServiceEnableAction, + "NODE_FILE_SCAN": NodeFileScanAction, + "NODE_FILE_CHECKHASH": NodeFileCheckhashAction, + "NODE_FILE_DELETE": NodeFileDeleteAction, + "NODE_FILE_REPAIR": NodeFileRepairAction, + "NODE_FILE_RESTORE": NodeFileRestoreAction, "NODE_FOLDER_SCAN": NodeFolderScanAction, - # "NODE_FOLDER_CHECKHASH": NodeFolderCheckhashAction, - # "NODE_FOLDER_REPAIR": NodeFolderRepairAction, - # "NODE_FOLDER_RESTORE": NodeFolderRestoreAction, - # "NODE_OS_SCAN": NodeOSScanAction, - # "NODE_SHUTDOWN": NodeShutdownAction, - # "NODE_STARTUP": NodeStartupAction, - # "NODE_RESET": NodeResetAction, - # "NETWORK_ACL_ADDRULE": NetworkACLAddRuleAction, - # "NETWORK_ACL_REMOVERULE": NetworkACLRemoveRuleAction, - # "NETWORK_NIC_ENABLE": NetworkNICEnable, - # "NETWORK_NIC_DISABLE": NetworkNICDisable, + "NODE_FOLDER_CHECKHASH": NodeFolderCheckhashAction, + "NODE_FOLDER_REPAIR": NodeFolderRepairAction, + "NODE_FOLDER_RESTORE": NodeFolderRestoreAction, + "NODE_OS_SCAN": NodeOSScanAction, + "NODE_SHUTDOWN": NodeShutdownAction, + "NODE_STARTUP": NodeStartupAction, + "NODE_RESET": NodeResetAction, + "NETWORK_ACL_ADDRULE": NetworkACLAddRuleAction, + "NETWORK_ACL_REMOVERULE": NetworkACLRemoveRuleAction, + "NETWORK_NIC_ENABLE": NetworkNICEnableAction, + "NETWORK_NIC_DISABLE": NetworkNICDisableAction, } + def __init__( + self, + session: PrimaiteSession, # reference to session for looking up stuff + actions: List[str], # stores list of actions available to agent + node_uuids: List[str], # allows mapping index to node + max_folders_per_node: int = 2, # allows calculating shape + max_files_per_folder: int = 2, # allows calculating shape + max_services_per_node: int = 2, # allows calculating shape + max_nics_per_node: int = 8, # allows calculating shape + max_acl_rules: int = 10, # allows calculating shape + protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol + ports: List[str] = ["HTTP", "DNS", "ARP", "FTP"], # allow mapping index to port + ip_address_list: Optional[List[str]] = None, # to allow us to map an index to an ip address. + act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions + ) -> None: + self.session: PrimaiteSession = session + self.sim: Simulation = self.session.simulation + self.node_uuids: List[str] = node_uuids + self.protocols: List[str] = protocols + self.ports: List[str] = ports - def __init__(self, - sim:Simulation, - actions:List[str], - node_uuids:List[str], - max_folders_per_node:int = 2, - max_files_per_folder:int = 2, - max_services_per_node:int = 2, - max_nics_per_node:int=8, - max_acl_rules:int=10, - protocols:List[str]=['TCP','UDP','ICMP'], - ports:List[str]=['HTTP','DNS','ARP','FTP'], - ip_address_list:Optional[List[str]]=None, - act_map:Optional[Dict[int, Dict]]=None) -> None: - self.sim: Simulation = sim - self.node_uuids:List[str] = node_uuids - self.protocols:List[str] = protocols - self.ports:List[str] = ports + self.ip_address_list: List[str] + if ip_address_list is not None: + self.ip_address_list = ip_address_list + else: + self.ip_address_list = [] + for node_uuid in self.node_uuids: + node_obj = self.sim.network.nodes[node_uuid] + nics = node_obj.nics + for nic_uuid, nic_obj in nics.items(): + self.ip_address_list.append(nic_obj.ip_address) action_args = { "num_nodes": len(node_uuids), - "num_folders":max_folders_per_node, + "num_folders": max_folders_per_node, "num_files": max_files_per_folder, "num_services": max_services_per_node, "num_nics": max_nics_per_node, "num_acl_rules": max_acl_rules, "num_protocols": len(self.protocols), "num_ports": len(self.protocols), - "num_ips":} + "num_ips": len(self.ip_address_list), + } self.actions: Dict[str, AbstractAction] = {} for act_type in actions: self.actions[act_type] = self.__act_class_identifiers[act_type](self, **action_args) - self.action_map:Dict[int, Tuple[str, Dict]] = {} + self.action_map: Dict[int, Tuple[str, Dict]] = {} """ Action mapping that converts an integer to a specific action and parameter choice. @@ -350,21 +442,30 @@ class ActionManager: if act_map is None: self.action_map = self._enumerate_actions() else: - self.action_map = {i:(a['action'], a['options']) for i,a in act_map.items()} + self.action_map = {i: (a["action"], a["options"]) for i, a in act_map.items()} # make sure all numbers between 0 and N are represented as dict keys in action map assert all([i in self.action_map.keys() for i in range(len(self.action_map))]) - def _enumerate_actions(self,) -> Dict[int, Tuple[AbstractAction, Dict]]: - ... + def _enumerate_actions( + self, + ) -> Dict[int, Tuple[AbstractAction, Dict]]: + all_action_possibilities = [] + for action in self.actions.values(): + param_names = (list(action.shape.keys()),) + num_possibilities = list(action.shape.values()) + possibilities = [range(n) for n in num_possibilities] - def get_action(self, action: int) -> Tuple[str,Dict]: + itertools.product(action.shape.values()) + all_action_possibilities.append((action, {})) + + def get_action(self, action: int) -> Tuple[str, Dict]: """Produce action in CAOS format""" """the agent chooses an action (as an integer), this is converted into an action in CAOS format""" """The caos format is basically a action identifier, followed by parameters stored in a dictionary""" act_identifier, act_options = self.action_map[action] return act_identifier, act_options - def form_request(self, action_identifier:str, action_options:Dict): + def form_request(self, action_identifier: str, action_options: Dict): """Take action in CAOS format and use the execution definition to change it into PrimAITE request format""" act_obj = self.actions[action_identifier] return act_obj.form_request(**action_options) @@ -380,37 +481,57 @@ class ActionManager: node_uuid = self.get_node_uuid_by_idx(node_idx) node = self.sim.network.nodes[node_uuid] folder_uuids = list(node.file_system.folders.keys()) - return folder_uuids[folder_idx] if len(folder_uuids)>folder_idx else None + return folder_uuids[folder_idx] if len(folder_uuids) > folder_idx else None def get_file_uuid_by_idx(self, node_idx, folder_idx, file_idx) -> Optional[str]: node_uuid = self.get_node_uuid_by_idx(node_idx) node = self.sim.network.nodes[node_uuid] folder_uuids = list(node.file_system.folders.keys()) - if len(folder_uuids)<=folder_idx: + if len(folder_uuids) <= folder_idx: return None folder = node.file_system.folders[folder_uuids[folder_idx]] file_uuids = list(folder.files.keys()) - return file_uuids[file_idx] if len(file_uuids)>file_idx else None + return file_uuids[file_idx] if len(file_uuids) > file_idx else None def get_service_uuid_by_idx(self, node_idx, service_idx) -> Optional[str]: node_uuid = self.get_node_uuid_by_idx(node_idx) node = self.sim.network.nodes[node_uuid] service_uuids = list(node.services.keys()) - return service_uuids[service_idx] if len(service_uuids)>service_idx else None + return service_uuids[service_idx] if len(service_uuids) > service_idx else None - def get_internet_protocol_by_idx(self, protocol_idx:int) -> str: + def get_internet_protocol_by_idx(self, protocol_idx: int) -> str: + return self.protocols[protocol_idx] + def get_ip_address_by_idx(self, ip_idx: int) -> str: + return self.ip_address_list[ip_idx] - # protocol = self.manager.get_internet_protocol_by_idx(protocol_idx) - # src_ip = self.manager.get_ip_address_by_idx(source_ip_idx) - # src_port = self.manager.get_port_by_idx(source_port_idx) - # dst_ip = self.manager.get_ip_address_by_idx(dest_ip_idx) - # dst_port = self.manager.get_port_by_idx(dest_port_idx) + def get_port_by_idx(self, port_idx: int) -> str: + return self.ports[port_idx] + def get_nic_uuid_by_idx(self, node_idx: int, nic_idx: int) -> str: + node_uuid = self.get_node_uuid_by_idx(node_idx) + node_obj = self.sim.network.nodes[node_uuid] + nics = list(node_obj.nics.keys()) + if len(nics) <= nic_idx: + return None + return nics[nic_idx] + @classmethod + def from_config(cls, session: PrimaiteSession, cfg: Dict) -> "ActionManager": + obj = cls( + session=session, + actions=cfg["action_list"], + node_uuids=cfg["options"]["nodes"], + max_folders_per_node=cfg["options"]["max_folders_per_node"], + max_files_per_folder=cfg["options"]["max_files_per_folder"], + max_services_per_node=cfg["options"]["max_services_per_node"], + max_nics_per_node=cfg["options"]["max_nics_per_node"], + max_acl_rules=cfg["options"]["max_acl_rules"], + max_X=cfg["options"]["max_X"], + protocols=session.options.ports, + ports=session.options.protocols, + ip_address_list=None, + act_map=cfg["action_map"], + ) -class UC2RedActions(AbstractAction): - ... - -class UC2GreenActionSpace(ActionManager): - ... + return obj diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 0e682b60..528c0b1a 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -2,14 +2,16 @@ # That's because I want to point out that this is disctinct from 'agent' in the reinforcement learning sense of the word # If you disagree, make a comment in the PR review and we can discuss from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Union, TypeAlias +from typing import Any, Dict, List, Optional, TypeAlias, Union + import numpy as np from primaite.game.agent.actions import ActionManager from primaite.game.agent.observations import ObservationSpace from primaite.game.agent.rewards import RewardFunction -ObsType:TypeAlias = Union[Dict, np.ndarray] +ObsType: TypeAlias = Union[Dict, np.ndarray] + class AbstractAgent(ABC): """Base class for scripted and RL agents.""" @@ -28,31 +30,28 @@ class AbstractAgent(ABC): # by for example specifying target ip addresses, or converting a node ID into a uuid self.execution_definition = None - def get_obs_from_state(self, state:Dict) -> ObsType: + def convert_state_to_obs(self, state: Dict) -> ObsType: """ state : dict state directly from simulation.describe_state output : dict state according to CAOS. """ return self.observation_space.observe(state) - def get_reward_from_state(self, state:Dict) -> float: + def calculate_reward_from_state(self, state: Dict) -> float: return self.reward_function.calculate(state) @abstractmethod - def get_action(self, obs:ObsType, reward:float=None): - # in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 1-40, + def get_action(self, obs: ObsType, reward: float = None): + # in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39, # then use a bespoke conversion to take 1-40 int back into CAOS action - return ('NODE', 'SERVICE', 'SCAN', '', '') + return ("NODE", "SERVICE", "SCAN", "", "") @abstractmethod def format_request(self, action) -> List[str]: # this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator. # therefore the execution definition needs to be a mapping from CAOS into SIMULATOR """Format action into format expected by the simulator, and apply execution definition if applicable.""" - return ['network', 'nodes', '', 'file_system', 'folder', 'root', 'scan'] - - - + return ["network", "nodes", "", "file_system", "folder", "root", "scan"] class AbstractScriptedAgent(AbstractAgent): @@ -60,10 +59,11 @@ class AbstractScriptedAgent(AbstractAgent): ... + class RandomAgent(AbstractScriptedAgent): """Agent that ignores its observation and acts completely at random.""" - def get_action(self, obs:ObsType, reward:float=None): + def get_action(self, obs: ObsType, reward: float = None): return self.action_space.space.sample() diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index fcd8b4b3..0f88b322 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -5,23 +5,58 @@ # 4. Create connection with ARCD GATE # 5. idk -from primaite.simulator.sim_container import Simulation -from primaite.game.agent.interface import AbstractAgent +from ipaddress import IPv4Address +from typing import Dict, List + +from pydantic import BaseModel + +from primaite.game.agent.actions import ActionManager +from primaite.game.agent.interface import AbstractAgent +from primaite.game.agent.observations import ( + AclObservation, + FileObservation, + FolderObservation, + ICSObservation, + LinkObservation, + NicObservation, + NodeObservation, + NullObservation, + ServiceObservation, + UC2BlueObservation, + UC2RedObservation, +) +from primaite.simulator.network.hardware.base import Link, NIC, Node +from primaite.simulator.network.hardware.nodes.computer import Computer +from primaite.simulator.network.hardware.nodes.router import ACLAction, Router +from primaite.simulator.network.hardware.nodes.server import Server +from primaite.simulator.network.hardware.nodes.switch import Switch +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.sim_container import Simulation +from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.services.database_service import DatabaseService +from primaite.simulator.system.services.dns_client import DNSClient +from primaite.simulator.system.services.dns_server import DNSServer +from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot +from primaite.simulator.system.services.service import Service + + +class PrimaiteSessionOptions(BaseModel): + ports: List[str] + protocols: List[str] -from typing import List class PrimaiteSession: def __init__(self): self.simulation: Simulation = Simulation() - self.agents:List[AbstractAgent] = [] - self.step_counter:int = 0 - self.episode_counter:int = 0 - + self.agents: List[AbstractAgent] = [] + self.step_counter: int = 0 + self.episode_counter: int = 0 + self.options: PrimaiteSessionOptions def step(self): # currently designed with assumption that all agents act once per step in order - for agent in self.agents: # 3. primaite session asks simulation to provide initial state # 4. primate session gives state to all agents @@ -29,10 +64,10 @@ class PrimaiteSession: sim_state = self.simulation.describe_state() # 6. each agent takes most recent state and converts it to CAOS observation - agent_obs = agent.get_obs_from_state(sim_state) + agent_obs = agent.convert_state_to_obs(sim_state) # 7. meanwhile each agent also takes state and calculates reward - agent_reward = agent.get_reward_from_state(sim_state) + agent_reward = agent.calculate_reward_from_state(sim_state) # 8. each agent takes observation and applies decision rule to observation to create CAOS # action(such as random, rulebased, or send to GATE) (therefore, converting CAOS action @@ -50,3 +85,237 @@ class PrimaiteSession: self.simulation.apply_timestep(self.step_counter) self.step_counter += 1 + @classmethod + def from_config(cls, cfg: dict) -> "PrimaiteSession": + sess = cls() + sim = sess.simulation + net = sim.network + + ref_map_nodes: Dict[str, Node] = {} + ref_map_services: Dict[str, Service] = {} + ref_map_links: Dict[str, Link] = {} + + nodes_cfg = cfg["simulation"]["network"]["nodes"] + links_cfg = cfg["simulation"]["network"]["links"] + for node_cfg in nodes_cfg: + node_ref = node_cfg["ref"] + n_type = node_cfg["type"] + if n_type == "computer": + new_node = Computer( + hostname=node_cfg["hostname"], + ip_address=node_cfg["ip_address"], + subnet_mask=node_cfg["subnet_mask"], + default_gateway=node_cfg["default_gateway"], + dns_server=node_cfg["dns_server"], + ) + elif n_type == "server": + new_node = Server( + hostname=node_cfg["hostname"], + ip_address=node_cfg["ip_address"], + subnet_mask=node_cfg["subnet_mask"], + default_gateway=node_cfg["default_gateway"], + dns_server=node_cfg.get("dns_server"), + ) + elif n_type == "switch": + new_node = Switch(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports")) + elif n_type == "router": + new_node = Router(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports")) + if "ports" in node_cfg: + for port_num, port_cfg in node_cfg["ports"].items(): + new_node.configure_port( + port=port_num, ip_address=port_cfg["ip_address"], subnet_mask=port_cfg["subnet_mask"] + ) + if "acl" in node_cfg: + for r_num, r_cfg in node_cfg["acl"].items(): + # excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating + # this: 'r_cfg.get('src_port')' + # Port/IPProtocol. TODO Refactor + new_node.acl.add_rule( + action=ACLAction[r_cfg["action"]], + src_port=None if not (p := r_cfg.get("src_port")) else Port[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], + protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_ip_address=r_cfg.get("ip_address"), + dst_ip_address=r_cfg.get("ip_address"), + position=r_num, + ) + else: + print("invalid node type") + if "services" in node_cfg: + for service_cfg in node_cfg["services"]: + service_ref = service_cfg["ref"] + service_type = service_cfg["type"] + service_types_mapping = { + "DNSClient": DNSClient, # key is equal to the 'name' attr of the service class itself. + "DNSServer": DNSServer, + "DatabaseClient": DatabaseClient, + "DatabaseService": DatabaseService, + # 'database_backup': , + "DataManipulationBot": DataManipulationBot, + # 'web_browser' + } + if service_type in service_types_mapping: + new_node.software_manager.install(service_types_mapping[service_type]) + new_service = new_node.software_manager.software[service_type] + ref_map_services[service_ref] = new_service + else: + print(f"service type not found {service_type}") + # service-dependent options + if service_type == "DatabaseClient": + if "options" in service_cfg: + opt = service_cfg["options"] + if "db_server_ip" in opt: + new_service.configure(server_ip_address=IPv4Address(opt["db_server_ip"])) + if service_type == "DNSServer": + if "options" in service_cfg: + opt = service_cfg["options"] + if "domain_mapping" in opt: + for domain, ip in opt["domain_mapping"].items(): + new_service.dns_register(domain, ip) + if "nics" in node_cfg: + for nic_num, nic_cfg in node_cfg["nics"].items(): + new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"])) + + net.add_node(new_node) + new_node.power_on() + ref_map_nodes[node_ref] = new_node.uuid + + # 2. create links between nodes + for link_cfg in links_cfg: + node_a = net.nodes[ref_map_nodes[link_cfg["endpoint_a_ref"]]] + node_b = net.nodes[ref_map_nodes[link_cfg["endpoint_b_ref"]]] + if isinstance(node_a, Switch): + endpoint_a = node_a.switch_ports[link_cfg["endpoint_a_port"]] + else: + endpoint_a = node_a.ethernet_port[link_cfg["endpoint_a_port"]] + if isinstance(node_b, Switch): + endpoint_b = node_b.switch_ports[link_cfg["endpoint_b_port"]] + else: + endpoint_b = node_b.ethernet_port[link_cfg["endpoint_b_port"]] + new_link = net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b) + ref_map_links[link_cfg["ref"]] = new_link.uuid + + # 3. create agents + game_cfg = cfg["game_config"] + ports_cfg = game_cfg["ports"] + protocols_cfg = game_cfg["protocols"] + agents_cfg = game_cfg["agents"] + + for agent_cfg in agents_cfg: + agent_ref = agent_cfg["ref"] + agent_type = agent_cfg["type"] + action_space_cfg = agent_cfg["action_space"] + observation_space_cfg = agent_cfg["observation_space"] + reward_function_cfg = agent_cfg["reward_function"] + + # CREATE OBSERVATION SPACE + if observation_space_cfg is None: + obs_space = NullObservation() + elif observation_space_cfg["type"] == "UC2BlueObservation": + node_obs_list = [] + link_obs_list = [] + + # node ip to index maps ip addresses to node id, as there are potentially multiple nics on a node, there are multiple ip addresses + node_ip_to_index = {} + for node_idx, node_cfg in enumerate(nodes_cfg): + n_ref = node_cfg["ref"] + n_obj = net.nodes[ref_map_nodes[n_ref]] + for nic_uuid, nic_obj in n_obj.nics.items(): + node_ip_to_index[nic_obj.ip_address] = node_idx + 2 + + for node_obs_cfg in observation_space_cfg["options"]["nodes"]: + node_ref = node_obs_cfg["node_ref"] + folder_obs_list = [] + service_obs_list = [] + if "services" in node_obs_cfg: + for service_obs_cfg in node_obs_cfg["services"]: + service_obs_list.append( + ServiceObservation( + where=[ + "network", + "nodes", + ref_map_nodes[node_ref], + "services", + ref_map_services[service_obs_cfg["service_ref"]], + ] + ) + ) + if "folders" in node_obs_cfg: + for folder_obs_cfg in node_obs_cfg["folders"]: + file_obs_list = [] + if "files" in folder_obs_cfg: + for file_obs_cfg in folder_obs_cfg["files"]: + file_obs_list.append( + FileObservation( + where=[ + "network", + "nodes", + ref_map_nodes[node_ref], + "folders", + folder_obs_cfg["folder_name"], + "files", + file_obs_cfg["file_name"], + ] + ) + ) + folder_obs_list.append( + FolderObservation( + where=[ + "network", + "nodes", + ref_map_nodes[node_ref], + "folders", + folder_obs_cfg["folder_name"], + ], + files=file_obs_list, + ) + ) + nic_obs_list = [] + for nic_uuid in net.nodes[ref_map_nodes[node_obs_cfg["node_ref"]]].nics.keys(): + nic_obs_list.append( + NicObservation(where=["network", "nodes", ref_map_nodes[node_ref], "NICs", nic_uuid]) + ) + node_obs_list.append( + NodeObservation( + where=["network", "nodes", ref_map_nodes[node_ref]], + services=service_obs_list, + folders=folder_obs_list, + nics=nic_obs_list, + logon_status=False, + ) + ) + for link_obs_cfg in observation_space_cfg["options"]["links"]: + link_ref = link_obs_cfg["link_ref"] + link_obs_list.append(LinkObservation(where=["network", "links", ref_map_links[link_ref]])) + + acl_obs = AclObservation( + node_ip_to_id=node_ip_to_index, + ports=game_cfg["ports"], + protocols=game_cfg["ports"], + where=["network", "nodes", observation_space_cfg["options"]["acl"]["router_node_ref"]], + ) + obs_space = UC2BlueObservation( + nodes=node_obs_list, links=link_obs_list, acl=acl_obs, ics=ICSObservation() + ) + elif observation_space_cfg["type"] == "UC2RedObservation": + obs_space = UC2RedObservation.from_config(observation_space_cfg["options"], sim=sim) + else: + print("observation space config not specified correctly.") + obs_space = NullObservation() + + # CREATE ACTION SPACE + action_space = ActionManager.from_config(sess, action_space_cfg) + + # CREATE REWARD FUNCTION + + # CREATE AGENT + if agent_type == "GreenWebBrowsingAgent": + ... + elif agent_type == "GATERLAgent": + ... + elif agent_type == "RedDatabaseCorruptingAgent": + ... + else: + print("agent type not found") + + return sess