diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 1df25d27..68e42fb1 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -10,13 +10,12 @@ AbstractAction. The ActionManager is responsible for: ensures that requests conform to the simulator's request format. """ from abc import abstractmethod -from typing import Dict, List, Literal, Optional, TYPE_CHECKING, Union +from typing import Dict, List, Literal, Optional, Union from pydantic import BaseModel, ConfigDict, Field, field_validator, ValidationInfo from primaite import getLogger -from primaite.game.agent.actions.manager import ActionManager -from primaite.game.agent.actions.manager import AbstractAction +from primaite.game.agent.actions.manager import AbstractAction, ActionManager from primaite.game.agent.actions.service import NodeServiceAbstractAction from primaite.interface.request import RequestFormat @@ -1238,4 +1237,3 @@ class RansomwareLaunchC2ServerAction(AbstractAction): return ["do_nothing"] # This action currently doesn't require any further configuration options. return ["network", "node", node_name, "application", "C2Server", "ransomware_launch"] - diff --git a/src/primaite/game/agent/actions/__init__.py b/src/primaite/game/agent/actions/__init__.py index e69de29b..24a3ad67 100644 --- a/src/primaite/game/agent/actions/__init__.py +++ b/src/primaite/game/agent/actions/__init__.py @@ -0,0 +1,27 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + +from primaite.game.agent.actions.manager import ActionManager +from primaite.game.agent.actions.service import ( + NodeServiceDisableAction, + NodeServiceEnableAction, + NodeServiceFixAction, + NodeServicePauseAction, + NodeServiceRestartAction, + NodeServiceResumeAction, + NodeServiceScanAction, + NodeServiceStartAction, + NodeServiceStopAction, +) + +__all__ = ( + "NodeServiceDisableAction", + "NodeServiceEnableAction", + "NodeServiceFixAction", + "NodeServicePauseAction", + "NodeServiceRestartAction", + "NodeServiceResumeAction", + "NodeServiceScanAction", + "NodeServiceStartAction", + "NodeServiceStopAction", + "ActionManager", +) diff --git a/src/primaite/game/agent/actions/acl.py b/src/primaite/game/agent/actions/acl.py new file mode 100644 index 00000000..22e0a465 --- /dev/null +++ b/src/primaite/game/agent/actions/acl.py @@ -0,0 +1,170 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import Dict, List, Literal + +from pydantic import BaseModel, Field, field_validator, ValidationInfo + +from primaite.game.agent.actions.manager import AbstractAction +from primaite.game.game import _LOGGER + + +class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"): + """Action which adds a rule to a router's ACL.""" + + class ACLRuleOptions(BaseModel): + """Validator for ACL_ADD_RULE options.""" + + target_router: str + """On which router to add the rule, must be specified.""" + position: int + """At what position to add the rule, must be specified.""" + permission: Literal[1, 2] + """Whether to allow or deny traffic, must be specified. 1 = PERMIT, 2 = DENY.""" + source_ip_id: int = Field(default=1, ge=1) + """Rule source IP address. By default, all ip addresses.""" + source_wildcard_id: int = Field(default=0, ge=0) + """Rule source IP wildcard. By default, use the wildcard at index 0 from action manager.""" + source_port_id: int = Field(default=1, ge=1) + """Rule source port. By default, all source ports.""" + dest_ip_id: int = Field(default=1, ge=1) + """Rule destination IP address. By default, all ip addresses.""" + dest_wildcard_id: int = Field(default=0, ge=0) + """Rule destination IP wildcard. By default, use the wildcard at index 0 from action manager.""" + dest_port_id: int = Field(default=1, ge=1) + """Rule destination port. By default, all destination ports.""" + protocol_id: int = Field(default=1, ge=1) + """Rule protocol. By default, all protocols.""" + + @field_validator( + "source_ip_id", + "source_port_id", + "source_wildcard_id", + "dest_ip_id", + "dest_port_id", + "dest_wildcard_id", + "protocol_id", + mode="before", + ) + @classmethod + def not_none(cls, v: str, info: ValidationInfo) -> int: + """If None is passed, use the default value instead.""" + if v is None: + return cls.model_fields[info.field_name].default + return v + + def __init__( + self, + manager: "ActionManager", + max_acl_rules: int, + num_ips: int, + num_ports: int, + num_protocols: int, + **kwargs, + ) -> None: + """Init method for RouterACLAddRuleAction. + + :param manager: Reference to the ActionManager which created this action. + :type manager: ActionManager + :param max_acl_rules: Maximum number of ACL rules that can be added to the router. + :type max_acl_rules: int + :param num_ips: Number of IP addresses in the simulation. + :type num_ips: int + :param num_ports: Number of ports in the simulation. + :type num_ports: int + :param num_protocols: Number of protocols in the simulation. + :type num_protocols: int + """ + super().__init__(manager=manager) + num_permissions = 3 + self.shape: Dict[str, int] = { + "position": max_acl_rules, + "permission": num_permissions, + "source_ip_id": num_ips, + "dest_ip_id": num_ips, + "source_port_id": num_ports, + "dest_port_id": num_ports, + "protocol_id": num_protocols, + } + + def form_request( + self, + target_router: str, + position: int, + permission: int, + source_ip_id: int, + source_wildcard_id: int, + dest_ip_id: int, + dest_wildcard_id: int, + source_port_id: int, + dest_port_id: int, + protocol_id: int, + ) -> List[str]: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + # Validate incoming data. + parsed_options = RouterACLAddRuleAction.ACLRuleOptions( + target_router=target_router, + position=position, + permission=permission, + source_ip_id=source_ip_id, + source_wildcard_id=source_wildcard_id, + dest_ip_id=dest_ip_id, + dest_wildcard_id=dest_wildcard_id, + source_port_id=source_port_id, + dest_port_id=dest_port_id, + protocol_id=protocol_id, + ) + if parsed_options.permission == 1: + permission_str = "PERMIT" + elif parsed_options.permission == 2: + permission_str = "DENY" + else: + _LOGGER.warning(f"{self.__class__} received permission {permission}, expected 0 or 1.") + + if parsed_options.protocol_id == 1: + protocol = "ALL" + else: + protocol = self.manager.get_internet_protocol_by_idx(parsed_options.protocol_id - 2) + # subtract 2 to account for UNUSED=0 and ALL=1. + + if parsed_options.source_ip_id == 1: + src_ip = "ALL" + else: + src_ip = self.manager.get_ip_address_by_idx(parsed_options.source_ip_id - 2) + # subtract 2 to account for UNUSED=0, and ALL=1 + + src_wildcard = self.manager.get_wildcard_by_idx(parsed_options.source_wildcard_id) + + if parsed_options.source_port_id == 1: + src_port = "ALL" + else: + src_port = self.manager.get_port_by_idx(parsed_options.source_port_id - 2) + # subtract 2 to account for UNUSED=0, and ALL=1 + + if parsed_options.dest_ip_id == 1: + dst_ip = "ALL" + else: + dst_ip = self.manager.get_ip_address_by_idx(parsed_options.dest_ip_id - 2) + # subtract 2 to account for UNUSED=0, and ALL=1 + dst_wildcard = self.manager.get_wildcard_by_idx(parsed_options.dest_wildcard_id) + + if parsed_options.dest_port_id == 1: + dst_port = "ALL" + else: + dst_port = self.manager.get_port_by_idx(parsed_options.dest_port_id - 2) + # subtract 2 to account for UNUSED=0, and ALL=1 + + return [ + "network", + "node", + target_router, + "acl", + "add_rule", + permission_str, + protocol, + str(src_ip), + src_wildcard, + src_port, + str(dst_ip), + dst_wildcard, + dst_port, + position, + ] diff --git a/src/primaite/game/agent/actions/application.py b/src/primaite/game/agent/actions/application.py new file mode 100644 index 00000000..4b82ffd3 --- /dev/null +++ b/src/primaite/game/agent/actions/application.py @@ -0,0 +1,64 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from abc import abstractmethod +from typing import ClassVar, Dict + +from primaite.game.agent.actions.manager import AbstractAction +from primaite.interface.request import RequestFormat + + +class NodeApplicationAbstractAction(AbstractAction): + """ + Base class for application actions. + + Any action which applies to an application and uses node_id and application_id as its only two parameters can + inherit from this base class. + """ + + class ConfigSchema(AbstractAction.ConfigSchema): + node_name: str + application_name: str + + verb: ClassVar[str] + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + if config.node_name is None or config.application_name is None: + return ["do_nothing"] + return ["network", "node", config.node_name, "application", config.application_name, cls.verb] + + +class NodeApplicationExecuteAction(NodeApplicationAbstractAction, identifier="node_application_execute"): + """Action which executes an application.""" + + class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema): + verb: str = "execute" + + +class NodeApplicationScanAction(NodeApplicationAbstractAction, identifier="node_application_scan"): + """Action which scans an application.""" + + class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema): + verb: str = "scan" + + +class NodeApplicationCloseAction(NodeApplicationAbstractAction, identifier="node_application_close"): + """Action which closes an application.""" + + class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema): + verb: str = "close" + + +class NodeApplicationFixAction(NodeApplicationAbstractAction, identifier="node_application_fix"): + """Action which fixes an application.""" + + class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema): + verb: str = "fix" + + +class NodeApplicationInstallAction(AbstractAction): + """Action which installs an application.""" + + class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema): + verb: str = "install" diff --git a/src/primaite/game/agent/actions/file.py b/src/primaite/game/agent/actions/file.py new file mode 100644 index 00000000..d21daa9b --- /dev/null +++ b/src/primaite/game/agent/actions/file.py @@ -0,0 +1,79 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import ClassVar + +from primaite.game.agent.actions.manager import AbstractAction +from primaite.interface.request import RequestFormat + + +class NodeFileAbstractAction(AbstractAction): + """Abstract base class for file actions. + + Any action which applies to a file and uses node_name, folder_name, and file_name as its only three parameters can inherit + from this base class. + """ + + class ConfigSchema(AbstractAction.ConfigSchema): + node_name: str + folder_name: str + file_name: str + + verb: ClassVar[str] + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + if config.node_name is None or config.folder_name is None or config.file_name is None: + return ["do_nothing"] + return [ + "network", + "node", + config.node_name, + "file_system", + "folder", + config.folder_name, + "file", + config.file_name, + cls.verb, + ] + + +class NodeFileCreateAction(NodeFileAbstractAction, identifier="node_file_create"): + """Action which creates a new file in a given folder.""" + + class ConfigSchema(NodeFileAbstractAction.ConfigSchema): + verb: str = "create" + + +class NodeFileScanAction(NodeFileAbstractAction, identifier="node_file_scan"): + """Action which scans a file.""" + + class ConfigSchema(NodeFileAbstractAction.ConfigSchema): + verb: str = "scan" + + +class NodeFileDeleteAction(NodeFileAbstractAction, identifier="node_file_delete"): + """Action which deletes a file.""" + + class ConfigSchema(NodeFileAbstractAction.ConfigSchema): + verb: str = "delete" + + +class NodeFileRestoreAction(NodeFileAbstractAction, identifier="node_file_restore"): + """Action which restores a file.""" + + class ConfigSchema(NodeFileAbstractAction.ConfigSchema): + verb: str = "restore" + + +class NodeFileCorruptAction(NodeFileAbstractAction, identifier="node_file_corrupt"): + """Action which corrupts a file.""" + + class ConfigSchema(NodeFileAbstractAction.ConfigSchema): + verb: str = "corrupt" + + +class NodeFileAccessAction(NodeFileAbstractAction, identifier="node_file_access"): + """Action which increases a file's access count.""" + + class ConfigSchema(NodeFileAbstractAction.ConfigSchema): + verb: str = "access" diff --git a/src/primaite/game/agent/actions/folder.py b/src/primaite/game/agent/actions/folder.py new file mode 100644 index 00000000..278f5658 --- /dev/null +++ b/src/primaite/game/agent/actions/folder.py @@ -0,0 +1,65 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from abc import abstractmethod +from typing import ClassVar, Dict + +from primaite.game.agent.actions.manager import AbstractAction +from primaite.interface.request import RequestFormat + + +class NodeFolderAbstractAction(AbstractAction): + """ + Base class for folder actions. + + Any action which applies to a folder and uses node_id and folder_id as its only two parameters can inherit from + this base class. + """ + + class ConfigSchema(AbstractAction.ConfigSchema): + node_name: str + folder_name: str + + verb: ClassVar[str] + + @classmethod + def form_request(cls, node_id: int, folder_id: int) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + node_name = cls.manager.get_node_name_by_idx(node_id) + folder_name = cls.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id) + if node_name is None or folder_name is None: + return ["do_nothing"] + return ["network", "node", node_name, "file_system", "folder", folder_name, cls.verb] + + +class NodeFolderScanAction(NodeFolderAbstractAction, identifier="node_folder_scan"): + """Action which scans a folder.""" + + class ConfigSchema(NodeFolderAbstractAction.ConfigSchema): + verb: str = "scan" + + +class NodeFolderCheckhashAction(NodeFolderAbstractAction, identifier="node_folder_checkhash"): + """Action which checks the hash of a folder.""" + + class ConfigSchema(NodeFolderAbstractAction.ConfigSchema): + verb: str = "checkhash" + + +class NodeFolderRepairAction(NodeFolderAbstractAction, identifier="node_folder_repair"): + """Action which repairs a folder.""" + + class ConfigSchema(NodeFolderAbstractAction.ConfigSchema): + verb: str = "repair" + + +class NodeFolderRestoreAction(NodeFolderAbstractAction, identifier="node_folder_restore"): + """Action which restores a folder.""" + + class ConfigSchema(NodeFolderAbstractAction.ConfigSchema): + verb: str = "restore" + + +class NodeFolderCreateAction(AbstractAction, identifier="node_folder_create"): + """Action which creates a new folder.""" + + class ConfigSchema(NodeFolderAbstractAction.ConfigSchema): + verb: str = "create" diff --git a/src/primaite/game/agent/actions/manager.py b/src/primaite/game/agent/actions/manager.py index 34c7c4d6..99ce091e 100644 --- a/src/primaite/game/agent/actions/manager.py +++ b/src/primaite/game/agent/actions/manager.py @@ -1,3 +1,4 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK """yaml example agents: @@ -10,20 +11,22 @@ agents: action_map: """ -from abc import ABC, abstractmethod - -from pydantic import BaseModel, ConfigDict -from primaite.game.game import PrimaiteGame -from primaite.interface.request import RequestFormat from __future__ import annotations -from gymnasium import spaces - import itertools -from typing import Any, ClassVar, Dict, List, Literal, Tuple, Type +from abc import ABC, abstractmethod +from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Type + +from gymnasium import spaces +from pydantic import BaseModel, ConfigDict + +from primaite.game.game import _LOGGER, PrimaiteGame +from primaite.interface.request import RequestFormat + class AbstractAction(BaseModel): """Base class for actions.""" + # notes: # we actually don't need to hold any state in actions, so there's no need to define any __init__ logic. # all the init methods in the old actions are just used for holding a verb and shape, which are not really used. @@ -31,30 +34,32 @@ class AbstractAction(BaseModel): # (therefore there's no need for creating action instances, just the action class contains logic for converting # CAOS actions to requests for simulator. Similar to the network node adder, that class also doesn't need to be # instantiated.) - class ConfigSchema(BaseModel, ABC): # TODO: not sure if this better named something like `Options` + class ConfigSchema(BaseModel, ABC): # TODO: not sure if this better named something like `Options` model_config = ConfigDict(extra="forbid") type: str - _registry: ClassVar[Dict[str,Type[AbstractAction]]] = {} + _registry: ClassVar[Dict[str, Type[AbstractAction]]] = {} - def __init_subclass__(cls, identifier:str, **kwargs: Any) -> None: + def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) if identifier in cls._registry: raise ValueError(f"Cannot create new action under reserved name {identifier}") cls._registry[identifier] = cls @classmethod - def form_request(self, config:ConfigSchema) -> RequestFormat: + def form_request(self, config: ConfigSchema) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" return [] + class DoNothingAction(AbstractAction): class ConfigSchema(AbstractAction.ConfigSchema): type: Literal["do_nothing"] = "do_nothing" - def form_request(self, options:ConfigSchema) -> RequestFormat: + def form_request(self, options: ConfigSchema) -> RequestFormat: return ["do_nothing"] + class ActionManager: """Class which manages the action space for an agent.""" @@ -131,53 +136,53 @@ class ActionManager: """ # Populate lists of apps, services, files, folders, etc on nodes. - for node in nodes: - app_list = [a["application_name"] for a in node.get("applications", [])] - while len(app_list) < max_applications_per_node: - app_list.append(None) - self.application_names.append(app_list) + # for node in nodes: + # app_list = [a["application_name"] for a in node.get("applications", [])] + # while len(app_list) < max_applications_per_node: + # app_list.append(None) + # self.application_names.append(app_list) - svc_list = [s["service_name"] for s in node.get("services", [])] - while len(svc_list) < max_services_per_node: - svc_list.append(None) - self.service_names.append(svc_list) + # svc_list = [s["service_name"] for s in node.get("services", [])] + # while len(svc_list) < max_services_per_node: + # svc_list.append(None) + # self.service_names.append(svc_list) - folder_list = [f["folder_name"] for f in node.get("folders", [])] - while len(folder_list) < max_folders_per_node: - folder_list.append(None) - self.folder_names.append(folder_list) + # folder_list = [f["folder_name"] for f in node.get("folders", [])] + # while len(folder_list) < max_folders_per_node: + # folder_list.append(None) + # self.folder_names.append(folder_list) - file_sublist = [] - for folder in node.get("folders", [{"files": []}]): - file_list = [f["file_name"] for f in folder.get("files", [])] - while len(file_list) < max_files_per_folder: - file_list.append(None) - file_sublist.append(file_list) - while len(file_sublist) < max_folders_per_node: - file_sublist.append([None] * max_files_per_folder) - self.file_names.append(file_sublist) - self.protocols: List[str] = protocols - self.ports: List[str] = ports + # file_sublist = [] + # for folder in node.get("folders", [{"files": []}]): + # file_list = [f["file_name"] for f in folder.get("files", [])] + # while len(file_list) < max_files_per_folder: + # file_list.append(None) + # file_sublist.append(file_list) + # while len(file_sublist) < max_folders_per_node: + # file_sublist.append([None] * max_files_per_folder) + # self.file_names.append(file_sublist) + # self.protocols: List[str] = protocols + # self.ports: List[str] = ports - self.ip_address_list: List[str] = ip_list - self.wildcard_list: List[str] = wildcard_list - if self.wildcard_list == []: - self.wildcard_list = ["NONE"] - # action_args are settings which are applied to the action space as a whole. - global_action_args = { - "num_nodes": len(self.node_names), - "num_folders": max_folders_per_node, - "num_files": max_files_per_folder, - "num_services": max_services_per_node, - "num_applications": max_applications_per_node, - "num_nics": max_nics_per_node, - "num_acl_rules": max_acl_rules, - "num_protocols": len(self.protocols), - "num_ports": len(self.protocols), - "num_ips": len(self.ip_address_list), - "max_acl_rules": max_acl_rules, - "max_nics_per_node": max_nics_per_node, - } + # self.ip_address_list: List[str] = ip_list + # self.wildcard_list: List[str] = wildcard_list + # if self.wildcard_list == []: + # self.wildcard_list = ["NONE"] + # # action_args are settings which are applied to the action space as a whole. + # global_action_args = { + # "num_nodes": len(self.node_names), + # "num_folders": max_folders_per_node, + # "num_files": max_files_per_folder, + # "num_services": max_services_per_node, + # "num_applications": max_applications_per_node, + # "num_nics": max_nics_per_node, + # "num_acl_rules": max_acl_rules, + # "num_protocols": len(self.protocols), + # "num_ports": len(self.protocols), + # "num_ips": len(self.ip_address_list), + # "max_acl_rules": max_acl_rules, + # "max_nics_per_node": max_nics_per_node, + # } self.actions: Dict[str, AbstractAction] = {} for act_spec in actions: # each action is provided into the action space config like this: @@ -260,191 +265,191 @@ class ActionManager: """Return the gymnasium action space for this agent.""" return spaces.Discrete(len(self.action_map)) - def get_node_name_by_idx(self, node_idx: int) -> str: - """ - Get the node name corresponding to the given index. + # def get_node_name_by_idx(self, node_idx: int) -> str: + # """ + # Get the node name corresponding to the given index. - :param node_idx: The index of the node to retrieve. - :type node_idx: int - :return: The node hostname. - :rtype: str - """ - if not node_idx < len(self.node_names): - msg = ( - f"Error: agent attempted to perform an action on node {node_idx}, but its action space only" - f"has {len(self.node_names)} nodes." - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.node_names[node_idx] + # :param node_idx: The index of the node to retrieve. + # :type node_idx: int + # :return: The node hostname. + # :rtype: str + # """ + # if not node_idx < len(self.node_names): + # msg = ( + # f"Error: agent attempted to perform an action on node {node_idx}, but its action space only" + # f"has {len(self.node_names)} nodes." + # ) + # _LOGGER.error(msg) + # raise RuntimeError(msg) + # return self.node_names[node_idx] - def get_folder_name_by_idx(self, node_idx: int, folder_idx: int) -> Optional[str]: - """ - Get the folder name corresponding to the given node and folder indices. + # def get_folder_name_by_idx(self, node_idx: int, folder_idx: int) -> Optional[str]: + # """ + # Get the folder name corresponding to the given node and folder indices. - :param node_idx: The index of the node. - :type node_idx: int - :param folder_idx: The index of the folder on the node. - :type folder_idx: int - :return: The name of the folder. Or None if the node has fewer folders than the given index. - :rtype: Optional[str] - """ - if node_idx >= len(self.folder_names) or folder_idx >= len(self.folder_names[node_idx]): - msg = ( - f"Error: agent attempted to perform an action on node {node_idx} and folder {folder_idx}, but this" - f" is out of range for its action space. Folder on each node: {self.folder_names}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.folder_names[node_idx][folder_idx] + # :param node_idx: The index of the node. + # :type node_idx: int + # :param folder_idx: The index of the folder on the node. + # :type folder_idx: int + # :return: The name of the folder. Or None if the node has fewer folders than the given index. + # :rtype: Optional[str] + # """ + # if node_idx >= len(self.folder_names) or folder_idx >= len(self.folder_names[node_idx]): + # msg = ( + # f"Error: agent attempted to perform an action on node {node_idx} and folder {folder_idx}, but this" + # f" is out of range for its action space. Folder on each node: {self.folder_names}" + # ) + # _LOGGER.error(msg) + # raise RuntimeError(msg) + # return self.folder_names[node_idx][folder_idx] - def get_file_name_by_idx(self, node_idx: int, folder_idx: int, file_idx: int) -> Optional[str]: - """Get the file name corresponding to the given node, folder, and file indices. + # def get_file_name_by_idx(self, node_idx: int, folder_idx: int, file_idx: int) -> Optional[str]: + # """Get the file name corresponding to the given node, folder, and file indices. - :param node_idx: The index of the node. - :type node_idx: int - :param folder_idx: The index of the folder on the node. - :type folder_idx: int - :param file_idx: The index of the file in the folder. - :type file_idx: int - :return: The name of the file. Or None if the node has fewer folders than the given index, or the folder has - fewer files than the given index. - :rtype: Optional[str] - """ - if ( - node_idx >= len(self.file_names) - or folder_idx >= len(self.file_names[node_idx]) - or file_idx >= len(self.file_names[node_idx][folder_idx]) - ): - msg = ( - f"Error: agent attempted to perform an action on node {node_idx} folder {folder_idx} file {file_idx}" - f" but this is out of range for its action space. Files on each node: {self.file_names}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.file_names[node_idx][folder_idx][file_idx] + # :param node_idx: The index of the node. + # :type node_idx: int + # :param folder_idx: The index of the folder on the node. + # :type folder_idx: int + # :param file_idx: The index of the file in the folder. + # :type file_idx: int + # :return: The name of the file. Or None if the node has fewer folders than the given index, or the folder has + # fewer files than the given index. + # :rtype: Optional[str] + # """ + # if ( + # node_idx >= len(self.file_names) + # or folder_idx >= len(self.file_names[node_idx]) + # or file_idx >= len(self.file_names[node_idx][folder_idx]) + # ): + # msg = ( + # f"Error: agent attempted to perform an action on node {node_idx} folder {folder_idx} file {file_idx}" + # f" but this is out of range for its action space. Files on each node: {self.file_names}" + # ) + # _LOGGER.error(msg) + # raise RuntimeError(msg) + # return self.file_names[node_idx][folder_idx][file_idx] - def get_service_name_by_idx(self, node_idx: int, service_idx: int) -> Optional[str]: - """Get the service name corresponding to the given node and service indices. + # def get_service_name_by_idx(self, node_idx: int, service_idx: int) -> Optional[str]: + # """Get the service name corresponding to the given node and service indices. - :param node_idx: The index of the node. - :type node_idx: int - :param service_idx: The index of the service on the node. - :type service_idx: int - :return: The name of the service. Or None if the node has fewer services than the given index. - :rtype: Optional[str] - """ - if node_idx >= len(self.service_names) or service_idx >= len(self.service_names[node_idx]): - msg = ( - f"Error: agent attempted to perform an action on node {node_idx} and service {service_idx}, but this" - f" is out of range for its action space. Services on each node: {self.service_names}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.service_names[node_idx][service_idx] + # :param node_idx: The index of the node. + # :type node_idx: int + # :param service_idx: The index of the service on the node. + # :type service_idx: int + # :return: The name of the service. Or None if the node has fewer services than the given index. + # :rtype: Optional[str] + # """ + # if node_idx >= len(self.service_names) or service_idx >= len(self.service_names[node_idx]): + # msg = ( + # f"Error: agent attempted to perform an action on node {node_idx} and service {service_idx}, but this" + # f" is out of range for its action space. Services on each node: {self.service_names}" + # ) + # _LOGGER.error(msg) + # raise RuntimeError(msg) + # return self.service_names[node_idx][service_idx] - def get_application_name_by_idx(self, node_idx: int, application_idx: int) -> Optional[str]: - """Get the application name corresponding to the given node and service indices. + # def get_application_name_by_idx(self, node_idx: int, application_idx: int) -> Optional[str]: + # """Get the application name corresponding to the given node and service indices. - :param node_idx: The index of the node. - :type node_idx: int - :param application_idx: The index of the service on the node. - :type application_idx: int - :return: The name of the service. Or None if the node has fewer services than the given index. - :rtype: Optional[str] - """ - if node_idx >= len(self.application_names) or application_idx >= len(self.application_names[node_idx]): - msg = ( - f"Error: agent attempted to perform an action on node {node_idx} and app {application_idx}, but " - f"this is out of range for its action space. Applications on each node: {self.application_names}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.application_names[node_idx][application_idx] + # :param node_idx: The index of the node. + # :type node_idx: int + # :param application_idx: The index of the service on the node. + # :type application_idx: int + # :return: The name of the service. Or None if the node has fewer services than the given index. + # :rtype: Optional[str] + # """ + # if node_idx >= len(self.application_names) or application_idx >= len(self.application_names[node_idx]): + # msg = ( + # f"Error: agent attempted to perform an action on node {node_idx} and app {application_idx}, but " + # f"this is out of range for its action space. Applications on each node: {self.application_names}" + # ) + # _LOGGER.error(msg) + # raise RuntimeError(msg) + # return self.application_names[node_idx][application_idx] - def get_internet_protocol_by_idx(self, protocol_idx: int) -> str: - """Get the internet protocol corresponding to the given index. + # def get_internet_protocol_by_idx(self, protocol_idx: int) -> str: + # """Get the internet protocol corresponding to the given index. - :param protocol_idx: The index of the protocol to retrieve. - :type protocol_idx: int - :return: The protocol. - :rtype: str - """ - if protocol_idx >= len(self.protocols): - msg = ( - f"Error: agent attempted to perform an action on protocol {protocol_idx} but this" - f" is out of range for its action space. Protocols: {self.protocols}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.protocols[protocol_idx] + # :param protocol_idx: The index of the protocol to retrieve. + # :type protocol_idx: int + # :return: The protocol. + # :rtype: str + # """ + # if protocol_idx >= len(self.protocols): + # msg = ( + # f"Error: agent attempted to perform an action on protocol {protocol_idx} but this" + # f" is out of range for its action space. Protocols: {self.protocols}" + # ) + # _LOGGER.error(msg) + # raise RuntimeError(msg) + # return self.protocols[protocol_idx] - def get_ip_address_by_idx(self, ip_idx: int) -> str: - """ - Get the IP address corresponding to the given index. + # def get_ip_address_by_idx(self, ip_idx: int) -> str: + # """ + # Get the IP address corresponding to the given index. - :param ip_idx: The index of the IP address to retrieve. - :type ip_idx: int - :return: The IP address. - :rtype: str - """ - if ip_idx >= len(self.ip_address_list): - msg = ( - f"Error: agent attempted to perform an action on ip address {ip_idx} but this" - f" is out of range for its action space. IP address list: {self.ip_address_list}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.ip_address_list[ip_idx] + # :param ip_idx: The index of the IP address to retrieve. + # :type ip_idx: int + # :return: The IP address. + # :rtype: str + # """ + # if ip_idx >= len(self.ip_address_list): + # msg = ( + # f"Error: agent attempted to perform an action on ip address {ip_idx} but this" + # f" is out of range for its action space. IP address list: {self.ip_address_list}" + # ) + # _LOGGER.error(msg) + # raise RuntimeError(msg) + # return self.ip_address_list[ip_idx] - def get_wildcard_by_idx(self, wildcard_idx: int) -> str: - """ - Get the IP wildcard corresponding to the given index. + # def get_wildcard_by_idx(self, wildcard_idx: int) -> str: + # """ + # Get the IP wildcard corresponding to the given index. - :param ip_idx: The index of the IP wildcard to retrieve. - :type ip_idx: int - :return: The wildcard address. - :rtype: str - """ - if wildcard_idx >= len(self.wildcard_list): - msg = ( - f"Error: agent attempted to perform an action on ip wildcard {wildcard_idx} but this" - f" is out of range for its action space. Wildcard list: {self.wildcard_list}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.wildcard_list[wildcard_idx] + # :param ip_idx: The index of the IP wildcard to retrieve. + # :type ip_idx: int + # :return: The wildcard address. + # :rtype: str + # """ + # if wildcard_idx >= len(self.wildcard_list): + # msg = ( + # f"Error: agent attempted to perform an action on ip wildcard {wildcard_idx} but this" + # f" is out of range for its action space. Wildcard list: {self.wildcard_list}" + # ) + # _LOGGER.error(msg) + # raise RuntimeError(msg) + # return self.wildcard_list[wildcard_idx] - def get_port_by_idx(self, port_idx: int) -> str: - """ - Get the port corresponding to the given index. + # def get_port_by_idx(self, port_idx: int) -> str: + # """ + # Get the port corresponding to the given index. - :param port_idx: The index of the port to retrieve. - :type port_idx: int - :return: The port. - :rtype: str - """ - if port_idx >= len(self.ports): - msg = ( - f"Error: agent attempted to perform an action on port {port_idx} but this" - f" is out of range for its action space. Port list: {self.ip_address_list}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.ports[port_idx] + # :param port_idx: The index of the port to retrieve. + # :type port_idx: int + # :return: The port. + # :rtype: str + # """ + # if port_idx >= len(self.ports): + # msg = ( + # f"Error: agent attempted to perform an action on port {port_idx} but this" + # f" is out of range for its action space. Port list: {self.ip_address_list}" + # ) + # _LOGGER.error(msg) + # raise RuntimeError(msg) + # return self.ports[port_idx] - def get_nic_num_by_idx(self, node_idx: int, nic_idx: int) -> int: - """ - Get the NIC number corresponding to the given node and NIC indices. + # def get_nic_num_by_idx(self, node_idx: int, nic_idx: int) -> int: + # """ + # Get the NIC number corresponding to the given node and NIC indices. - :param node_idx: The index of the node. - :type node_idx: int - :param nic_idx: The index of the NIC on the node. - :type nic_idx: int - :return: The NIC number. - :rtype: int - """ - return nic_idx + 1 + # :param node_idx: The index of the node. + # :type node_idx: int + # :param nic_idx: The index of the NIC on the node. + # :type nic_idx: int + # :return: The NIC number. + # :rtype: int + # """ + # return nic_idx + 1 @classmethod def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager": diff --git a/src/primaite/game/agent/actions/node.py b/src/primaite/game/agent/actions/node.py new file mode 100644 index 00000000..cbf035a0 --- /dev/null +++ b/src/primaite/game/agent/actions/node.py @@ -0,0 +1,52 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from abc import abstractmethod +from typing import ClassVar, Dict + +from primaite.game.agent.actions.manager import AbstractAction +from primaite.interface.request import RequestFormat + + +class NodeAbstractAction(AbstractAction): + """ + Abstract base class for node actions. + + Any action which applies to a node and uses node_name as its only parameter can inherit from this base class. + """ + + class ConfigSchema(AbstractAction.ConfigSchema): + node_name: str + + verb: ClassVar[str] + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return ["network", "node", config.node_name, cls.verb] + + +class NodeOSScanAction(NodeAbstractAction, identifier="node_os_scan"): + """Action which scans a node's OS.""" + + class ConfigSchema(NodeAbstractAction.ConfigSchema): + verb: str = "scan" + + +class NodeShutdownAction(NodeAbstractAction, identifier="node_shutdown"): + """Action which shuts down a node.""" + + class ConfigSchema(NodeAbstractAction.ConfigSchema): + verb: str = "shutdown" + + +class NodeStartupAction(NodeAbstractAction, identifier="node_startup"): + """Action which starts up a node.""" + + class ConfigSchema(NodeAbstractAction.ConfigSchema): + verb: str = "startup" + + +class NodeResetAction(NodeAbstractAction, identifier="node_reset"): + """Action which resets a node.""" + + class ConfigSchema(NodeAbstractAction.ConfigSchema): + verb: str = "reset" diff --git a/src/primaite/game/agent/actions/service.py b/src/primaite/game/agent/actions/service.py index 79d70212..97b37bde 100644 --- a/src/primaite/game/agent/actions/service.py +++ b/src/primaite/game/agent/actions/service.py @@ -1,7 +1,10 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from typing import ClassVar + from primaite.game.agent.actions.manager import AbstractAction from primaite.interface.request import RequestFormat + class NodeServiceAbstractAction(AbstractAction): class ConfigSchema(AbstractAction.ConfigSchema): node_name: str @@ -10,33 +13,69 @@ class NodeServiceAbstractAction(AbstractAction): verb: ClassVar[str] @classmethod - def form_request(cls, config:ConfigSchema) -> RequestFormat: + def form_request(cls, config: ConfigSchema) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" return ["network", "node", config.node_name, "service", config.service_name, cls.verb] + class NodeServiceScanAction(NodeServiceAbstractAction, identifier="node_service_scan"): - verb: str = "scan" + """Action which scans a service.""" -class NodeServiceStopAction(NodeServiceAbstractAction, identifier=...): - verb: str = "stop" + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + verb: str = "scan" -class NodeServiceStartAction(NodeServiceAbstractAction): - verb: str = "start" -class NodeServicePauseAction(NodeServiceAbstractAction): - verb: str = "pause" +class NodeServiceStopAction(NodeServiceAbstractAction, identifier="node_service_stop"): + """Action which stops a service.""" -class NodeServiceResumeAction(NodeServiceAbstractAction): - verb: str = "resume" + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + verb: str = "stop" -class NodeServiceRestartAction(NodeServiceAbstractAction): - verb: str = "restart" -class NodeServiceDisableAction(NodeServiceAbstractAction): - verb: str = "disable" +class NodeServiceStartAction(NodeServiceAbstractAction, identifier="node_service_start"): + """Action which starts a service.""" -class NodeServiceEnableAction(NodeServiceAbstractAction): - verb: str = "enable" + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + verb: str = "start" -class NodeServiceFixAction(NodeServiceAbstractAction): - verb: str = "fix" + +class NodeServicePauseAction(NodeServiceAbstractAction, identifier="node_service_pause"): + """Action which pauses a service.""" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + verb: str = "pause" + + +class NodeServiceResumeAction(NodeServiceAbstractAction, identifier="node_service_resume"): + """Action which resumes a service.""" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + verb: str = "resume" + + +class NodeServiceRestartAction(NodeServiceAbstractAction, identifier="node_service_restart"): + """Action which restarts a service.""" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + verb: str = "restart" + + +class NodeServiceDisableAction(NodeServiceAbstractAction, identifier="node_service_disable"): + """Action which disables a service.""" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + verb: str = "disable" + + +class NodeServiceEnableAction(NodeServiceAbstractAction, identifier="node_service_enable"): + """Action which enables a service.""" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + verb: str = "enable" + + +class NodeServiceFixAction(NodeServiceAbstractAction, identifier="node_service_fix"): + """Action which fixes a service.""" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + verb: str = "fix"