From e0f8c3c5eaf5bc8dea27d58af7160dfcf64d0b5f Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 19 Oct 2023 01:56:40 +0100 Subject: [PATCH] Add documentation --- example_config.yaml | 10 +- src/primaite/game/__init__.py | 1 + src/primaite/game/agent/actions.py | 323 +++++++++++++++++++++++---- src/primaite/game/agent/interface.py | 4 +- src/primaite/game/agent/rewards.py | 74 +++--- src/primaite/game/session.py | 166 +++++++++----- 6 files changed, 445 insertions(+), 133 deletions(-) diff --git a/example_config.yaml b/example_config.yaml index b700da5c..e16411fa 100644 --- a/example_config.yaml +++ b/example_config.yaml @@ -2,10 +2,10 @@ training_config: rl_framework: SB3 rl_algorithm: PPO seed: 333 - n_learn_episodes: 2 - n_learn_steps: 128 - n_eval_episodes: 2 - n_eval_steps: 128 + n_learn_episodes: 1 + n_learn_steps: 8 + n_eval_episodes: 0 + n_eval_steps: 8 game_config: @@ -451,6 +451,8 @@ game_config: node_ref: database_server folder_name: database file_name: database.db + + - type: WEB_SERVER_404_PENALTY weight: 0.5 options: diff --git a/src/primaite/game/__init__.py b/src/primaite/game/__init__.py index e69de29b..5d7a721f 100644 --- a/src/primaite/game/__init__.py +++ b/src/primaite/game/__init__.py @@ -0,0 +1 @@ +"""PrimAITE Game Layer.""" diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 6a1d5bcd..4c4aaab4 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -1,6 +1,16 @@ +""" +This module contains the ActionManager class which belongs to the Agent class. + +An agent's action space is made up of a collection of actions. Each action is an instance of a subclass of +AbstractAction. The ActionManager is responsible for: + 1. Creating the action space from a list of action types. + 2. Converting an integer action choice into a specific action and parameter choice. + 3. Converting an action and parameter choice into a request which can be ingested by the PrimAITE simulation. This + ensures that requests conform to the simulator's request format. +""" import itertools from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING from gymnasium import spaces @@ -13,22 +23,9 @@ if TYPE_CHECKING: from primaite.game.session import PrimaiteSession -class ExecutionDefiniton(ABC): - """ - Converter from actions to simulator requests. - - Allows adding extra data/context that defines in more detail what an action means. - """ - - """ - Examples: - ('node', 'service', 'scan', 2, 0) means scan the first service on node index 2 - -> ['network', 'nodes', , 'services', , 'scan'w] - """ - ... - - class AbstractAction(ABC): + """Base class for actions.""" + @abstractmethod def __init__(self, manager: "ActionManager", **kwargs) -> None: """ @@ -46,6 +43,8 @@ class AbstractAction(ABC): """Dictionary describing the number of options for each parameter of this action. The keys of this dict must align with the keyword args of the form_request method.""" self.manager: ActionManager = manager + """Reference to the ActionManager which created this action. This is used to access the session and simulation + objects.""" @abstractmethod def form_request(self) -> List[str]: @@ -54,6 +53,8 @@ class AbstractAction(ABC): class DoNothingAction(AbstractAction): + """Action which does nothing. This is here to allow agents to be idle if they choose to.""" + def __init__(self, manager: "ActionManager", **kwargs) -> None: super().__init__(manager=manager) self.name = "DONOTHING" @@ -65,6 +66,7 @@ class DoNothingAction(AbstractAction): # with one option. This just aids the Action Manager to enumerate all possibilities. def form_request(self, **kwargs) -> List[str]: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" return ["do_nothing"] @@ -77,12 +79,13 @@ class NodeServiceAbstractAction(AbstractAction): """ @abstractmethod - def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager) self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services} self.verb: str def form_request(self, node_id: int, service_id: int) -> List[str]: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_uuid = self.manager.get_node_uuid_by_idx(node_id) service_uuid = self.manager.get_service_uuid_by_idx(node_id, service_id) if node_uuid is None or service_uuid is None: @@ -91,54 +94,77 @@ class NodeServiceAbstractAction(AbstractAction): class NodeServiceScanAction(NodeServiceAbstractAction): + """Action which scans a service.""" + def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "scan" class NodeServiceStopAction(NodeServiceAbstractAction): + """Action which stops a service.""" + def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "stop" class NodeServiceStartAction(NodeServiceAbstractAction): + """Action which starts a service.""" + def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "start" class NodeServicePauseAction(NodeServiceAbstractAction): + """Action which pauses a service.""" + def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "pause" class NodeServiceResumeAction(NodeServiceAbstractAction): + """Action which resumes a service.""" + def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "resume" class NodeServiceRestartAction(NodeServiceAbstractAction): + """Action which restarts a service.""" + def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "restart" class NodeServiceDisableAction(NodeServiceAbstractAction): + """Action which disables a service.""" + def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "disable" class NodeServiceEnableAction(NodeServiceAbstractAction): + """Action which enables a service.""" + def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "enable" class NodeFolderAbstractAction(AbstractAction): + """ + Base class for folder actions. + + Any action which applies to a folder and uses node_id and folder_id as its only two parameters can inherit from + this base class. + """ + @abstractmethod def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: super().__init__(manager=manager) @@ -146,6 +172,7 @@ class NodeFolderAbstractAction(AbstractAction): self.verb: str def form_request(self, node_id: int, folder_id: int) -> List[str]: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_uuid = self.manager.get_node_uuid_by_idx(node_id) folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id) if node_uuid is None or folder_uuid is None: @@ -154,30 +181,44 @@ class NodeFolderAbstractAction(AbstractAction): class NodeFolderScanAction(NodeFolderAbstractAction): + """Action which scans a folder.""" + def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) self.verb: str = "scan" class NodeFolderCheckhashAction(NodeFolderAbstractAction): + """Action which checks the hash of a folder.""" + def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) self.verb: str = "checkhash" class NodeFolderRepairAction(NodeFolderAbstractAction): + """Action which repairs a folder.""" + def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) self.verb: str = "repair" class NodeFolderRestoreAction(NodeFolderAbstractAction): + """Action which restores a folder.""" + def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) self.verb: str = "restore" class NodeFileAbstractAction(AbstractAction): + """Abstract base class for file actions. + + Any action which applies to a file and uses node_id, folder_id, and file_id as its only three parameters can inherit + from this base class. + """ + @abstractmethod def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager=manager) @@ -185,6 +226,7 @@ class NodeFileAbstractAction(AbstractAction): self.verb: str def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_uuid = self.manager.get_node_uuid_by_idx(node_id) folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id) file_uuid = self.manager.get_file_uuid_by_idx(node_idx=node_id, folder_idx=folder_id, file_idx=file_id) @@ -194,42 +236,60 @@ class NodeFileAbstractAction(AbstractAction): class NodeFileScanAction(NodeFileAbstractAction): + """Action which scans a file.""" + def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) self.verb = "scan" class NodeFileCheckhashAction(NodeFileAbstractAction): + """Action which checks the hash of a file.""" + def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) self.verb = "checkhash" class NodeFileDeleteAction(NodeFileAbstractAction): + """Action which deletes a file.""" + def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) self.verb = "delete" class NodeFileRepairAction(NodeFileAbstractAction): + """Action which repairs a file.""" + def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) self.verb = "repair" class NodeFileRestoreAction(NodeFileAbstractAction): + """Action which restores a file.""" + def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) self.verb = "restore" class NodeFileCorruptAction(NodeFileAbstractAction): + """Action which corrupts a file.""" + def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) self.verb = "corrupt" class NodeAbstractAction(AbstractAction): + """ + Abstract base class for node actions. + + Any action which applies to a node and uses node_id as its only parameter can inherit from this base class. + """ + @abstractmethod def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: super().__init__(manager=manager) @@ -237,35 +297,46 @@ class NodeAbstractAction(AbstractAction): self.verb: str def form_request(self, node_id: int) -> List[str]: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_uuid = self.manager.get_node_uuid_by_idx(node_id) return ["network", "node", node_uuid, self.verb] class NodeOSScanAction(NodeAbstractAction): + """Action which scans a node's OS.""" + def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes) self.verb = "scan" class NodeShutdownAction(NodeAbstractAction): + """Action which shuts down a node.""" + def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes) self.verb = "shutdown" class NodeStartupAction(NodeAbstractAction): + """Action which starts up a node.""" + def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes) self.verb = "startup" class NodeResetAction(NodeAbstractAction): + """Action which resets a node.""" + def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes) self.verb = "reset" class NetworkACLAddRuleAction(AbstractAction): + """Action which adds a rule to a router's ACL.""" + def __init__( self, manager: "ActionManager", @@ -276,6 +347,21 @@ class NetworkACLAddRuleAction(AbstractAction): num_protocols: int, **kwargs, ) -> None: + """Init method for NetworkACLAddRuleAction. + + :param manager: Reference to the ActionManager which created this action. + :type manager: ActionManager + :param target_router_uuid: UUID of the router to which the ACL rule should be added. + :type target_router_uuid: str + :param max_acl_rules: Maximum number of ACL rules that can be added to the router. + :type max_acl_rules: int + :param num_ips: Number of IP addresses in the simulation. + :type num_ips: int + :param num_ports: Number of ports in the simulation. + :type num_ports: int + :param num_protocols: Number of protocols in the simulation. + :type num_protocols: int + """ super().__init__(manager=manager) num_permissions = 3 self.shape: Dict[str, int] = { @@ -290,8 +376,16 @@ class NetworkACLAddRuleAction(AbstractAction): self.target_router_uuid: str = target_router_uuid def form_request( - self, position, permission, source_ip_id, dest_ip_id, source_port_id, dest_port_id, protocol_id + self, + position: int, + permission: int, + source_ip_id: int, + dest_ip_id: int, + source_port_id: int, + dest_port_id: int, + protocol_id: int, ) -> List[str]: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" if permission == 0: permission_str = "UNUSED" return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS @@ -354,22 +448,51 @@ class NetworkACLAddRuleAction(AbstractAction): class NetworkACLRemoveRuleAction(AbstractAction): + """Action which removes a rule from a router's ACL.""" + def __init__(self, manager: "ActionManager", target_router_uuid: str, max_acl_rules: int, **kwargs) -> None: + """Init method for NetworkACLRemoveRuleAction. + + :param manager: Reference to the ActionManager which created this action. + :type manager: ActionManager + :param target_router_uuid: UUID of the router from which the ACL rule should be removed. + :type target_router_uuid: str + :param max_acl_rules: Maximum number of ACL rules that can be added to the router. + :type max_acl_rules: int + """ super().__init__(manager=manager) self.shape: Dict[str, int] = {"position": max_acl_rules} self.target_router_uuid: str = target_router_uuid def form_request(self, position: int) -> List[str]: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" return ["network", "node", self.target_router_uuid, "acl", "remove_rule", position] class NetworkNICAbstractAction(AbstractAction): + """ + Abstract base class for NIC actions. + + Any action which applies to a NIC and uses node_id and nic_id as its only two parameters can inherit from this base + class. + """ + def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: + """Init method for NetworkNICAbstractAction. + + :param manager: Reference to the ActionManager which created this action. + :type manager: ActionManager + :param num_nodes: Number of nodes in the simulation. + :type num_nodes: int + :param max_nics_per_node: Maximum number of NICs per node. + :type max_nics_per_node: int + """ super().__init__(manager=manager) self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node} self.verb: str def form_request(self, node_id: int, nic_id: int) -> List[str]: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_uuid = self.manager.get_node_uuid_by_idx(node_idx=node_id) nic_uuid = self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id) if node_uuid is None or nic_uuid is None: @@ -385,45 +508,24 @@ class NetworkNICAbstractAction(AbstractAction): class NetworkNICEnableAction(NetworkNICAbstractAction): + """Action which enables a NIC.""" + def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs) self.verb = "enable" class NetworkNICDisableAction(NetworkNICAbstractAction): + """Action which disables a NIC.""" + def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs) self.verb = "disable" -# class NetworkNICDisableAction(AbstractAction): -# def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: -# super().__init__(manager=manager) -# self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node} - -# def form_request(self, node_id: int, nic_id: int) -> List[str]: -# return [ -# "network", -# "node", -# self.manager.get_node_uuid_by_idx(node_idx=node_id), -# "nic", -# self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id), -# "disable", -# ] - - class ActionManager: - # let the action manager handle the conversion of action spaces into a single discrete integer space. - # + """Class which manages the action space for an agent.""" - # when action space is created, it will take subspaces and generate an action map by enumerating all possibilities, - # BUT, the action map can be provided in the config, in which case it will use that. - - # action map is basically just a mapping between integer and CAOS action (incl. parameter values) - # for example the action map can be: - # 0: DONOTHING - # 1: NODE, FILE, SCAN, NODEID=2, FOLDERID=1, FILEID=0 - # 2: ...... __act_class_identifiers: Dict[str, type] = { "DONOTHING": DoNothingAction, "NODE_SERVICE_SCAN": NodeServiceScanAction, @@ -453,6 +555,7 @@ class ActionManager: "NETWORK_NIC_ENABLE": NetworkNICEnableAction, "NETWORK_NIC_DISABLE": NetworkNICDisableAction, } + """Dictionary which maps action type strings to the corresponding action class.""" def __init__( self, @@ -469,6 +572,33 @@ class ActionManager: ip_address_list: Optional[List[str]] = None, # to allow us to map an index to an ip address. act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions ) -> None: + """Init method for ActionManager. + + :param session: Reference to the session to which the agent belongs. + :type session: PrimaiteSession + :param actions: List of action types which should be made available to the agent. + :type actions: List[str] + :param node_uuids: List of node UUIDs that this agent can act on. + :type node_uuids: List[str] + :param max_folders_per_node: Maximum number of folders per node. Used for calculating action shape. + :type max_folders_per_node: int + :param max_files_per_folder: Maximum number of files per folder. Used for calculating action shape. + :type max_files_per_folder: int + :param max_services_per_node: Maximum number of services per node. Used for calculating action shape. + :type max_services_per_node: int + :param max_nics_per_node: Maximum number of NICs per node. Used for calculating action shape. + :type max_nics_per_node: int + :param max_acl_rules: Maximum number of ACL rules per router. Used for calculating action shape. + :type max_acl_rules: int + :param protocols: List of protocols that are available in the simulation. Used for calculating action shape. + :type protocols: List[str] + :param ports: List of ports that are available in the simulation. Used for calculating action shape. + :type ports: List[str] + :param ip_address_list: List of IP addresses that known to this agent. Used for calculating action shape. + :type ip_address_list: Optional[List[str]] + :param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions. + :type act_map: Optional[Dict[int, Dict]] + """ self.session: "PrimaiteSession" = session self.sim: Simulation = self.session.simulation self.node_uuids: List[str] = node_uuids @@ -578,18 +708,48 @@ class ActionManager: @property def space(self) -> spaces.Space: + """Return the gymnasium action space for this agent.""" return spaces.Discrete(len(self.action_map)) - def get_node_uuid_by_idx(self, node_idx): + def get_node_uuid_by_idx(self, node_idx: int) -> str: + """Get the node UUID corresponding to the given index. + + :param node_idx: The index of the node to retrieve. + :type node_idx: int + :return: The node UUID. + :rtype: str + """ return self.node_uuids[node_idx] - def get_folder_uuid_by_idx(self, node_idx, folder_idx) -> Optional[str]: + def get_folder_uuid_by_idx(self, node_idx: int, folder_idx: int) -> Optional[str]: + """Get the folder UUID corresponding to the given node and folder indices. + + :param node_idx: The index of the node. + :type node_idx: int + :param folder_idx: The index of the folder on the node. + :type folder_idx: int + :return: The UUID of the folder. Or None if the node has fewer folders than the given index. + :rtype: Optional[str] + """ + node_uuid = self.get_node_uuid_by_idx(node_idx) node = self.sim.network.nodes[node_uuid] folder_uuids = list(node.file_system.folders.keys()) return folder_uuids[folder_idx] if len(folder_uuids) > folder_idx else None - def get_file_uuid_by_idx(self, node_idx, folder_idx, file_idx) -> Optional[str]: + def get_file_uuid_by_idx(self, node_idx: int, folder_idx: int, file_idx: int) -> Optional[str]: + """Get the file UUID corresponding to the given node, folder, and file indices. + + :param node_idx: The index of the node. + :type node_idx: int + :param folder_idx: The index of the folder on the node. + :type folder_idx: int + :param file_idx: The index of the file in the folder. + :type file_idx: int + :return: The UUID of the file. Or None if the node has fewer folders than the given index, or the folder has + fewer files than the given index. + :rtype: Optional[str] + """ node_uuid = self.get_node_uuid_by_idx(node_idx) node = self.sim.network.nodes[node_uuid] folder_uuids = list(node.file_system.folders.keys()) @@ -599,22 +759,64 @@ class ActionManager: file_uuids = list(folder.files.keys()) return file_uuids[file_idx] if len(file_uuids) > file_idx else None - def get_service_uuid_by_idx(self, node_idx, service_idx) -> Optional[str]: + def get_service_uuid_by_idx(self, node_idx: int, service_idx: int) -> Optional[str]: + """Get the service UUID corresponding to the given node and service indices. + + :param node_idx: The index of the node. + :type node_idx: int + :param service_idx: The index of the service on the node. + :type service_idx: int + :return: The UUID of the service. Or None if the node has fewer services than the given index. + :rtype: Optional[str] + """ node_uuid = self.get_node_uuid_by_idx(node_idx) node = self.sim.network.nodes[node_uuid] service_uuids = list(node.services.keys()) return service_uuids[service_idx] if len(service_uuids) > service_idx else None def get_internet_protocol_by_idx(self, protocol_idx: int) -> str: + """Get the internet protocol corresponding to the given index. + + :param protocol_idx: The index of the protocol to retrieve. + :type protocol_idx: int + :return: The protocol. + :rtype: str + """ return self.protocols[protocol_idx] def get_ip_address_by_idx(self, ip_idx: int) -> str: + """ + Get the IP address corresponding to the given index. + + :param ip_idx: The index of the IP address to retrieve. + :type ip_idx: int + :return: The IP address. + :rtype: str + """ return self.ip_address_list[ip_idx] def get_port_by_idx(self, port_idx: int) -> str: + """ + Get the port corresponding to the given index. + + :param port_idx: The index of the port to retrieve. + :type port_idx: int + :return: The port. + :rtype: str + """ return self.ports[port_idx] def get_nic_uuid_by_idx(self, node_idx: int, nic_idx: int) -> str: + """ + Get the NIC UUID corresponding to the given node and NIC indices. + + :param node_idx: The index of the node. + :type node_idx: int + :param nic_idx: The index of the NIC on the node. + :type nic_idx: int + :return: The NIC UUID. + :rtype: str + """ node_uuid = self.get_node_uuid_by_idx(node_idx) node_obj = self.sim.network.nodes[node_uuid] nics = list(node_obj.nics.keys()) @@ -624,6 +826,31 @@ class ActionManager: @classmethod def from_config(cls, session: "PrimaiteSession", cfg: Dict) -> "ActionManager": + """ + Construct an ActionManager from a config definition. + + The action space config supports the following three sections: + 1. ``action_list`` + ``action_list`` contians a list action components which need to be included in the action space. + Each action component has a ``type`` which maps to a subclass of AbstractAction, and additional options + which will be passed to the action class's __init__ method during initialisation. + 2. ``action_map`` + Since the agent uses a discrete action space which acts as a flattened version of the component-based + action space, action_map provides a mapping between an integer (chosen by the agent) and a meaningful + action and values of parameters. For example action 0 can correspond to do nothing, action 1 can + correspond to "NODE_SERVICE_SCAN" with ``node_id=1`` and ``service_id=1``, action 2 can be " + 3. ``options`` + ``options`` contains a dictionary of options which are passed to the ActionManager's __init__ method. + These options are used to calculate the shape of the action space, and to provide additional information + to the ActionManager which is required to convert the agent's action choice into a CAOS request. + + :param session: The Primaite Session to which the agent belongs. + :type session: PrimaiteSession + :param cfg: The action space config. + :type cfg: Dict + :return: The constructed ActionManager. + :rtype: ActionManager + """ obj = cls( session=session, actions=cfg["action_list"], diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 817e59b1..5f121fcc 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -1,6 +1,4 @@ -# TODO: remove this comment... This is just here to point out that I've named this 'actor' rather than 'agent' -# That's because I want to point out that this is disctinct from 'agent' in the reinforcement learning sense of the word -# If you disagree, make a comment in the PR review and we can discuss +"""Interface for agents.""" from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple, TypeAlias, Union diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 85da95da..67e6ee50 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -1,8 +1,9 @@ -from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE - from abc import ABC, abstractmethod from typing import Any, Dict, List, Tuple, TYPE_CHECKING + from primaite import getLogger +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE + _LOGGER = getLogger(__name__) if TYPE_CHECKING: @@ -10,14 +11,13 @@ if TYPE_CHECKING: class AbstractReward: - @abstractmethod def calculate(self, state: Dict) -> float: return 0.0 @classmethod @abstractmethod - def from_config(cls, config:dict, session:"PrimaiteSession") -> "AbstractReward": + def from_config(cls, config: dict, session: "PrimaiteSession") -> "AbstractReward": return cls() @@ -26,16 +26,26 @@ class DummyReward(AbstractReward): return 0.0 @classmethod - def from_config(cls, config: dict, session:"PrimaiteSession") -> "DummyReward": + def from_config(cls, config: dict, session: "PrimaiteSession") -> "DummyReward": return cls() + class DatabaseFileIntegrity(AbstractReward): - def __init__(self, node_uuid:str, folder_name:str, file_name:str) -> None: - self.location_in_state = ["network", "nodes", node_uuid, "file_system", "folders",folder_name, "files", file_name] + def __init__(self, node_uuid: str, folder_name: str, file_name: str) -> None: + self.location_in_state = [ + "network", + "nodes", + node_uuid, + "file_system", + "folders", + folder_name, + "files", + file_name, + ] def calculate(self, state: Dict) -> float: database_file_state = access_from_nested_dict(state, self.location_in_state) - health_status = database_file_state['health_status'] + health_status = database_file_state["health_status"] if health_status == "corrupted": return -1 elif health_status == "good": @@ -48,29 +58,39 @@ class DatabaseFileIntegrity(AbstractReward): node_ref = config.get("node_ref") folder_name = config.get("folder_name") file_name = config.get("file_name") - if not (node_ref): - _LOGGER.error(f"{cls.__name__} could not be initialised from config because node_ref parameter was not specified") - return DummyReward() #TODO: better error handling + if not node_ref: + _LOGGER.error( + f"{cls.__name__} could not be initialised from config because node_ref parameter was not specified" + ) + return DummyReward() # TODO: better error handling if not folder_name: - _LOGGER.error(f"{cls.__name__} could not be initialised from config because folder_name parameter was not specified") - return DummyReward() # TODO: better error handling + _LOGGER.error( + f"{cls.__name__} could not be initialised from config because folder_name parameter was not specified" + ) + return DummyReward() # TODO: better error handling if not file_name: - _LOGGER.error(f"{cls.__name__} could not be initialised from config because file_name parameter was not specified") - return DummyReward() # TODO: better error handling + _LOGGER.error( + f"{cls.__name__} could not be initialised from config because file_name parameter was not specified" + ) + return DummyReward() # TODO: better error handling node_uuid = session.ref_map_nodes[node_ref] if not node_uuid: - _LOGGER.error(f"{cls.__name__} could not be initialised from config because the referenced node could not be found in the simulation") - return DummyReward() # TODO: better error handling + _LOGGER.error( + f"{cls.__name__} could not be initialised from config because the referenced node could not be found in the simulation" + ) + return DummyReward() # TODO: better error handling + + return cls(node_uuid=node_uuid, folder_name=folder_name, file_name=file_name) - return cls(node_uuid = node_uuid, folder_name=folder_name, file_name=file_name) class WebServer404Penalty(AbstractReward): - def __init__(self, node_uuid:str, service_uuid:str) -> None: - self.location_in_state = ['network','nodes', node_uuid, 'services', service_uuid] + def __init__(self, node_uuid: str, service_uuid: str) -> None: + self.location_in_state = ["network", "nodes", node_uuid, "services", service_uuid] def calculate(self, state: Dict) -> float: web_service_state = access_from_nested_dict(state, self.location_in_state) - most_recent_return_code = web_service_state['most_recent_return_code'] + most_recent_return_code = web_service_state["most_recent_return_code"] + # TODO: reward needs to use the current web state. Observation should return web state at the time of last scan. if most_recent_return_code == 200: return 1 elif most_recent_return_code == 404: @@ -85,13 +105,13 @@ class WebServer404Penalty(AbstractReward): if not (node_ref and service_ref): msg = f"{cls.__name__} could not be initialised from config because node_ref and service_ref were not found in reward config." _LOGGER.warn(msg) - return DummyReward() #TODO: should we error out with incorrect inputs? Probably! + return DummyReward() # TODO: should we error out with incorrect inputs? Probably! node_uuid = session.ref_map_nodes[node_ref] service_uuid = session.ref_map_services[service_ref].uuid if not (node_uuid and service_uuid): msg = f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not found in the simulator." _LOGGER.warn(msg) - return DummyReward() # TODO: consider erroring here as well + return DummyReward() # TODO: consider erroring here as well return cls(node_uuid=node_uuid, service_uuid=service_uuid) @@ -101,13 +121,13 @@ class RewardFunction: "DUMMY": DummyReward, "DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity, "WEB_SERVER_404_PENALTY": WebServer404Penalty, - } + } def __init__(self): self.reward_components: List[Tuple[AbstractReward, float]] = [] "attribute reward_components keeps track of reward components and the weights assigned to each." - def regsiter_component(self, component:AbstractReward, weight:float=1.0) -> None: + def regsiter_component(self, component: AbstractReward, weight: float = 1.0) -> None: self.reward_components.append((component, weight)) def calculate(self, state: Dict) -> float: @@ -124,8 +144,8 @@ class RewardFunction: for rew_component_cfg in config["reward_components"]: rew_type = rew_component_cfg["type"] - weight = rew_component_cfg.get("weight",1.0) + weight = rew_component_cfg.get("weight", 1.0) rew_class = cls.__rew_class_identifiers[rew_type] - rew_instance = rew_class.from_config(config=rew_component_cfg.get('options',{}), session=session) + rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}), session=session) new.regsiter_component(component=rew_instance, weight=weight) return new diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 406308b9..f29d03dd 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -1,39 +1,17 @@ -# What do? Be an entry point for using PrimAITE -# 1. parse monoconfig -# 2. craete simulation -# 3. create actors and configure their actions/observations/rewards/ anything else -# 4. Create connection with ARCD GATE -# 5. idk - +"""PrimAITE session - the main entry point to training agents on PrimAITE.""" from ipaddress import IPv4Address from typing import Any, Dict, List, Optional, Tuple + +from arcd_gate.client.gate_client import ActType, GATEClient from gymnasium import spaces -from gymnasium.spaces.utils import flatten, flatten_space, unflatten -from gymnasium.core import ObsType, ActType - -import numpy as np - +from gymnasium.core import ActType, ObsType +from gymnasium.spaces.utils import flatten, flatten_space from pydantic import BaseModel from primaite import getLogger -from primaite.game.agent.GATE_agents import GATERLAgent from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractAgent, RandomAgent -from primaite.game.agent.observations import ( - AclObservation, - FileObservation, - FolderObservation, - ICSObservation, - LinkObservation, - NicObservation, - NodeObservation, - NullObservation, - ObservationSpace, - ServiceObservation, - UC2BlueObservation, - UC2GreenObservation, - UC2RedObservation, -) +from primaite.game.agent.observations import ObservationSpace from primaite.game.agent.rewards import RewardFunction from primaite.simulator.network.hardware.base import Link, NIC, Node from primaite.simulator.network.hardware.nodes.computer import Computer @@ -50,53 +28,75 @@ from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot from primaite.simulator.system.services.service import Service -from arcd_gate.client.gate_client import GATEClient, ActType -from numpy import ndarray - _LOGGER = getLogger(__name__) + class PrimaiteGATEClient(GATEClient): - def __init__(self, parent_session:"PrimaiteSession", service_port: int = 50000): + def __init__(self, parent_session: "PrimaiteSession", service_port: int = 50000): + """Create a new GATE client for PrimAITE. + + :param parent_session: The parent session object. + :type parent_session: PrimaiteSession + :param service_port: The port on which the GATE service is running. + :type service_port: int, optional""" super().__init__(service_port=service_port) - self.parent_session:"PrimaiteSession" = parent_session + self.parent_session: "PrimaiteSession" = parent_session @property def rl_framework(self) -> str: + """The reinforcement learning framework to use.""" return self.parent_session.training_options.rl_framework @property def rl_algorithm(self) -> str: + """The reinforcement learning algorithm to use.""" return self.parent_session.training_options.rl_algorithm @property def seed(self) -> int | None: + """The seed to use for the environment's random number generator.""" return self.parent_session.training_options.seed @property def n_learn_episodes(self) -> int: + """The number of episodes in each learning run.""" return self.parent_session.training_options.n_learn_episodes @property def n_learn_steps(self) -> int: + """The number of steps in each learning episode.""" return self.parent_session.training_options.n_learn_steps @property def n_eval_episodes(self) -> int: + """The number of episodes in each evaluation run.""" return self.parent_session.training_options.n_eval_episodes @property def n_eval_steps(self) -> int: + """The number of steps in each evaluation episode.""" return self.parent_session.training_options.n_eval_steps @property def action_space(self) -> spaces.Space: + """The gym action space of the agent.""" return self.parent_session.rl_agent.action_space.space @property def observation_space(self) -> spaces.Space: + """The gymnasium observation space of the agent.""" return flatten_space(self.parent_session.rl_agent.observation_space.space) def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, Dict]: + """Take a step in the environment. + + This method is called by GATE to advance the simulation by one timestep. + + :param action: The agent's action. + :type action: ActType + :return: The observation, reward, terminal flag, truncated flag, and info dictionary. + :rtype: Tuple[ObsType, float, bool, bool, Dict] + """ self.parent_session.rl_agent.most_recent_action = action self.parent_session.step() state = self.parent_session.simulation.describe_state() @@ -108,8 +108,19 @@ class PrimaiteGATEClient(GATEClient): info = {} return obs, rew, term, trunc, info - def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> Tuple[ObsType, Dict]: + """Reset the environment. + + This method is called when the environment is initialized and at the end of each episode. + + :param seed: The seed to use for the environment's random number generator. + :type seed: int, optional + :param options: Additional options for the reset. None are used by PrimAITE but this is included for + compatibility with GATE. + :type options: dict[str, Any], optional + :return: The initial observation and an empty info dictionary. + :rtype: Tuple[ObsType, Dict] + """ self.parent_session.reset() state = self.parent_session.simulation.describe_state() obs = self.parent_session.rl_agent.observation_space.observe(state) @@ -117,44 +128,78 @@ class PrimaiteGATEClient(GATEClient): return obs, {} def close(self): + """Close the session, this will stop the gate client and close the simulation.""" self.parent_session.close() + class PrimaiteSessionOptions(BaseModel): + """Global options which are applicable to all of the agents in the game. + + Currently this is used to restrict which ports and protocols exist in the world of the simulation.""" + ports: List[str] protocols: List[str] -class TrainingOptions(BaseModel): - rl_framework:str - rl_algorithm:str - seed:Optional[int] - n_learn_episodes:int - n_learn_steps:int - n_eval_episodes:int - n_eval_steps:int +class TrainingOptions(BaseModel): + """Options for training the RL agent.""" + + rl_framework: str + rl_algorithm: str + seed: Optional[int] + n_learn_episodes: int + n_learn_steps: int + n_eval_episodes: int + n_eval_steps: int class PrimaiteSession: + """ + The main entrypoint for PrimAITE sessions, this coordinates a simulation, agents, and connections to ARCD GATE. + """ + def __init__(self): self.simulation: Simulation = Simulation() + """Simulation object with which the agents will interact.""" self.agents: List[AbstractAgent] = [] + """List of agents.""" self.rl_agent: AbstractAgent - # which of the agents should be used for sending RL data to GATE client? + """The agent from the list which communicates with GATE to perform reinforcement learning.""" self.step_counter: int = 0 + """Current timestep within the episode.""" self.episode_counter: int = 0 + """Current episode number.""" self.options: PrimaiteSessionOptions + """Special options that apply for the entire game.""" self.training_options: TrainingOptions + """Options specific to agent training.""" self.ref_map_nodes: Dict[str, Node] = {} + """Mapping from unique node reference name to node object. Used when parsing config files.""" self.ref_map_services: Dict[str, Service] = {} + """Mapping from human-readable service reference to service object. Used for parsing config files.""" self.ref_map_links: Dict[str, Link] = {} + """Mapping from human-readable link reference to link object. Used when parsing config files.""" self.gate_client: PrimaiteGATEClient = PrimaiteGATEClient(self) + """Reference to a GATE Client object, which will send data to GATE service for training RL agent.""" def start_session(self, opts="TODO..."): - """Commence the session, this gives the gate client control over the simulation/agent loop.""" + """Commence the training session, this gives the GATE client control over the simulation/agent loop.""" self.gate_client.start() def step(self): + """ + Perform one step of the simulation/agent loop. + + This is the main loop of the game. It corresponds to one timestep in the simulation, and one action from each + agent. The steps are as follows: + 1. The simulation state is updated. + 2. The simulation state is sent to each agent. + 3. Each agent converts the state to an observation and calculates a reward. + 4. Each agent chooses an action based on the observation. + 5. Each agent converts the action to a request. + 6. The simulation applies the requests. + """ _LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}") # currently designed with assumption that all agents act once per step in order @@ -192,19 +237,36 @@ class PrimaiteSession: self.step_counter += 1 def reset(self): - pass + """Reset the session, this will reset the simulation.""" + return NotImplemented def close(self): - pass + """Close the session, this will stop the gate client and close the simulation.""" + return NotImplemented @classmethod def from_config(cls, cfg: dict) -> "PrimaiteSession": + """Create a PrimaiteSession object from a config dictionary. + + The config dictionary should have the following top-level keys: + 1. training_config: options for training the RL agent. Used by GATE. + 2. game_config: options for the game itself. Used by PrimaiteSession. + 3. simulation: defines the network topology and the initial state of the simulation. + + The specification for each of the three major areas is described in a separate documentation page. + # TODO: create documentation page and add links to it here. + + :param cfg: The config dictionary. + :type cfg: dict + :return: A PrimaiteSession object. + :rtype: PrimaiteSession + """ sess = cls() sess.options = PrimaiteSessionOptions( ports=cfg["game_config"]["ports"], protocols=cfg["game_config"]["protocols"], ) - sess.training_options = TrainingOptions(**cfg['training_config']) + sess.training_options = TrainingOptions(**cfg["training_config"]) sim = sess.simulation net = sim.network @@ -295,7 +357,11 @@ class PrimaiteSession: net.add_node(new_node) new_node.power_on() - sess.ref_map_nodes[node_ref] = new_node.uuid # TODO: fix incosistency with service and link. Node gets added by uuid, but service gets reference to object + sess.ref_map_nodes[ + node_ref + ] = ( + new_node.uuid + ) # TODO: fix incosistency with service and link. Node gets added by uuid, but service by object # 2. create links between nodes for link_cfg in links_cfg: @@ -314,12 +380,10 @@ class PrimaiteSession: # 3. create agents game_cfg = cfg["game_config"] - ports_cfg = game_cfg["ports"] - protocols_cfg = game_cfg["protocols"] agents_cfg = game_cfg["agents"] for agent_cfg in agents_cfg: - agent_ref = agent_cfg["ref"] + agent_ref = agent_cfg["ref"] # noqa: F841 agent_type = agent_cfg["type"] action_space_cfg = agent_cfg["action_space"] observation_space_cfg = agent_cfg["observation_space"]