diff --git a/CHANGELOG.md b/CHANGELOG.md index 3aec3ba1..de94f6f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Agents now follow a common configuration format, simplifying the configuration of agents and their extensibilty. +- Actions within PrimAITE are now extensible, allowing for plugin support. + ## [3.3.0] - 2024-09-04 diff --git a/docs/source/how_to_guides/extensible_actions.rst b/docs/source/how_to_guides/extensible_actions.rst new file mode 100644 index 00000000..0064a3a7 --- /dev/null +++ b/docs/source/how_to_guides/extensible_actions.rst @@ -0,0 +1,67 @@ +.. only:: comment + + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK + + +Extensible Actions +****************** + + +Changes to Actions class Structure. +=================================== + +Actions within PrimAITE have been updated to inherit from a base class, AbstractAction, standardising their format and allowing for easier creation of custom actions. Actions now use a ``ConfigSchema`` to define the possible configuration variables, and use pydantic to enforce correct parameters are passed through. + + +Developing Custom Actions. +========================== + +Custom actions within PrimAITE must be a sub-class of `AbstractAction`, and contain 3 key items: + +#. ConfigSchema class + +#. Unique Identifier + +#. `from_request` method. + + +ConfigSchema +############ + +The ConfigSchema sub-class of the action must contain all `configurable` variables within the action, that would be specified within the environments configuration YAML file. + + +Unique Identifier +################# + +When declaring a custom class, it must have a unique identifier string, that allows PrimAITE to generate the correct action when needed. + +.. code:: Python + + class CreateDirectoryAction(AbstractAction, identifier="node_folder_create") + + config: CreateDirectoryAction.ConfigSchema + + class ConfigSchema(AbstractAction.ConfigSchema): + + verb: ClassVar[str] = "create" + node_name: str + directory_name: str + + def form_request(cls, config: ConfigSchema) -> RequestFormat: + return ["network", + "node", + config.node_name, + "file_system", + config.verb, + "folder", + config.directory_name, + ] + +The above action would fail pydantic validation as the identifier "node_folder_create" is already used by the `NodeFolderCreateAction`, and would create a duplicate listing within `AbstractAction._registry`. + + +from_request method +################### + +PrimAITE actions need to be have a `from_request` method, which can be passed to the `RequestManager` for processing. This allows the custom action to be actioned within the simulation environment. diff --git a/docs/source/how_to_guides/extensible_rewards.rst b/docs/source/how_to_guides/extensible_rewards.rst new file mode 100644 index 00000000..a01b9d8f --- /dev/null +++ b/docs/source/how_to_guides/extensible_rewards.rst @@ -0,0 +1,57 @@ +.. only:: comment + + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK + +.. _about: + +Extensible Rewards +****************** +Extensible Rewards differ from the previous reward mechanism used in PrimAITE v3.x as new reward +types can be added without requiring a change to the RewardFunction class in rewards.py (PrimAITE +core repository). + +Changes to reward class structure. +================================== + +Reward classes are inherited from AbstractReward (a sub-class of Pydantic's BaseModel). +Within the reward class there is a ConfigSchema class responsible for ensuring the config file data +is in the correct format. This also means there is little (if no) requirement for and `__init__` +method. The `.from_config` method is no longer required as it's inherited from `AbstractReward`. +Each class requires an identifier string which is used by the ConfigSchema class to verify that it +hasn't previously been added to the registry. + +Inheriting from `BaseModel` removes the need for an `__init__` method but means that object +attributes need to be passed by keyword. + +To add a new reward class follow the example below. Note that the type attribute in the +`ConfigSchema` class should match the type used in the config file to define the reward. + +.. code-block:: Python + +class DatabaseFileIntegrity(AbstractReward, identifier="DATABASE_FILE_INTEGRITY"): + """Reward function component which rewards the agent for maintaining the integrity of a database file.""" + + config: "DatabaseFileIntegrity.ConfigSchema" + location_in_state: List[str] = [""] + reward: float = 0.0 + + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for DatabaseFileIntegrity.""" + + type: str = "DATABASE_FILE_INTEGRITY" + node_hostname: str + folder_name: str + file_name: str + + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: + """Calculate the reward for the current state. + pass + + + +Changes to YAML file. +===================== +.. code:: YAML + + There's no longer a need to provide a `dns_server` as an option in the simulation section + of the config file. diff --git a/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml b/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml index dfd200f3..8c83bf79 100644 --- a/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml +++ b/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml @@ -55,50 +55,50 @@ agents: action_space: action_list: - - type: DONOTHING - - type: NODE_SHUTDOWN - - type: NODE_STARTUP - - type: HOST_NIC_ENABLE - - type: HOST_NIC_DISABLE + - type: do_nothing + - type: node_shutdown + - type: node_startup + - type: host_nic_enable + - type: host_nic_enable action_map: 0: - action: DONOTHING + action: do_nothing options: {} 1: - action: NODE_SHUTDOWN + action: node_shutdown options: - node_id: 0 + node_name: client_1 2: - action: NODE_SHUTDOWN + action: node_shutdown options: - node_id: 1 + node_name: server 3: - action: NODE_STARTUP + action: node_startup options: - node_id: 0 + node_name: client_1 4: - action: NODE_STARTUP + action: node_startup options: - node_id: 1 + node_name: server 5: - action: HOST_NIC_DISABLE + action: host_nic_disable options: - node_id: 0 + node_name: client_1 nic_id: 0 6: - action: HOST_NIC_DISABLE + action: host_nic_disable options: - node_id: 1 + node_name: server nic_id: 0 7: - action: HOST_NIC_ENABLE + action: host_nic_enable options: - node_id: 0 + node_name: client_1 nic_id: 0 8: - action: HOST_NIC_ENABLE + action: host_nic_enable options: - node_id: 1 + node_name: server nic_id: 0 options: nodes: diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py deleted file mode 100644 index 5ec122ec..00000000 --- a/src/primaite/game/agent/actions.py +++ /dev/null @@ -1,1804 +0,0 @@ -# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK -""" -This module contains the ActionManager class which belongs to the Agent class. - -An agent's action space is made up of a collection of actions. Each action is an instance of a subclass of -AbstractAction. The ActionManager is responsible for: - 1. Creating the action space from a list of action types. - 2. Converting an integer action choice into a specific action and parameter choice. - 3. Converting an action and parameter choice into a request which can be ingested by the PrimAITE simulation. This - ensures that requests conform to the simulator's request format. -""" -import itertools -from abc import ABC, abstractmethod -from typing import Dict, List, Literal, Optional, Tuple, TYPE_CHECKING, Union - -from gymnasium import spaces -from pydantic import BaseModel, ConfigDict, Field, field_validator, ValidationInfo - -from primaite import getLogger -from primaite.interface.request import RequestFormat - -_LOGGER = getLogger(__name__) - -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame - - -class AbstractAction(ABC): - """Base class for actions.""" - - @abstractmethod - def __init__(self, manager: "ActionManager", **kwargs) -> None: - """ - Init method for action. - - All action init functions should accept **kwargs as a way of ignoring extra arguments. - - Since many parameters are defined for the action space as a whole (such as max files per folder, max services - per node), we need to pass those options to every action that gets created. To prevent verbosity, these - parameters are just broadcasted to all actions and the actions can pay attention to the ones that apply. - """ - self.name: str = "" - """Human-readable action identifier used for printing, logging, and reporting.""" - self.shape: Dict[str, int] = {} - """Dictionary describing the number of options for each parameter of this action. The keys of this dict must - align with the keyword args of the form_request method.""" - self.manager: ActionManager = manager - """Reference to the ActionManager which created this action. This is used to access the game and simulation - objects.""" - - @abstractmethod - def form_request(self) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return [] - - -class DoNothingAction(AbstractAction): - """Action which does nothing. This is here to allow agents to be idle if they choose to.""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - self.name = "DONOTHING" - self.shape: Dict[str, int] = { - "dummy": 1, - } - # This action does not accept any parameters, therefore it technically has a gymnasium shape of Discrete(1), - # i.e. a choice between one option. To make enumerating this action easier, we are adding a 'dummy' paramter - # with one option. This just aids the Action Manager to enumerate all possibilities. - - def form_request(self, **kwargs) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return ["do_nothing"] - - -class NodeServiceAbstractAction(AbstractAction): - """ - Base class for service actions. - - Any action which applies to a service and uses node_id and service_id as its only two parameters can inherit from - this base class. - """ - - @abstractmethod - def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services} - self.verb: str # define but don't initialise: defends against children classes not defining this - - def form_request(self, node_id: int, service_id: int) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - service_name = self.manager.get_service_name_by_idx(node_id, service_id) - if node_name is None or service_name is None: - return ["do_nothing"] - return ["network", "node", node_name, "service", service_name, self.verb] - - -class NodeServiceScanAction(NodeServiceAbstractAction): - """Action which scans a service.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb: str = "scan" - - -class NodeServiceStopAction(NodeServiceAbstractAction): - """Action which stops a service.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb: str = "stop" - - -class NodeServiceStartAction(NodeServiceAbstractAction): - """Action which starts a service.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb: str = "start" - - -class NodeServicePauseAction(NodeServiceAbstractAction): - """Action which pauses a service.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb: str = "pause" - - -class NodeServiceResumeAction(NodeServiceAbstractAction): - """Action which resumes a service.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb: str = "resume" - - -class NodeServiceRestartAction(NodeServiceAbstractAction): - """Action which restarts a service.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb: str = "restart" - - -class NodeServiceDisableAction(NodeServiceAbstractAction): - """Action which disables a service.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb: str = "disable" - - -class NodeServiceEnableAction(NodeServiceAbstractAction): - """Action which enables a service.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb: str = "enable" - - -class NodeServiceFixAction(NodeServiceAbstractAction): - """Action which fixes a service.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb: str = "fix" - - -class NodeApplicationAbstractAction(AbstractAction): - """ - Base class for application actions. - - Any action which applies to an application and uses node_id and application_id as its only two parameters can - inherit from this base class. - """ - - @abstractmethod - def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None: - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"node_id": num_nodes, "application_id": num_applications} - self.verb: str # define but don't initialise: defends against children classes not defining this - - def form_request(self, node_id: int, application_id: int) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - application_name = self.manager.get_application_name_by_idx(node_id, application_id) - if node_name is None or application_name is None: - return ["do_nothing"] - return ["network", "node", node_name, "application", application_name, self.verb] - - -class NodeApplicationExecuteAction(NodeApplicationAbstractAction): - """Action which executes an application.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_applications=num_applications) - self.verb: str = "execute" - - -class NodeApplicationScanAction(NodeApplicationAbstractAction): - """Action which scans an application.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_applications=num_applications) - self.verb: str = "scan" - - -class NodeApplicationCloseAction(NodeApplicationAbstractAction): - """Action which closes an application.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_applications=num_applications) - self.verb: str = "close" - - -class NodeApplicationFixAction(NodeApplicationAbstractAction): - """Action which fixes an application.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_applications=num_applications) - self.verb: str = "fix" - - -class NodeApplicationInstallAction(AbstractAction): - """Action which installs an application.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"node_id": num_nodes} - - def form_request(self, node_id: int, application_name: str) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: - return ["do_nothing"] - return [ - "network", - "node", - node_name, - "software_manager", - "application", - "install", - application_name, - ] - - -class ConfigureDatabaseClientAction(AbstractAction): - """Action which sets config parameters for a database client on a node.""" - - class _Opts(BaseModel): - """Schema for options that can be passed to this action.""" - - model_config = ConfigDict(extra="forbid") - server_ip_address: Optional[str] = None - server_password: Optional[str] = None - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: int, config: Dict) -> RequestFormat: - """Return the action formatted as a request that can be ingested by the simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: - return ["do_nothing"] - ConfigureDatabaseClientAction._Opts.model_validate(config) # check that options adhere to schema - return ["network", "node", node_name, "application", "DatabaseClient", "configure", config] - - -class ConfigureRansomwareScriptAction(AbstractAction): - """Action which sets config parameters for a ransomware script on a node.""" - - class _Opts(BaseModel): - """Schema for options that can be passed to this option.""" - - model_config = ConfigDict(extra="forbid") - server_ip_address: Optional[str] = None - server_password: Optional[str] = None - payload: Optional[str] = None - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: int, config: Dict) -> RequestFormat: - """Return the action formatted as a request that can be ingested by the simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: - return ["do_nothing"] - ConfigureRansomwareScriptAction._Opts.model_validate(config) # check that options adhere to schema - return ["network", "node", node_name, "application", "RansomwareScript", "configure", config] - - -class ConfigureDoSBotAction(AbstractAction): - """Action which sets config parameters for a DoS bot on a node.""" - - class _Opts(BaseModel): - """Schema for options that can be passed to this action.""" - - model_config = ConfigDict(extra="forbid") - target_ip_address: Optional[str] = None - target_port: Optional[str] = None - payload: Optional[str] = None - repeat: Optional[bool] = None - port_scan_p_of_success: Optional[float] = None - dos_intensity: Optional[float] = None - max_sessions: Optional[int] = None - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: int, config: Dict) -> RequestFormat: - """Return the action formatted as a request that can be ingested by the simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: - return ["do_nothing"] - self._Opts.model_validate(config) # check that options adhere to schema - return ["network", "node", node_name, "application", "DoSBot", "configure", config] - - -class NodeApplicationRemoveAction(AbstractAction): - """Action which removes/uninstalls an application.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"node_id": num_nodes} - - def form_request(self, node_id: int, application_name: str) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: - return ["do_nothing"] - return ["network", "node", node_name, "software_manager", "application", "uninstall", application_name] - - -class NodeFolderAbstractAction(AbstractAction): - """ - Base class for folder actions. - - Any action which applies to a folder and uses node_id and folder_id as its only two parameters can inherit from - this base class. - """ - - @abstractmethod - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders} - self.verb: str # define but don't initialise: defends against children classes not defining this - - def form_request(self, node_id: int, folder_id: int) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id) - if node_name is None or folder_name is None: - return ["do_nothing"] - return ["network", "node", node_name, "file_system", "folder", folder_name, self.verb] - - -class NodeFolderScanAction(NodeFolderAbstractAction): - """Action which scans a folder.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) - self.verb: str = "scan" - - -class NodeFolderCheckhashAction(NodeFolderAbstractAction): - """Action which checks the hash of a folder.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) - self.verb: str = "checkhash" - - -class NodeFolderRepairAction(NodeFolderAbstractAction): - """Action which repairs a folder.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) - self.verb: str = "repair" - - -class NodeFolderRestoreAction(NodeFolderAbstractAction): - """Action which restores a folder.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) - self.verb: str = "restore" - - -class NodeFileCreateAction(AbstractAction): - """Action which creates a new file in a given folder.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) - self.verb: str = "create" - - def form_request( - self, node_id: int, folder_name: str, file_name: str, force: Optional[bool] = False - ) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None or folder_name is None or file_name is None: - return ["do_nothing"] - return ["network", "node", node_name, "file_system", "create", "file", folder_name, file_name, force] - - -class NodeFolderCreateAction(AbstractAction): - """Action which creates a new folder.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) - self.verb: str = "create" - - def form_request(self, node_id: int, folder_name: str) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None or folder_name is None: - return ["do_nothing"] - return ["network", "node", node_name, "file_system", "create", "folder", folder_name] - - -class NodeFileAbstractAction(AbstractAction): - """Abstract base class for file actions. - - Any action which applies to a file and uses node_id, folder_id, and file_id as its only three parameters can inherit - from this base class. - """ - - @abstractmethod - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders, "file_id": num_files} - self.verb: str # define but don't initialise: defends against children classes not defining this - - def form_request(self, node_id: int, folder_id: int, file_id: int) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id) - file_name = self.manager.get_file_name_by_idx(node_idx=node_id, folder_idx=folder_id, file_idx=file_id) - if node_name is None or folder_name is None or file_name is None: - return ["do_nothing"] - return ["network", "node", node_name, "file_system", "folder", folder_name, "file", file_name, self.verb] - - -class NodeFileScanAction(NodeFileAbstractAction): - """Action which scans a file.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb: str = "scan" - - -class NodeFileCheckhashAction(NodeFileAbstractAction): - """Action which checks the hash of a file.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb: str = "checkhash" - - -class NodeFileDeleteAction(NodeFileAbstractAction): - """Action which deletes a file.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb: str = "delete" - - def form_request(self, node_id: int, folder_id: int, file_id: int) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id) - file_name = self.manager.get_file_name_by_idx(node_idx=node_id, folder_idx=folder_id, file_idx=file_id) - if node_name is None or folder_name is None or file_name is None: - return ["do_nothing"] - return ["network", "node", node_name, "file_system", "delete", "file", folder_name, file_name] - - -class NodeFileRepairAction(NodeFileAbstractAction): - """Action which repairs a file.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb: str = "repair" - - -class NodeFileRestoreAction(NodeFileAbstractAction): - """Action which restores a file.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb: str = "restore" - - -class NodeFileCorruptAction(NodeFileAbstractAction): - """Action which corrupts a file.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb: str = "corrupt" - - -class NodeFileAccessAction(AbstractAction): - """Action which increases a file's access count.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) - self.verb: str = "access" - - def form_request(self, node_id: int, folder_name: str, file_name: str) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None or folder_name is None or file_name is None: - return ["do_nothing"] - return ["network", "node", node_name, "file_system", "access", folder_name, file_name] - - -class NodeAbstractAction(AbstractAction): - """ - Abstract base class for node actions. - - Any action which applies to a node and uses node_id as its only parameter can inherit from this base class. - """ - - @abstractmethod - def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"node_id": num_nodes} - self.verb: str # define but don't initialise: defends against children classes not defining this - - def form_request(self, node_id: int) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - return ["network", "node", node_name, self.verb] - - -class NodeOSScanAction(NodeAbstractAction): - """Action which scans a node's OS.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes) - self.verb: str = "scan" - - -class NodeShutdownAction(NodeAbstractAction): - """Action which shuts down a node.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes) - self.verb: str = "shutdown" - - -class NodeStartupAction(NodeAbstractAction): - """Action which starts up a node.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes) - self.verb: str = "startup" - - -class NodeResetAction(NodeAbstractAction): - """Action which resets a node.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes) - self.verb: str = "reset" - - -class RouterACLAddRuleAction(AbstractAction): - """Action which adds a rule to a router's ACL.""" - - class ACLRuleOptions(BaseModel): - """Validator for ACL_ADD_RULE options.""" - - target_router: str - """On which router to add the rule, must be specified.""" - position: int - """At what position to add the rule, must be specified.""" - permission: Literal[1, 2] - """Whether to allow or deny traffic, must be specified. 1 = PERMIT, 2 = DENY.""" - source_ip_id: int = Field(default=1, ge=1) - """Rule source IP address. By default, all ip addresses.""" - source_wildcard_id: int = Field(default=0, ge=0) - """Rule source IP wildcard. By default, use the wildcard at index 0 from action manager.""" - source_port_id: int = Field(default=1, ge=1) - """Rule source port. By default, all source ports.""" - dest_ip_id: int = Field(default=1, ge=1) - """Rule destination IP address. By default, all ip addresses.""" - dest_wildcard_id: int = Field(default=0, ge=0) - """Rule destination IP wildcard. By default, use the wildcard at index 0 from action manager.""" - dest_port_id: int = Field(default=1, ge=1) - """Rule destination port. By default, all destination ports.""" - protocol_id: int = Field(default=1, ge=1) - """Rule protocol. By default, all protocols.""" - - @field_validator( - "source_ip_id", - "source_port_id", - "source_wildcard_id", - "dest_ip_id", - "dest_port_id", - "dest_wildcard_id", - "protocol_id", - mode="before", - ) - @classmethod - def not_none(cls, v: str, info: ValidationInfo) -> int: - """If None is passed, use the default value instead.""" - if v is None: - return cls.model_fields[info.field_name].default - return v - - def __init__( - self, - manager: "ActionManager", - max_acl_rules: int, - num_ips: int, - num_ports: int, - num_protocols: int, - **kwargs, - ) -> None: - """Init method for RouterACLAddRuleAction. - - :param manager: Reference to the ActionManager which created this action. - :type manager: ActionManager - :param max_acl_rules: Maximum number of ACL rules that can be added to the router. - :type max_acl_rules: int - :param num_ips: Number of IP addresses in the simulation. - :type num_ips: int - :param num_ports: Number of ports in the simulation. - :type num_ports: int - :param num_protocols: Number of protocols in the simulation. - :type num_protocols: int - """ - super().__init__(manager=manager) - num_permissions = 3 - self.shape: Dict[str, int] = { - "position": max_acl_rules, - "permission": num_permissions, - "source_ip_id": num_ips, - "dest_ip_id": num_ips, - "source_port_id": num_ports, - "dest_port_id": num_ports, - "protocol_id": num_protocols, - } - - def form_request( - self, - target_router: str, - position: int, - permission: int, - source_ip_id: int, - source_wildcard_id: int, - dest_ip_id: int, - dest_wildcard_id: int, - source_port_id: int, - dest_port_id: int, - protocol_id: int, - ) -> List[str]: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - # Validate incoming data. - parsed_options = RouterACLAddRuleAction.ACLRuleOptions( - target_router=target_router, - position=position, - permission=permission, - source_ip_id=source_ip_id, - source_wildcard_id=source_wildcard_id, - dest_ip_id=dest_ip_id, - dest_wildcard_id=dest_wildcard_id, - source_port_id=source_port_id, - dest_port_id=dest_port_id, - protocol_id=protocol_id, - ) - if parsed_options.permission == 1: - permission_str = "PERMIT" - elif parsed_options.permission == 2: - permission_str = "DENY" - else: - _LOGGER.warning(f"{self.__class__} received permission {permission}, expected 0 or 1.") - - if parsed_options.protocol_id == 1: - protocol = "ALL" - else: - protocol = self.manager.get_internet_protocol_by_idx(parsed_options.protocol_id - 2) - # subtract 2 to account for UNUSED=0 and ALL=1. - - if parsed_options.source_ip_id == 1: - src_ip = "ALL" - else: - src_ip = self.manager.get_ip_address_by_idx(parsed_options.source_ip_id - 2) - # subtract 2 to account for UNUSED=0, and ALL=1 - - src_wildcard = self.manager.get_wildcard_by_idx(parsed_options.source_wildcard_id) - - if parsed_options.source_port_id == 1: - src_port = "ALL" - else: - src_port = self.manager.get_port_by_idx(parsed_options.source_port_id - 2) - # subtract 2 to account for UNUSED=0, and ALL=1 - - if parsed_options.dest_ip_id == 1: - dst_ip = "ALL" - else: - dst_ip = self.manager.get_ip_address_by_idx(parsed_options.dest_ip_id - 2) - # subtract 2 to account for UNUSED=0, and ALL=1 - dst_wildcard = self.manager.get_wildcard_by_idx(parsed_options.dest_wildcard_id) - - if parsed_options.dest_port_id == 1: - dst_port = "ALL" - else: - dst_port = self.manager.get_port_by_idx(parsed_options.dest_port_id - 2) - # subtract 2 to account for UNUSED=0, and ALL=1 - - return [ - "network", - "node", - target_router, - "acl", - "add_rule", - permission_str, - protocol, - str(src_ip), - src_wildcard, - src_port, - str(dst_ip), - dst_wildcard, - dst_port, - position, - ] - - -class RouterACLRemoveRuleAction(AbstractAction): - """Action which removes a rule from a router's ACL.""" - - def __init__(self, manager: "ActionManager", max_acl_rules: int, **kwargs) -> None: - """Init method for RouterACLRemoveRuleAction. - - :param manager: Reference to the ActionManager which created this action. - :type manager: ActionManager - :param max_acl_rules: Maximum number of ACL rules that can be added to the router. - :type max_acl_rules: int - """ - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"position": max_acl_rules} - - def form_request(self, target_router: str, position: int) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return ["network", "node", target_router, "acl", "remove_rule", position] - - -class FirewallACLAddRuleAction(AbstractAction): - """Action which adds a rule to a firewall port's ACL.""" - - def __init__( - self, - manager: "ActionManager", - max_acl_rules: int, - num_ips: int, - num_ports: int, - num_protocols: int, - **kwargs, - ) -> None: - """Init method for FirewallACLAddRuleAction. - - :param manager: Reference to the ActionManager which created this action. - :type manager: ActionManager - :param max_acl_rules: Maximum number of ACL rules that can be added to the router. - :type max_acl_rules: int - :param num_ips: Number of IP addresses in the simulation. - :type num_ips: int - :param num_ports: Number of ports in the simulation. - :type num_ports: int - :param num_protocols: Number of protocols in the simulation. - :type num_protocols: int - """ - super().__init__(manager=manager) - num_permissions = 3 - self.shape: Dict[str, int] = { - "position": max_acl_rules, - "permission": num_permissions, - "source_ip_id": num_ips, - "dest_ip_id": num_ips, - "source_port_id": num_ports, - "dest_port_id": num_ports, - "protocol_id": num_protocols, - } - - def form_request( - self, - target_firewall_nodename: str, - firewall_port_name: str, - firewall_port_direction: str, - position: int, - permission: int, - source_ip_id: int, - source_wildcard_id: int, - dest_ip_id: int, - dest_wildcard_id: int, - source_port_id: int, - dest_port_id: int, - protocol_id: int, - ) -> List[str]: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - if permission == 0: - permission_str = "UNUSED" - return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS - elif permission == 1: - permission_str = "PERMIT" - elif permission == 2: - permission_str = "DENY" - else: - _LOGGER.warning(f"{self.__class__} received permission {permission}, expected 0 or 1.") - - if protocol_id == 0: - return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS - - if protocol_id == 1: - protocol = "ALL" - else: - protocol = self.manager.get_internet_protocol_by_idx(protocol_id - 2) - # subtract 2 to account for UNUSED=0 and ALL=1. - - if source_ip_id == 0: - return ["do_nothing"] # invalid formulation - elif source_ip_id == 1: - src_ip = "ALL" - else: - src_ip = self.manager.get_ip_address_by_idx(source_ip_id - 2) - # subtract 2 to account for UNUSED=0, and ALL=1 - - if source_port_id == 0: - return ["do_nothing"] # invalid formulation - elif source_port_id == 1: - src_port = "ALL" - else: - src_port = self.manager.get_port_by_idx(source_port_id - 2) - # subtract 2 to account for UNUSED=0, and ALL=1 - - if dest_ip_id == 0: - return ["do_nothing"] # invalid formulation - elif dest_ip_id == 1: - dst_ip = "ALL" - else: - dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id - 2) - # subtract 2 to account for UNUSED=0, and ALL=1 - - if dest_port_id == 0: - return ["do_nothing"] # invalid formulation - elif dest_port_id == 1: - dst_port = "ALL" - else: - dst_port = self.manager.get_port_by_idx(dest_port_id - 2) - # subtract 2 to account for UNUSED=0, and ALL=1 - src_wildcard = self.manager.get_wildcard_by_idx(source_wildcard_id) - dst_wildcard = self.manager.get_wildcard_by_idx(dest_wildcard_id) - - return [ - "network", - "node", - target_firewall_nodename, - firewall_port_name, - firewall_port_direction, - "acl", - "add_rule", - permission_str, - protocol, - str(src_ip), - src_wildcard, - src_port, - str(dst_ip), - dst_wildcard, - dst_port, - position, - ] - - -class FirewallACLRemoveRuleAction(AbstractAction): - """Action which removes a rule from a firewall port's ACL.""" - - def __init__(self, manager: "ActionManager", max_acl_rules: int, **kwargs) -> None: - """Init method for RouterACLRemoveRuleAction. - - :param manager: Reference to the ActionManager which created this action. - :type manager: ActionManager - :param max_acl_rules: Maximum number of ACL rules that can be added to the router. - :type max_acl_rules: int - """ - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"position": max_acl_rules} - - def form_request( - self, target_firewall_nodename: str, firewall_port_name: str, firewall_port_direction: str, position: int - ) -> List[str]: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return [ - "network", - "node", - target_firewall_nodename, - firewall_port_name, - firewall_port_direction, - "acl", - "remove_rule", - position, - ] - - -class HostNICAbstractAction(AbstractAction): - """ - Abstract base class for NIC actions. - - Any action which applies to a NIC and uses node_id and nic_id as its only two parameters can inherit from this base - class. - """ - - def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: - """Init method for HostNICAbstractAction. - - :param manager: Reference to the ActionManager which created this action. - :type manager: ActionManager - :param num_nodes: Number of nodes in the simulation. - :type num_nodes: int - :param max_nics_per_node: Maximum number of NICs per node. - :type max_nics_per_node: int - """ - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node} - self.verb: str # define but don't initialise: defends against children classes not defining this - - def form_request(self, node_id: int, nic_id: int) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_idx=node_id) - nic_num = self.manager.get_nic_num_by_idx(node_idx=node_id, nic_idx=nic_id) - if node_name is None or nic_num is None: - return ["do_nothing"] - return ["network", "node", node_name, "network_interface", nic_num, self.verb] - - -class HostNICEnableAction(HostNICAbstractAction): - """Action which enables a NIC.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs) - self.verb: str = "enable" - - -class HostNICDisableAction(HostNICAbstractAction): - """Action which disables a NIC.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs) - self.verb: str = "disable" - - -class NetworkPortEnableAction(AbstractAction): - """Action which enables are port on a router or a firewall.""" - - def __init__(self, manager: "ActionManager", max_nics_per_node: int, **kwargs) -> None: - """Init method for NetworkPortEnableAction. - - :param max_nics_per_node: Maximum number of NICs per node. - :type max_nics_per_node: int - """ - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"port_id": max_nics_per_node} - - def form_request(self, target_nodename: str, port_id: int) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - if target_nodename is None or port_id is None: - return ["do_nothing"] - return ["network", "node", target_nodename, "network_interface", port_id, "enable"] - - -class NetworkPortDisableAction(AbstractAction): - """Action which disables are port on a router or a firewall.""" - - def __init__(self, manager: "ActionManager", max_nics_per_node: int, **kwargs) -> None: - """Init method for NetworkPortDisableAction. - - :param max_nics_per_node: Maximum number of NICs per node. - :type max_nics_per_node: int - """ - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"port_id": max_nics_per_node} - - def form_request(self, target_nodename: str, port_id: int) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - if target_nodename is None or port_id is None: - return ["do_nothing"] - return ["network", "node", target_nodename, "network_interface", port_id, "disable"] - - -class NodeNMAPPingScanAction(AbstractAction): - """Action which performs an NMAP ping scan.""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request( - self, source_node: str, target_ip_address: Union[str, List[str]], show: Optional[bool] = False - ) -> List[str]: # noqa - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return [ - "network", - "node", - source_node, - "application", - "NMAP", - "ping_scan", - {"target_ip_address": target_ip_address, "show": show}, - ] - - -class NodeNMAPPortScanAction(AbstractAction): - """Action which performs an NMAP port scan.""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request( - self, - source_node: str, - target_ip_address: Union[str, List[str]], - target_protocol: Optional[Union[str, List[str]]] = None, - target_port: Optional[Union[str, List[str]]] = None, - show: Optional[bool] = False, - ) -> List[str]: # noqa - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return [ - "network", - "node", - source_node, - "application", - "NMAP", - "port_scan", - { - "target_ip_address": target_ip_address, - "target_port": target_port, - "target_protocol": target_protocol, - "show": show, - }, - ] - - -class NodeNetworkServiceReconAction(AbstractAction): - """Action which performs an NMAP network service recon (ping scan followed by port scan).""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request( - self, - source_node: str, - target_ip_address: Union[str, List[str]], - target_protocol: Optional[Union[str, List[str]]] = None, - target_port: Optional[Union[str, List[str]]] = None, - show: Optional[bool] = False, - ) -> List[str]: # noqa - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return [ - "network", - "node", - source_node, - "application", - "NMAP", - "network_service_recon", - { - "target_ip_address": target_ip_address, - "target_port": target_port, - "target_protocol": target_protocol, - "show": show, - }, - ] - - -class ConfigureC2BeaconAction(AbstractAction): - """Action which configures a C2 Beacon based on the parameters given.""" - - class _Opts(BaseModel): - """Schema for options that can be passed to this action.""" - - c2_server_ip_address: str - keep_alive_frequency: int = Field(default=5, ge=1) - masquerade_protocol: str = Field(default="TCP") - masquerade_port: str = Field(default="HTTP") - - @field_validator( - "c2_server_ip_address", - "keep_alive_frequency", - "masquerade_protocol", - "masquerade_port", - mode="before", - ) - @classmethod - def not_none(cls, v: str, info: ValidationInfo) -> int: - """If None is passed, use the default value instead.""" - if v is None: - return cls.model_fields[info.field_name].default - return v - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: int, config: Dict) -> RequestFormat: - """Return the action formatted as a request that can be ingested by the simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: - return ["do_nothing"] - config = ConfigureC2BeaconAction._Opts( - c2_server_ip_address=config["c2_server_ip_address"], - keep_alive_frequency=config["keep_alive_frequency"], - masquerade_port=config["masquerade_port"], - masquerade_protocol=config["masquerade_protocol"], - ) - - ConfigureC2BeaconAction._Opts.model_validate(config) # check that options adhere to schema - - return ["network", "node", node_name, "application", "C2Beacon", "configure", config.__dict__] - - -class NodeAccountsChangePasswordAction(AbstractAction): - """Action which changes the password for a user.""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: str, username: str, current_password: str, new_password: str) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - return [ - "network", - "node", - node_name, - "service", - "UserManager", - "change_password", - username, - current_password, - new_password, - ] - - -class NodeSessionsRemoteLoginAction(AbstractAction): - """Action which performs a remote session login.""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: str, username: str, password: str, remote_ip: str) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - return [ - "network", - "node", - node_name, - "service", - "Terminal", - "ssh_to_remote", - username, - password, - remote_ip, - ] - - -class NodeSessionsRemoteLogoutAction(AbstractAction): - """Action which performs a remote session logout.""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: str, remote_ip: str) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - return ["network", "node", node_name, "service", "Terminal", "remote_logoff", remote_ip] - - -class RansomwareConfigureC2ServerAction(AbstractAction): - """Action which sends a command from the C2 Server to the C2 Beacon which configures a local RansomwareScript.""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: int, config: Dict) -> RequestFormat: - """Return the action formatted as a request that can be ingested by the simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: - return ["do_nothing"] - # Using the ransomware scripts model to validate. - ConfigureRansomwareScriptAction._Opts.model_validate(config) # check that options adhere to schema - return ["network", "node", node_name, "application", "C2Server", "ransomware_configure", config] - - -class RansomwareLaunchC2ServerAction(AbstractAction): - """Action which causes the C2 Server to send a command to the C2 Beacon to launch the RansomwareScript.""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: int) -> RequestFormat: - """Return the action formatted as a request that can be ingested by the simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: - return ["do_nothing"] - # This action currently doesn't require any further configuration options. - return ["network", "node", node_name, "application", "C2Server", "ransomware_launch"] - - -class ExfiltrationC2ServerAction(AbstractAction): - """Action which exfiltrates a target file from a certain node onto the C2 beacon and then the C2 Server.""" - - class _Opts(BaseModel): - """Schema for options that can be passed to this action.""" - - username: Optional[str] - password: Optional[str] - target_ip_address: str - target_file_name: str - target_folder_name: str - exfiltration_folder_name: Optional[str] - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request( - self, - node_id: int, - account: dict, - target_ip_address: str, - target_file_name: str, - target_folder_name: str, - exfiltration_folder_name: Optional[str], - ) -> RequestFormat: - """Return the action formatted as a request that can be ingested by the simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: - return ["do_nothing"] - - command_model = { - "target_file_name": target_file_name, - "target_folder_name": target_folder_name, - "exfiltration_folder_name": exfiltration_folder_name, - "target_ip_address": target_ip_address, - "username": account["username"], - "password": account["password"], - } - ExfiltrationC2ServerAction._Opts.model_validate(command_model) - return ["network", "node", node_name, "application", "C2Server", "exfiltrate", command_model] - - -class NodeSendRemoteCommandAction(AbstractAction): - """Action which sends a terminal command to a remote node via SSH.""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: int, remote_ip: str, command: RequestFormat) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - return [ - "network", - "node", - node_name, - "service", - "Terminal", - "send_remote_command", - remote_ip, - {"command": command}, - ] - - -class TerminalC2ServerAction(AbstractAction): - """Action which causes the C2 Server to send a command to the C2 Beacon to execute the terminal command passed.""" - - class _Opts(BaseModel): - """Schema for options that can be passed to this action.""" - - commands: Union[List[RequestFormat], RequestFormat] - ip_address: Optional[str] - username: Optional[str] - password: Optional[str] - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: int, commands: List, ip_address: Optional[str], account: dict) -> RequestFormat: - """Return the action formatted as a request that can be ingested by the simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: - return ["do_nothing"] - - command_model = { - "commands": commands, - "ip_address": ip_address, - "username": account["username"], - "password": account["password"], - } - - TerminalC2ServerAction._Opts.model_validate(command_model) - return ["network", "node", node_name, "application", "C2Server", "terminal_command", command_model] - - -class RansomwareLaunchC2ServerAction(AbstractAction): - """Action which causes the C2 Server to send a command to the C2 Beacon to launch the RansomwareScript.""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: int) -> RequestFormat: - """Return the action formatted as a request that can be ingested by the simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: - return ["do_nothing"] - # This action currently doesn't require any further configuration options. - return ["network", "node", node_name, "application", "C2Server", "ransomware_launch"] - - -class ActionManager: - """Class which manages the action space for an agent.""" - - act_class_identifiers: Dict[str, type] = { - "DONOTHING": DoNothingAction, - "NODE_SERVICE_SCAN": NodeServiceScanAction, - "NODE_SERVICE_STOP": NodeServiceStopAction, - "NODE_SERVICE_START": NodeServiceStartAction, - "NODE_SERVICE_PAUSE": NodeServicePauseAction, - "NODE_SERVICE_RESUME": NodeServiceResumeAction, - "NODE_SERVICE_RESTART": NodeServiceRestartAction, - "NODE_SERVICE_DISABLE": NodeServiceDisableAction, - "NODE_SERVICE_ENABLE": NodeServiceEnableAction, - "NODE_SERVICE_FIX": NodeServiceFixAction, - "NODE_APPLICATION_EXECUTE": NodeApplicationExecuteAction, - "NODE_APPLICATION_SCAN": NodeApplicationScanAction, - "NODE_APPLICATION_CLOSE": NodeApplicationCloseAction, - "NODE_APPLICATION_FIX": NodeApplicationFixAction, - "NODE_APPLICATION_INSTALL": NodeApplicationInstallAction, - "NODE_APPLICATION_REMOVE": NodeApplicationRemoveAction, - "NODE_FILE_SCAN": NodeFileScanAction, - "NODE_FILE_CREATE": NodeFileCreateAction, - "NODE_FILE_CHECKHASH": NodeFileCheckhashAction, - "NODE_FILE_DELETE": NodeFileDeleteAction, - "NODE_FILE_REPAIR": NodeFileRepairAction, - "NODE_FILE_RESTORE": NodeFileRestoreAction, - "NODE_FILE_CORRUPT": NodeFileCorruptAction, - "NODE_FILE_ACCESS": NodeFileAccessAction, - "NODE_FOLDER_CREATE": NodeFolderCreateAction, - "NODE_FOLDER_SCAN": NodeFolderScanAction, - "NODE_FOLDER_CHECKHASH": NodeFolderCheckhashAction, - "NODE_FOLDER_REPAIR": NodeFolderRepairAction, - "NODE_FOLDER_RESTORE": NodeFolderRestoreAction, - "NODE_OS_SCAN": NodeOSScanAction, - "NODE_SHUTDOWN": NodeShutdownAction, - "NODE_STARTUP": NodeStartupAction, - "NODE_RESET": NodeResetAction, - "ROUTER_ACL_ADDRULE": RouterACLAddRuleAction, - "ROUTER_ACL_REMOVERULE": RouterACLRemoveRuleAction, - "FIREWALL_ACL_ADDRULE": FirewallACLAddRuleAction, - "FIREWALL_ACL_REMOVERULE": FirewallACLRemoveRuleAction, - "HOST_NIC_ENABLE": HostNICEnableAction, - "HOST_NIC_DISABLE": HostNICDisableAction, - "NETWORK_PORT_ENABLE": NetworkPortEnableAction, - "NETWORK_PORT_DISABLE": NetworkPortDisableAction, - "NODE_NMAP_PING_SCAN": NodeNMAPPingScanAction, - "NODE_NMAP_PORT_SCAN": NodeNMAPPortScanAction, - "NODE_NMAP_NETWORK_SERVICE_RECON": NodeNetworkServiceReconAction, - "CONFIGURE_DATABASE_CLIENT": ConfigureDatabaseClientAction, - "CONFIGURE_RANSOMWARE_SCRIPT": ConfigureRansomwareScriptAction, - "CONFIGURE_DOSBOT": ConfigureDoSBotAction, - "CONFIGURE_C2_BEACON": ConfigureC2BeaconAction, - "C2_SERVER_RANSOMWARE_LAUNCH": RansomwareLaunchC2ServerAction, - "C2_SERVER_RANSOMWARE_CONFIGURE": RansomwareConfigureC2ServerAction, - "C2_SERVER_TERMINAL_COMMAND": TerminalC2ServerAction, - "C2_SERVER_DATA_EXFILTRATE": ExfiltrationC2ServerAction, - "NODE_ACCOUNTS_CHANGE_PASSWORD": NodeAccountsChangePasswordAction, - "SSH_TO_REMOTE": NodeSessionsRemoteLoginAction, - "SESSIONS_REMOTE_LOGOFF": NodeSessionsRemoteLogoutAction, - "NODE_SEND_REMOTE_COMMAND": NodeSendRemoteCommandAction, - } - """Dictionary which maps action type strings to the corresponding action class.""" - - def __init__( - self, - actions: List[Dict], # stores list of actions available to agent - nodes: List[Dict], # extra configuration for each node - max_folders_per_node: int = 2, # allows calculating shape - max_files_per_folder: int = 2, # allows calculating shape - max_services_per_node: int = 2, # allows calculating shape - max_applications_per_node: int = 2, # allows calculating shape - max_nics_per_node: int = 8, # allows calculating shape - max_acl_rules: int = 10, # allows calculating shape - protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol - ports: List[str] = ["HTTP", "DNS", "ARP", "FTP", "NTP"], # allow mapping index to port - ip_list: List[str] = [], # to allow us to map an index to an ip address. - wildcard_list: List[str] = [], # to allow mapping from wildcard index to - act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions - ) -> None: - """Init method for ActionManager. - - :param game: Reference to the game to which the agent belongs. - :type game: PrimaiteGame - :param actions: List of action specs which should be made available to the agent. The keys of each spec are: - 'type' and 'options' for passing any options to the action class's init method - :type actions: List[dict] - :param nodes: Extra configuration for each node. - :type nodes: List[Dict] - :param max_folders_per_node: Maximum number of folders per node. Used for calculating action shape. - :type max_folders_per_node: int - :param max_files_per_folder: Maximum number of files per folder. Used for calculating action shape. - :type max_files_per_folder: int - :param max_services_per_node: Maximum number of services per node. Used for calculating action shape. - :type max_services_per_node: int - :param max_nics_per_node: Maximum number of NICs per node. Used for calculating action shape. - :type max_nics_per_node: int - :param max_acl_rules: Maximum number of ACL rules per router. Used for calculating action shape. - :type max_acl_rules: int - :param protocols: List of protocols that are available in the simulation. Used for calculating action shape. - :type protocols: List[str] - :param ports: List of ports that are available in the simulation. Used for calculating action shape. - :type ports: List[str] - :param ip_list: List of IP addresses that known to this agent. Used for calculating action shape. - :type ip_list: Optional[List[str]] - :param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions. - :type act_map: Optional[Dict[int, Dict]] - """ - self.node_names: List[str] = [n["node_name"] for n in nodes] - """List of node names in this action space. The list order is the mapping between node index and node name.""" - self.application_names: List[List[str]] = [] - """ - List of applications per node. The list order gives the two-index mapping between (node_id, app_id) to app name. - The first index corresponds to node id, the second index is the app id on that particular node. - For instance, self.application_names[0][2] is the name of the third application on the first node. - """ - self.service_names: List[List[str]] = [] - """ - List of services per node. The list order gives the two-index mapping between (node_id, svc_id) to svc name. - The first index corresponds to node id, the second index is the service id on that particular node. - For instance, self.service_names[0][2] is the name of the third service on the first node. - """ - self.folder_names: List[List[str]] = [] - """ - List of folders per node. The list order gives the two-index mapping between (node_id, folder_id) to folder - name. The first index corresponds to node id, the second index is the folder id on that particular node. - For instance, self.folder_names[0][2] is the name of the third folder on the first node. - """ - self.file_names: List[List[List[str]]] = [] - """ - List of files per folder per node. The list order gives the three-index mapping between - (node_id, folder_id, file_id) to file name. The first index corresponds to node id, the second index is the - folder id on that particular node, and the third index is the file id in that particular folder. - For instance, self.file_names[0][2][1] is the name of the second file in the third folder on the first node. - """ - - # Populate lists of apps, services, files, folders, etc on nodes. - for node in nodes: - app_list = [a["application_name"] for a in node.get("applications", [])] - while len(app_list) < max_applications_per_node: - app_list.append(None) - self.application_names.append(app_list) - - svc_list = [s["service_name"] for s in node.get("services", [])] - while len(svc_list) < max_services_per_node: - svc_list.append(None) - self.service_names.append(svc_list) - - folder_list = [f["folder_name"] for f in node.get("folders", [])] - while len(folder_list) < max_folders_per_node: - folder_list.append(None) - self.folder_names.append(folder_list) - - file_sublist = [] - for folder in node.get("folders", [{"files": []}]): - file_list = [f["file_name"] for f in folder.get("files", [])] - while len(file_list) < max_files_per_folder: - file_list.append(None) - file_sublist.append(file_list) - while len(file_sublist) < max_folders_per_node: - file_sublist.append([None] * max_files_per_folder) - self.file_names.append(file_sublist) - self.protocols: List[str] = protocols - self.ports: List[str] = ports - - self.ip_address_list: List[str] = ip_list - self.wildcard_list: List[str] = wildcard_list - if self.wildcard_list == []: - self.wildcard_list = ["NONE"] - # action_args are settings which are applied to the action space as a whole. - global_action_args = { - "num_nodes": len(self.node_names), - "num_folders": max_folders_per_node, - "num_files": max_files_per_folder, - "num_services": max_services_per_node, - "num_applications": max_applications_per_node, - "num_nics": max_nics_per_node, - "num_acl_rules": max_acl_rules, - "num_protocols": len(self.protocols), - "num_ports": len(self.protocols), - "num_ips": len(self.ip_address_list), - "max_acl_rules": max_acl_rules, - "max_nics_per_node": max_nics_per_node, - } - self.actions: Dict[str, AbstractAction] = {} - for act_spec in actions: - # each action is provided into the action space config like this: - # - type: ACTION_TYPE - # options: - # option_1: value1 - # option_2: value2 - # where `type` decides which AbstractAction subclass should be used - # and `options` is an optional dict of options to pass to the init method of the action class - act_type = act_spec.get("type") - act_options = act_spec.get("options", {}) - self.actions[act_type] = self.act_class_identifiers[act_type](self, **global_action_args, **act_options) - - self.action_map: Dict[int, Tuple[str, Dict]] = {} - """ - Action mapping that converts an integer to a specific action and parameter choice. - - For example : - {0: ("NODE_SERVICE_SCAN", {node_id:0, service_id:2})} - """ - if act_map is None: - # raise RuntimeError("Action map must be specified in the config file.") - pass - else: - self.action_map = {i: (a["action"], a["options"]) for i, a in act_map.items()} - # make sure all numbers between 0 and N are represented as dict keys in action map - assert all([i in self.action_map.keys() for i in range(len(self.action_map))]) - - def _enumerate_actions( - self, - ) -> Dict[int, Tuple[str, Dict]]: - """Generate a list of all the possible actions that could be taken. - - This enumerates all actions all combinations of parameters you could choose for those actions. The output - of this function is intended to populate the self.action_map parameter in the situation where the user provides - a list of action types, but doesn't specify any subset of actions that should be made available to the agent. - - The enumeration relies on the Actions' `shape` attribute. - - :return: An action map maps consecutive integers to a combination of Action type and parameter choices. - An example output could be: - {0: ("DONOTHING", {'dummy': 0}), - 1: ("NODE_OS_SCAN", {'node_id': 0}), - 2: ("NODE_OS_SCAN", {'node_id': 1}), - 3: ("NODE_FOLDER_SCAN", {'node_id:0, folder_id:0}), - ... #etc... - } - :rtype: Dict[int, Tuple[AbstractAction, Dict]] - """ - all_action_possibilities = [] - for act_name, action in self.actions.items(): - param_names = list(action.shape.keys()) - num_possibilities = list(action.shape.values()) - possibilities = [range(n) for n in num_possibilities] - - param_combinations = list(itertools.product(*possibilities)) - all_action_possibilities.extend( - [ - (act_name, {param_names[i]: param_combinations[j][i] for i in range(len(param_names))}) - for j in range(len(param_combinations)) - ] - ) - - return {i: p for i, p in enumerate(all_action_possibilities)} - - def get_action(self, action: int) -> Tuple[str, Dict]: - """Produce action in CAOS format.""" - """the agent chooses an action (as an integer), this is converted into an action in CAOS format""" - """The CAOS format is basically a action identifier, followed by parameters stored in a dictionary""" - act_identifier, act_options = self.action_map[action] - return act_identifier, act_options - - def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat: - """Take action in CAOS format and use the execution definition to change it into PrimAITE request format.""" - act_obj = self.actions[action_identifier] - return act_obj.form_request(**action_options) - - @property - def space(self) -> spaces.Space: - """Return the gymnasium action space for this agent.""" - return spaces.Discrete(len(self.action_map)) - - def get_node_name_by_idx(self, node_idx: int) -> str: - """ - Get the node name corresponding to the given index. - - :param node_idx: The index of the node to retrieve. - :type node_idx: int - :return: The node hostname. - :rtype: str - """ - if not node_idx < len(self.node_names): - msg = ( - f"Error: agent attempted to perform an action on node {node_idx}, but its action space only" - f"has {len(self.node_names)} nodes." - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.node_names[node_idx] - - def get_folder_name_by_idx(self, node_idx: int, folder_idx: int) -> Optional[str]: - """ - Get the folder name corresponding to the given node and folder indices. - - :param node_idx: The index of the node. - :type node_idx: int - :param folder_idx: The index of the folder on the node. - :type folder_idx: int - :return: The name of the folder. Or None if the node has fewer folders than the given index. - :rtype: Optional[str] - """ - if node_idx >= len(self.folder_names) or folder_idx >= len(self.folder_names[node_idx]): - msg = ( - f"Error: agent attempted to perform an action on node {node_idx} and folder {folder_idx}, but this" - f" is out of range for its action space. Folder on each node: {self.folder_names}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.folder_names[node_idx][folder_idx] - - def get_file_name_by_idx(self, node_idx: int, folder_idx: int, file_idx: int) -> Optional[str]: - """Get the file name corresponding to the given node, folder, and file indices. - - :param node_idx: The index of the node. - :type node_idx: int - :param folder_idx: The index of the folder on the node. - :type folder_idx: int - :param file_idx: The index of the file in the folder. - :type file_idx: int - :return: The name of the file. Or None if the node has fewer folders than the given index, or the folder has - fewer files than the given index. - :rtype: Optional[str] - """ - if ( - node_idx >= len(self.file_names) - or folder_idx >= len(self.file_names[node_idx]) - or file_idx >= len(self.file_names[node_idx][folder_idx]) - ): - msg = ( - f"Error: agent attempted to perform an action on node {node_idx} folder {folder_idx} file {file_idx}" - f" but this is out of range for its action space. Files on each node: {self.file_names}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.file_names[node_idx][folder_idx][file_idx] - - def get_service_name_by_idx(self, node_idx: int, service_idx: int) -> Optional[str]: - """Get the service name corresponding to the given node and service indices. - - :param node_idx: The index of the node. - :type node_idx: int - :param service_idx: The index of the service on the node. - :type service_idx: int - :return: The name of the service. Or None if the node has fewer services than the given index. - :rtype: Optional[str] - """ - if node_idx >= len(self.service_names) or service_idx >= len(self.service_names[node_idx]): - msg = ( - f"Error: agent attempted to perform an action on node {node_idx} and service {service_idx}, but this" - f" is out of range for its action space. Services on each node: {self.service_names}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.service_names[node_idx][service_idx] - - def get_application_name_by_idx(self, node_idx: int, application_idx: int) -> Optional[str]: - """Get the application name corresponding to the given node and service indices. - - :param node_idx: The index of the node. - :type node_idx: int - :param application_idx: The index of the service on the node. - :type application_idx: int - :return: The name of the service. Or None if the node has fewer services than the given index. - :rtype: Optional[str] - """ - if node_idx >= len(self.application_names) or application_idx >= len(self.application_names[node_idx]): - msg = ( - f"Error: agent attempted to perform an action on node {node_idx} and app {application_idx}, but " - f"this is out of range for its action space. Applications on each node: {self.application_names}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.application_names[node_idx][application_idx] - - def get_internet_protocol_by_idx(self, protocol_idx: int) -> str: - """Get the internet protocol corresponding to the given index. - - :param protocol_idx: The index of the protocol to retrieve. - :type protocol_idx: int - :return: The protocol. - :rtype: str - """ - if protocol_idx >= len(self.protocols): - msg = ( - f"Error: agent attempted to perform an action on protocol {protocol_idx} but this" - f" is out of range for its action space. Protocols: {self.protocols}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.protocols[protocol_idx] - - def get_ip_address_by_idx(self, ip_idx: int) -> str: - """ - Get the IP address corresponding to the given index. - - :param ip_idx: The index of the IP address to retrieve. - :type ip_idx: int - :return: The IP address. - :rtype: str - """ - if ip_idx >= len(self.ip_address_list): - msg = ( - f"Error: agent attempted to perform an action on ip address {ip_idx} but this" - f" is out of range for its action space. IP address list: {self.ip_address_list}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.ip_address_list[ip_idx] - - def get_wildcard_by_idx(self, wildcard_idx: int) -> str: - """ - Get the IP wildcard corresponding to the given index. - - :param ip_idx: The index of the IP wildcard to retrieve. - :type ip_idx: int - :return: The wildcard address. - :rtype: str - """ - if wildcard_idx >= len(self.wildcard_list): - msg = ( - f"Error: agent attempted to perform an action on ip wildcard {wildcard_idx} but this" - f" is out of range for its action space. Wildcard list: {self.wildcard_list}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.wildcard_list[wildcard_idx] - - def get_port_by_idx(self, port_idx: int) -> str: - """ - Get the port corresponding to the given index. - - :param port_idx: The index of the port to retrieve. - :type port_idx: int - :return: The port. - :rtype: str - """ - if port_idx >= len(self.ports): - msg = ( - f"Error: agent attempted to perform an action on port {port_idx} but this" - f" is out of range for its action space. Port list: {self.ip_address_list}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.ports[port_idx] - - def get_nic_num_by_idx(self, node_idx: int, nic_idx: int) -> int: - """ - Get the NIC number corresponding to the given node and NIC indices. - - :param node_idx: The index of the node. - :type node_idx: int - :param nic_idx: The index of the NIC on the node. - :type nic_idx: int - :return: The NIC number. - :rtype: int - """ - return nic_idx + 1 - - @classmethod - def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager": - """ - Construct an ActionManager from a config definition. - - The action space config supports the following three sections: - 1. ``action_list`` - ``action_list`` contains a list action components which need to be included in the action space. - Each action component has a ``type`` which maps to a subclass of AbstractAction, and additional options - which will be passed to the action class's __init__ method during initialisation. - 2. ``action_map`` - Since the agent uses a discrete action space which acts as a flattened version of the component-based - action space, action_map provides a mapping between an integer (chosen by the agent) and a meaningful - action and values of parameters. For example action 0 can correspond to do nothing, action 1 can - correspond to "NODE_SERVICE_SCAN" with ``node_id=1`` and ``service_id=1``, action 2 can be " - 3. ``options`` - ``options`` contains a dictionary of options which are passed to the ActionManager's __init__ method. - These options are used to calculate the shape of the action space, and to provide additional information - to the ActionManager which is required to convert the agent's action choice into a CAOS request. - - :param game: The Primaite Game to which the agent belongs. - :type game: PrimaiteGame - :param cfg: The action space config. - :type cfg: Dict - :return: The constructed ActionManager. - :rtype: ActionManager - """ - if "ip_list" not in cfg["options"]: - cfg["options"]["ip_list"] = [] - - obj = cls( - actions=cfg["action_list"], - **cfg["options"], - protocols=game.options.protocols, - ports=game.options.ports, - act_map=cfg.get("action_map"), - ) - - return obj diff --git a/src/primaite/game/agent/actions/__init__.py b/src/primaite/game/agent/actions/__init__.py new file mode 100644 index 00000000..8517ded8 --- /dev/null +++ b/src/primaite/game/agent/actions/__init__.py @@ -0,0 +1,33 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK + +from primaite.game.agent.actions import ( + abstract, + acl, + application, + file, + folder, + host_nic, + manager, + network, + node, + service, + session, + software, +) +from primaite.game.agent.actions.manager import ActionManager + +__all__ = ( + "abstract", + "acl", + "application", + "software", + "file", + "folder", + "host_nic", + "manager", + "network", + "node", + "service", + "session", + "ActionManager", +) diff --git a/src/primaite/game/agent/actions/abstract.py b/src/primaite/game/agent/actions/abstract.py new file mode 100644 index 00000000..c570119b --- /dev/null +++ b/src/primaite/game/agent/actions/abstract.py @@ -0,0 +1,36 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from __future__ import annotations + +from abc import ABC +from typing import Any, ClassVar, Dict, Optional, Type + +from pydantic import BaseModel, ConfigDict + +from primaite.interface.request import RequestFormat + + +class AbstractAction(BaseModel, ABC): + """Base class for actions.""" + + config: "AbstractAction.ConfigSchema" + + class ConfigSchema(BaseModel, ABC): + """Base configuration schema for Actions.""" + + model_config = ConfigDict(extra="forbid") + type: str + + _registry: ClassVar[Dict[str, Type[AbstractAction]]] = {} + + def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + if identifier is None: + return + if identifier in cls._registry: + raise ValueError(f"Cannot create new action under reserved name {identifier}") + cls._registry[identifier] = cls + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + pass diff --git a/src/primaite/game/agent/actions/acl.py b/src/primaite/game/agent/actions/acl.py new file mode 100644 index 00000000..6022f697 --- /dev/null +++ b/src/primaite/game/agent/actions/acl.py @@ -0,0 +1,188 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from __future__ import annotations + +from abc import ABC +from typing import List + +from pydantic import field_validator + +from primaite.game.agent.actions.manager import AbstractAction +from primaite.interface.request import RequestFormat +from primaite.utils.validation.ip_protocol import protocol_validator +from primaite.utils.validation.ipv4_address import ipv4_validator, IPV4Address +from primaite.utils.validation.port import port_validator + +__all__ = ( + "RouterACLAddRuleAction", + "RouterACLRemoveRuleAction", + "FirewallACLAddRuleAction", + "FirewallACLRemoveRuleAction", +) + + +class ACLAddRuleAbstractAction(AbstractAction, ABC): + """Base abstract class for ACL add rule actions.""" + + config: ConfigSchema = "ACLAddRuleAbstractAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Configuration Schema base for ACL add rule abstract actions.""" + + src_ip: IPV4Address + protocol_name: str + permission: str + position: int + dst_ip: IPV4Address + src_port: int + dst_port: int + src_wildcard: int + dst_wildcard: int + + @field_validator( + "src_port", + "dst_port", + mode="before", + ) + @classmethod + def valid_port(cls, v: str) -> int: + """Check that inputs are valid.""" + return port_validator(v) + + @field_validator( + "src_ip", + "dst_ip", + mode="before", + ) + @classmethod + def valid_ip(cls, v: str) -> str: + """Check that a valid IP has been provided for src and dst.""" + return ipv4_validator(v) + + @field_validator( + "protocol_name", + mode="before", + ) + @classmethod + def is_valid_protocol(cls, v: str) -> bool: + """Check that we are using a valid protocol.""" + return protocol_validator(v) + + +class ACLRemoveRuleAbstractAction(AbstractAction, identifier="acl_remove_rule_abstract_action"): + """Base abstract class for ACL remove rule actions.""" + + config: ConfigSchema = "ACLRemoveRuleAbstractAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Configuration Schema base for ACL remove rule abstract actions.""" + + position: int + + +class RouterACLAddRuleAction(ACLAddRuleAbstractAction, identifier="router_acl_add_rule"): + """Action which adds a rule to a router's ACL.""" + + config: "RouterACLAddRuleAction.ConfigSchema" + + class ConfigSchema(ACLAddRuleAbstractAction.ConfigSchema): + """Configuration Schema for RouterACLAddRuleAction.""" + + target_router: str + + @classmethod + def form_request(cls, config: ConfigSchema) -> List[str]: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.target_router, + "acl", + "add_rule", + config.permission, + config.protocol_name, + config.src_ip, + config.src_wildcard, + config.src_port, + config.dst_ip, + config.dst_wildcard, + config.dst_port, + config.position, + ] + + +class RouterACLRemoveRuleAction(ACLRemoveRuleAbstractAction, identifier="router_acl_remove_rule"): + """Action which removes a rule from a router's ACL.""" + + config: "RouterACLRemoveRuleAction.ConfigSchema" + + class ConfigSchema(ACLRemoveRuleAbstractAction.ConfigSchema): + """Configuration schema for RouterACLRemoveRuleAction.""" + + target_router: str + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return ["network", "node", config.target_router, "acl", "remove_rule", config.position] + + +class FirewallACLAddRuleAction(ACLAddRuleAbstractAction, identifier="firewall_acl_add_rule"): + """Action which adds a rule to a firewall port's ACL.""" + + config: "FirewallACLAddRuleAction.ConfigSchema" + + class ConfigSchema(ACLAddRuleAbstractAction.ConfigSchema): + """Configuration schema for FirewallACLAddRuleAction.""" + + target_firewall_nodename: str + firewall_port_name: str + firewall_port_direction: str + + @classmethod + def form_request(cls, config: ConfigSchema) -> List[str]: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.target_firewall_nodename, + config.firewall_port_name, + config.firewall_port_direction, + "acl", + "add_rule", + config.permission, + config.protocol_name, + config.src_ip, + config.src_wildcard, + config.src_port, + config.dst_ip, + config.dst_wildcard, + config.dst_port, + config.position, + ] + + +class FirewallACLRemoveRuleAction(ACLRemoveRuleAbstractAction, identifier="firewall_acl_remove_rule"): + """Action which removes a rule from a firewall port's ACL.""" + + config: "FirewallACLRemoveRuleAction.ConfigSchema" + + class ConfigSchema(ACLRemoveRuleAbstractAction.ConfigSchema): + """Configuration schema for FirewallACLRemoveRuleAction.""" + + target_firewall_nodename: str + firewall_port_name: str + firewall_port_direction: str + + @classmethod + def form_request(cls, config: ConfigSchema) -> List[str]: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.target_firewall_nodename, + config.firewall_port_name, + config.firewall_port_direction, + "acl", + "remove_rule", + config.position, + ] diff --git a/src/primaite/game/agent/actions/application.py b/src/primaite/game/agent/actions/application.py new file mode 100644 index 00000000..223effc4 --- /dev/null +++ b/src/primaite/game/agent/actions/application.py @@ -0,0 +1,137 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from abc import ABC +from typing import ClassVar + +from primaite.game.agent.actions.abstract import AbstractAction +from primaite.interface.request import RequestFormat + +__all__ = ( + "NodeApplicationExecuteAction", + "NodeApplicationScanAction", + "NodeApplicationCloseAction", + "NodeApplicationFixAction", + "NodeApplicationInstallAction", + "NodeApplicationRemoveAction", +) + + +class NodeApplicationAbstractAction(AbstractAction, ABC): + """ + Base class for application actions. + + Any action which applies to an application and uses node_id and application_id as its only two parameters can + inherit from this base class. + """ + + config: "NodeApplicationAbstractAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Base Configuration schema for Node Application actions.""" + + node_name: str + application_name: str + verb: ClassVar[str] + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.node_name, + "application", + config.application_name, + config.verb, + ] + + +class NodeApplicationExecuteAction(NodeApplicationAbstractAction, identifier="node_application_execute"): + """Action which executes an application.""" + + config: "NodeApplicationExecuteAction.ConfigSchema" + + class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema): + """Configuration schema for NodeApplicationExecuteAction.""" + + verb: str = "execute" + + +class NodeApplicationScanAction(NodeApplicationAbstractAction, identifier="node_application_scan"): + """Action which scans an application.""" + + config: "NodeApplicationScanAction.ConfigSchema" + + class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema): + """Configuration schema for NodeApplicationScanAction.""" + + verb: str = "scan" + + +class NodeApplicationCloseAction(NodeApplicationAbstractAction, identifier="node_application_close"): + """Action which closes an application.""" + + config: "NodeApplicationCloseAction.ConfigSchema" + + class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema): + """Configuration schema for NodeApplicationCloseAction.""" + + verb: str = "close" + + +class NodeApplicationFixAction(NodeApplicationAbstractAction, identifier="node_application_fix"): + """Action which fixes an application.""" + + config: "NodeApplicationFixAction.ConfigSchema" + + class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema): + """Configuration schema for NodeApplicationFixAction.""" + + verb: str = "fix" + + +class NodeApplicationInstallAction(NodeApplicationAbstractAction, identifier="node_application_install"): + """Action which installs an application.""" + + config: "NodeApplicationInstallAction.ConfigSchema" + + class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema): + """Configuration schema for NodeApplicationInstallAction.""" + + verb: str = "install" + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.node_name, + "software_manager", + "application", + config.verb, + config.application_name, + ] + + +class NodeApplicationRemoveAction(NodeApplicationAbstractAction, identifier="node_application_remove"): + """Action which removes/uninstalls an application.""" + + config: "NodeApplicationRemoveAction.ConfigSchema" + + class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema): + """Configuration schema for NodeApplicationRemoveAction.""" + + verb: str = "uninstall" + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.node_name, + "software_manager", + "application", + config.verb, + config.application_name, + ] diff --git a/src/primaite/game/agent/actions/file.py b/src/primaite/game/agent/actions/file.py new file mode 100644 index 00000000..ed666773 --- /dev/null +++ b/src/primaite/game/agent/actions/file.py @@ -0,0 +1,189 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from abc import ABC +from typing import ClassVar + +from primaite.game.agent.actions.manager import AbstractAction +from primaite.interface.request import RequestFormat + +__all__ = ( + "NodeFileCreateAction", + "NodeFileScanAction", + "NodeFileDeleteAction", + "NodeFileRestoreAction", + "NodeFileCorruptAction", + "NodeFileAccessAction", + "NodeFileCheckhashAction", + "NodeFileRepairAction", +) + + +class NodeFileAbstractAction(AbstractAction, ABC): + """Abstract base class for file actions. + + Any action which applies to a file and uses node_name, folder_name, and file_name as its + only three parameters can inherit from this base class. + """ + + config: "NodeFileAbstractAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Configuration Schema for NodeFileAbstractAction.""" + + node_name: str + folder_name: str + file_name: str + verb: ClassVar[str] + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + if config.node_name is None or config.folder_name is None or config.file_name is None: + return ["do_nothing"] + return [ + "network", + "node", + config.node_name, + "file_system", + "folder", + config.folder_name, + "file", + config.file_name, + config.verb, + ] + + +class NodeFileCreateAction(NodeFileAbstractAction, identifier="node_file_create"): + """Action which creates a new file in a given folder.""" + + config: "NodeFileCreateAction.ConfigSchema" + + class ConfigSchema(NodeFileAbstractAction.ConfigSchema): + """Configuration schema for NodeFileCreateAction.""" + + verb: ClassVar[str] = "create" + force: bool = False + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + if config.node_name is None or config.folder_name is None or config.file_name is None: + return ["do_nothing"] + return [ + "network", + "node", + config.node_name, + "file_system", + config.verb, + "file", + config.folder_name, + config.file_name, + config.verb, + ] + + +class NodeFileScanAction(NodeFileAbstractAction, identifier="node_file_scan"): + """Action which scans a file.""" + + config: "NodeFileScanAction.ConfigSchema" + + class ConfigSchema(NodeFileAbstractAction.ConfigSchema): + """Configuration schema for NodeFileScanAction.""" + + verb: ClassVar[str] = "scan" + + +class NodeFileDeleteAction(NodeFileAbstractAction, identifier="node_file_delete"): + """Action which deletes a file.""" + + config: "NodeFileDeleteAction.ConfigSchema" + + class ConfigSchema(NodeFileAbstractAction.ConfigSchema): + """Configuration schema for NodeFileDeleteAction.""" + + verb: ClassVar[str] = "delete" + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + if config.node_name is None or config.folder_name is None or config.file_name is None: + return ["do_nothing"] + return [ + "network", + "node", + config.node_name, + "file_system", + config.verb, + "file", + config.folder_name, + config.file_name, + ] + + +class NodeFileRestoreAction(NodeFileAbstractAction, identifier="node_file_restore"): + """Action which restores a file.""" + + config: "NodeFileRestoreAction.ConfigSchema" + + class ConfigSchema(NodeFileAbstractAction.ConfigSchema): + """Configuration schema for NodeFileRestoreAction.""" + + verb: ClassVar[str] = "restore" + + +class NodeFileCorruptAction(NodeFileAbstractAction, identifier="node_file_corrupt"): + """Action which corrupts a file.""" + + config: "NodeFileCorruptAction.ConfigSchema" + + class ConfigSchema(NodeFileAbstractAction.ConfigSchema): + """Configuration schema for NodeFileCorruptAction.""" + + verb: ClassVar[str] = "corrupt" + + +class NodeFileAccessAction(NodeFileAbstractAction, identifier="node_file_access"): + """Action which increases a file's access count.""" + + config: "NodeFileAccessAction.ConfigSchema" + + class ConfigSchema(NodeFileAbstractAction.ConfigSchema): + """Configuration schema for NodeFileAccessAction.""" + + verb: ClassVar[str] = "access" + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + if config.node_name is None or config.folder_name is None or config.file_name is None: + return ["do_nothing"] + return [ + "network", + "node", + config.node_name, + "file_system", + config.verb, + config.folder_name, + config.file_name, + ] + + +class NodeFileCheckhashAction(NodeFileAbstractAction, identifier="node_file_checkhash"): + """Action which checks the hash of a file.""" + + config: "NodeFileCheckhashAction.ConfigSchema" + + class ConfigSchema(NodeFileAbstractAction.ConfigSchema): + """Configuration schema for NodeFileCheckhashAction.""" + + verb: ClassVar[str] = "checkhash" + + +class NodeFileRepairAction(NodeFileAbstractAction, identifier="node_file_repair"): + """Action which repairs a file.""" + + config: "NodeFileRepairAction.ConfigSchema" + + class ConfigSchema(NodeFileAbstractAction.ConfigSchema): + """Configuration Schema for NodeFileRepairAction.""" + + verb: ClassVar[str] = "repair" diff --git a/src/primaite/game/agent/actions/folder.py b/src/primaite/game/agent/actions/folder.py new file mode 100644 index 00000000..3e1136ac --- /dev/null +++ b/src/primaite/game/agent/actions/folder.py @@ -0,0 +1,117 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from abc import ABC +from typing import ClassVar + +from primaite.game.agent.actions.manager import AbstractAction +from primaite.interface.request import RequestFormat + +__all__ = ( + "NodeFolderScanAction", + "NodeFolderCheckhashAction", + "NodeFolderRepairAction", + "NodeFolderRestoreAction", + "NodeFolderCreateAction", +) + + +class NodeFolderAbstractAction(AbstractAction, ABC): + """ + Base class for folder actions. + + Any action which applies to a folder and uses node_name and folder_name as its only two parameters can inherit from + this base class. + """ + + config: "NodeFolderAbstractAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Base configuration schema for NodeFolder actions.""" + + node_name: str + folder_name: str + verb: ClassVar[str] + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + if config.node_name is None or config.folder_name is None: + return ["do_nothing"] + return [ + "network", + "node", + config.node_name, + "file_system", + "folder", + config.folder_name, + config.verb, + ] + + +class NodeFolderScanAction(NodeFolderAbstractAction, identifier="node_folder_scan"): + """Action which scans a folder.""" + + config: "NodeFolderScanAction.ConfigSchema" + + class ConfigSchema(NodeFolderAbstractAction.ConfigSchema): + """Configuration schema for NodeFolderScanAction.""" + + verb: ClassVar[str] = "scan" + + +class NodeFolderCheckhashAction(NodeFolderAbstractAction, identifier="node_folder_checkhash"): + """Action which checks the hash of a folder.""" + + config: "NodeFolderCheckhashAction.ConfigSchema" + + class ConfigSchema(NodeFolderAbstractAction.ConfigSchema): + """Configuration schema for NodeFolderCheckhashAction.""" + + verb: ClassVar[str] = "checkhash" + + +class NodeFolderRepairAction(NodeFolderAbstractAction, identifier="node_folder_repair"): + """Action which repairs a folder.""" + + config: "NodeFolderRepairAction.ConfigSchema" + + class ConfigSchema(NodeFolderAbstractAction.ConfigSchema): + """Configuration schema for NodeFolderRepairAction.""" + + verb: ClassVar[str] = "repair" + + +class NodeFolderRestoreAction(NodeFolderAbstractAction, identifier="node_folder_restore"): + """Action which restores a folder.""" + + config: "NodeFolderRestoreAction.ConfigSchema" + + class ConfigSchema(NodeFolderAbstractAction.ConfigSchema): + """Configuration schema for NodeFolderRestoreAction.""" + + verb: ClassVar[str] = "restore" + + +class NodeFolderCreateAction(NodeFolderAbstractAction, identifier="node_folder_create"): + """Action which creates a new folder.""" + + config: "NodeFolderCreateAction.ConfigSchema" + + class ConfigSchema(NodeFolderAbstractAction.ConfigSchema): + """Configuration schema for NodeFolderCreateAction.""" + + verb: ClassVar[str] = "create" + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + if config.node_name is None or config.folder_name is None: + return ["do_nothing"] + return [ + "network", + "node", + config.node_name, + "file_system", + config.verb, + "folder", + config.folder_name, + ] diff --git a/src/primaite/game/agent/actions/host_nic.py b/src/primaite/game/agent/actions/host_nic.py new file mode 100644 index 00000000..b9206b9c --- /dev/null +++ b/src/primaite/game/agent/actions/host_nic.py @@ -0,0 +1,62 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from abc import ABC +from typing import ClassVar + +from primaite.game.agent.actions.manager import AbstractAction +from primaite.interface.request import RequestFormat + +__all__ = ("HostNICEnableAction", "HostNICDisableAction") + + +class HostNICAbstractAction(AbstractAction, ABC): + """ + Abstract base class for NIC actions. + + Any action which applies to a NIC and uses node_id and nic_id as its only two parameters can inherit from this base + class. + """ + + config: "HostNICAbstractAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Base Configuration schema for HostNIC actions.""" + + node_name: str + nic_num: int + verb: ClassVar[str] + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + if config.node_name is None or config.nic_num is None: + return ["do_nothing"] + return [ + "network", + "node", + config.node_name, + "network_interface", + config.nic_num, + config.verb, + ] + + +class HostNICEnableAction(HostNICAbstractAction, identifier="host_nic_enable"): + """Action which enables a NIC.""" + + config: "HostNICEnableAction.ConfigSchema" + + class ConfigSchema(HostNICAbstractAction.ConfigSchema): + """Configuration schema for HostNICEnableAction.""" + + verb: ClassVar[str] = "enable" + + +class HostNICDisableAction(HostNICAbstractAction, identifier="host_nic_disable"): + """Action which disables a NIC.""" + + config: "HostNICDisableAction.ConfigSchema" + + class ConfigSchema(HostNICAbstractAction.ConfigSchema): + """Configuration schema for HostNICDisableAction.""" + + verb: ClassVar[str] = "disable" diff --git a/src/primaite/game/agent/actions/manager.py b/src/primaite/game/agent/actions/manager.py new file mode 100644 index 00000000..c3e14379 --- /dev/null +++ b/src/primaite/game/agent/actions/manager.py @@ -0,0 +1,138 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +"""yaml example. + +agents: + - name: agent_1 + action_space: + actions: + - do_nothing + - node_service_start + - node_service_stop + action_map: +""" + +from __future__ import annotations + +from typing import Dict, List, Optional, Tuple + +from gymnasium import spaces + +# from primaite.game.game import PrimaiteGame # TODO: Breaks things +from primaite.game.agent.actions.abstract import AbstractAction +from primaite.interface.request import RequestFormat + +__all__ = ("DoNothingAction", "ActionManager") + + +class DoNothingAction(AbstractAction, identifier="do_nothing"): + """Do Nothing Action.""" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Configuration Schema for do_nothingAction.""" + + type: str = "do_nothing" + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return ["do_nothing"] + + +class ActionManager: + """Class which manages the action space for an agent.""" + + def __init__( + self, + actions: List[Dict], # stores list of actions available to agent + nodes: List[Dict], # extra configuration for each node + act_map: Optional[ + Dict[int, Dict] + ] = None, # allows restricting set of possible actions - TODO: Refactor to be a list? + *args, + **kwargs, + ) -> None: + """Init method for ActionManager. + + :param game: Reference to the game to which the agent belongs. + :type game: PrimaiteGame + :param actions: List of action specs which should be made available to the agent. The keys of each spec are: + 'type' and 'options' for passing any options to the action class's init method + :type actions: List[dict] + :param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions. + :type act_map: Optional[Dict[int, Dict]] + """ + self.actions: Dict[str, AbstractAction] = {} + for act_spec in actions: + act_type = act_spec.get("type") + self.actions[act_type] = AbstractAction._registry[act_type] + + self.action_map: Dict[int, Tuple[str, Dict]] = {} + """ + Action mapping that converts an integer to a specific action and parameter choice. + + For example : + {0: ("node_service_scan", {node_name:"client_1", service_name:"WebBrowser"})} + """ + if act_map is None: + # raise RuntimeError("Action map must be specified in the config file.") + pass + else: + self.action_map = {i: (a["action"], a["options"]) for i, a in act_map.items()} + # make sure all numbers between 0 and N are represented as dict keys in action map + assert all([i in self.action_map.keys() for i in range(len(self.action_map))]) + + def get_action(self, action: int) -> Tuple[str, Dict]: + """Produce action in CAOS format.""" + """the agent chooses an action (as an integer), this is converted into an action in CAOS format""" + """The CAOS format is basically a action identifier, followed by parameters stored in a dictionary""" + act_identifier, act_options = self.action_map[action] + return act_identifier, act_options + + def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat: + """Take action in CAOS format and use the execution definition to change it into PrimAITE request format.""" + act_class = AbstractAction._registry[action_identifier] + config = act_class.ConfigSchema(**action_options) + return act_class.form_request(config=config) + + @property + def space(self) -> spaces.Space: + """Return the gymnasium action space for this agent.""" + return spaces.Discrete(len(self.action_map)) + + @classmethod + def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager": # noqa: F821 + """ + Construct an ActionManager from a config definition. + + The action space config supports the following three sections: + 1. ``action_list`` + ``action_list`` contains a list action components which need to be included in the action space. + Each action component has a ``type`` which maps to a subclass of AbstractAction, and additional options + which will be passed to the action class's __init__ method during initialisation. + 2. ``action_map`` + Since the agent uses a discrete action space which acts as a flattened version of the component-based + action space, action_map provides a mapping between an integer (chosen by the agent) and a meaningful + action and values of parameters. For example action 0 can correspond to do nothing, action 1 can + correspond to "node_service_scan" with ``node_name="server"`` and + ``service_name="WebBrowser"``, action 2 can be " + 3. ``options`` + ``options`` contains a dictionary of options which are passed to the ActionManager's __init__ method. + These options are used to calculate the shape of the action space, and to provide additional information + to the ActionManager which is required to convert the agent's action choice into a CAOS request. + + :param game: The Primaite Game to which the agent belongs. + :type game: PrimaiteGame + :param cfg: The action space config. + :type cfg: Dict + :return: The constructed ActionManager. + :rtype: ActionManager + """ + obj = cls( + actions=cfg["action_list"], + **cfg["options"], + protocols=game.options.protocols, + ports=game.options.ports, + act_map=cfg.get("action_map"), + ) + + return obj diff --git a/src/primaite/game/agent/actions/network.py b/src/primaite/game/agent/actions/network.py new file mode 100644 index 00000000..7f1e069a --- /dev/null +++ b/src/primaite/game/agent/actions/network.py @@ -0,0 +1,57 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK + +from typing import ClassVar + +from primaite.game.agent.actions.manager import AbstractAction +from primaite.interface.request import RequestFormat + +__all__ = ("NetworkPortEnableAction", "NetworkPortDisableAction") + + +class NetworkPortAbstractAction(AbstractAction, identifier="network_port_abstract"): + """Base class for Network port actions.""" + + config: "NetworkPortAbstractAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Base configuration schema for NetworkPort actions.""" + + target_nodename: str + port_id: int + verb: ClassVar[str] + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + if config.target_nodename is None or config.port_id is None: + return ["do_nothing"] + return [ + "network", + "node", + config.target_nodename, + "network_interface", + config.port_id, + config.verb, + ] + + +class NetworkPortEnableAction(NetworkPortAbstractAction, identifier="network_port_enable"): + """Action which enables are port on a router or a firewall.""" + + config: "NetworkPortEnableAction.ConfigSchema" + + class ConfigSchema(NetworkPortAbstractAction.ConfigSchema): + """Configuration schema for NetworkPortEnableAction.""" + + verb: ClassVar[str] = "enable" + + +class NetworkPortDisableAction(NetworkPortAbstractAction, identifier="network_port_disable"): + """Action which disables are port on a router or a firewall.""" + + config: "NetworkPortDisableAction.ConfigSchema" + + class ConfigSchema(NetworkPortAbstractAction.ConfigSchema): + """Configuration schema for NetworkPortDisableAction.""" + + verb: ClassVar[str] = "disable" diff --git a/src/primaite/game/agent/actions/node.py b/src/primaite/game/agent/actions/node.py new file mode 100644 index 00000000..4a7f725e --- /dev/null +++ b/src/primaite/game/agent/actions/node.py @@ -0,0 +1,195 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from abc import abstractmethod +from typing import ClassVar, List, Optional, Union + +from primaite.game.agent.actions.manager import AbstractAction +from primaite.interface.request import RequestFormat + +__all__ = ( + "NodeOSScanAction", + "NodeShutdownAction", + "NodeStartupAction", + "NodeResetAction", + "NodeNMAPPingScanAction", + "NodeNMAPPortScanAction", + "NodeNetworkServiceReconAction", +) + + +class NodeAbstractAction(AbstractAction, identifier="node_abstract"): + """ + Abstract base class for node actions. + + Any action which applies to a node and uses node_name as its only parameter can inherit from this base class. + """ + + config: "NodeAbstractAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Base Configuration schema for Node actions.""" + + node_name: str + verb: ClassVar[str] + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + print(config) + return ["network", "node", config.node_name, config.verb] + + +class NodeOSScanAction(NodeAbstractAction, identifier="node_os_scan"): + """Action which scans a node's OS.""" + + config: "NodeOSScanAction.ConfigSchema" + + class ConfigSchema(NodeAbstractAction.ConfigSchema): + """Configuration schema for NodeOSScanAction.""" + + verb: ClassVar[str] = "scan" + + +class NodeShutdownAction(NodeAbstractAction, identifier="node_shutdown"): + """Action which shuts down a node.""" + + config: "NodeShutdownAction.ConfigSchema" + + class ConfigSchema(NodeAbstractAction.ConfigSchema): + """Configuration schema for NodeShutdownAction.""" + + verb: ClassVar[str] = "shutdown" + + +class NodeStartupAction(NodeAbstractAction, identifier="node_startup"): + """Action which starts up a node.""" + + config: "NodeStartupAction.ConfigSchema" + + class ConfigSchema(NodeAbstractAction.ConfigSchema): + """Configuration schema for NodeStartupAction.""" + + verb: ClassVar[str] = "startup" + + +class NodeResetAction(NodeAbstractAction, identifier="node_reset"): + """Action which resets a node.""" + + config: "NodeResetAction.ConfigSchema" + + class ConfigSchema(NodeAbstractAction.ConfigSchema): + """Configuration schema for NodeResetAction.""" + + verb: ClassVar[str] = "reset" + + +class NodeNMAPAbstractAction(AbstractAction, identifier="node_nmap_abstract_action"): + """Base class for NodeNMAP actions.""" + + config: "NodeNMAPAbstractAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Base Configuration Schema for NodeNMAP actions.""" + + target_ip_address: Union[str, List[str]] + show: bool = False + node_name: str + + @classmethod + @abstractmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + # NMAP action requests don't share a common format for their requests + # This is just a placeholder to ensure the method is defined. + pass + + +class NodeNMAPPingScanAction(NodeNMAPAbstractAction, identifier="node_nmap_ping_scan"): + """Action which performs an NMAP ping scan.""" + + config: "NodeNMAPPingScanAction.ConfigSchema" + + class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema): + """Configuration schema for NodeNMAPPingScanAction.""" + + pass + + @classmethod + def form_request(cls, config: ConfigSchema) -> List[str]: # noqa + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.node_name, + "application", + "NMAP", + "ping_scan", + {"target_ip_address": config.target_ip_address, "show": config.show}, + ] + + +class NodeNMAPPortScanAction(NodeNMAPAbstractAction, identifier="node_nmap_port_scan"): + """Action which performs an NMAP port scan.""" + + config: "NodeNMAPPortScanAction.ConfigSchema" + + class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema): + """Configuration Schema for NodeNMAPPortScanAction.""" + + source_node: str + target_protocol: Optional[Union[str, List[str]]] = (None,) + target_port: Optional[Union[str, List[str]]] = (None,) + show: Optional[bool] = (False,) + + @classmethod + def form_request( + cls, + config: ConfigSchema, + ) -> List[str]: # noqa + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.source_node, + "application", + "NMAP", + "port_scan", + { + "target_ip_address": config.target_ip_address, + "target_port": config.target_port, + "target_protocol": config.target_protocol, + "show": config.show, + }, + ] + + +class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, identifier="node_network_service_recon"): + """Action which performs an NMAP network service recon (ping scan followed by port scan).""" + + config: "NodeNetworkServiceReconAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Configuration schema for NodeNetworkServiceReconAction.""" + + target_protocol: Optional[Union[str, List[str]]] = (None,) + target_port: Optional[Union[str, List[str]]] = (None,) + show: Optional[bool] = (False,) + + @classmethod + def form_request( + cls, + config: ConfigSchema, + ) -> List[str]: # noqa + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.source_node, + "application", + "NMAP", + "network_service_recon", + { + "target_ip_address": config.target_ip_address, + "target_port": config.target_port, + "target_protocol": config.target_protocol, + "show": config.show, + }, + ] diff --git a/src/primaite/game/agent/actions/service.py b/src/primaite/game/agent/actions/service.py new file mode 100644 index 00000000..4a483f28 --- /dev/null +++ b/src/primaite/game/agent/actions/service.py @@ -0,0 +1,135 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from typing import ClassVar + +from primaite.game.agent.actions.manager import AbstractAction +from primaite.interface.request import RequestFormat + +__all__ = ( + "NodeServiceScanAction", + "NodeServiceStopAction", + "NodeServiceStartAction", + "NodeServicePauseAction", + "NodeServiceResumeAction", + "NodeServiceRestartAction", + "NodeServiceDisableAction", + "NodeServiceEnableAction", + "NodeServiceFixAction", +) + + +class NodeServiceAbstractAction(AbstractAction, identifier="node_service_abstract"): + """Abstract Action for Node Service related actions. + + Any actions which use node_name and service_name can inherit from this class. + """ + + config: "NodeServiceAbstractAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + node_name: str + service_name: str + verb: ClassVar[str] + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return ["network", "node", config.node_name, "service", config.service_name, config.verb] + + +class NodeServiceScanAction(NodeServiceAbstractAction, identifier="node_service_scan"): + """Action which scans a service.""" + + config: "NodeServiceScanAction.ConfigSchema" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + """Configuration Schema for NodeServiceScanAction.""" + + verb: ClassVar[str] = "scan" + + +class NodeServiceStopAction(NodeServiceAbstractAction, identifier="node_service_stop"): + """Action which stops a service.""" + + config: "NodeServiceStopAction.ConfigSchema" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + """Configuration Schema for NodeServiceStopAction.""" + + verb: ClassVar[str] = "stop" + + +class NodeServiceStartAction(NodeServiceAbstractAction, identifier="node_service_start"): + """Action which starts a service.""" + + config: "NodeServiceStartAction.ConfigSchema" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + """Configuration Schema for NodeServiceStartAction.""" + + verb: ClassVar[str] = "start" + + +class NodeServicePauseAction(NodeServiceAbstractAction, identifier="node_service_pause"): + """Action which pauses a service.""" + + config: "NodeServicePauseAction.ConfigSchema" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + """Configuration Schema for NodeServicePauseAction.""" + + verb: ClassVar[str] = "pause" + + +class NodeServiceResumeAction(NodeServiceAbstractAction, identifier="node_service_resume"): + """Action which resumes a service.""" + + config: "NodeServiceResumeAction.ConfigSchema" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + """Configuration Schema for NodeServiceResumeAction.""" + + verb: ClassVar[str] = "resume" + + +class NodeServiceRestartAction(NodeServiceAbstractAction, identifier="node_service_restart"): + """Action which restarts a service.""" + + config: "NodeServiceRestartAction.ConfigSchema" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + """Configuration Schema for NodeServiceRestartAction.""" + + verb: ClassVar[str] = "restart" + + +class NodeServiceDisableAction(NodeServiceAbstractAction, identifier="node_service_disable"): + """Action which disables a service.""" + + config: "NodeServiceDisableAction.ConfigSchema" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + """Configuration Schema for NodeServiceDisableAction.""" + + verb: ClassVar[str] = "disable" + + +class NodeServiceEnableAction(NodeServiceAbstractAction, identifier="node_service_enable"): + """Action which enables a service.""" + + config: "NodeServiceEnableAction.ConfigSchema" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + """Configuration Schema for NodeServiceEnableAction.""" + + verb: ClassVar[str] = "enable" + + +class NodeServiceFixAction(NodeServiceAbstractAction, identifier="node_service_fix"): + """Action which fixes a service.""" + + config: "NodeServiceFixAction.ConfigSchema" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + """Configuration Schema for NodeServiceFixAction.""" + + verb: ClassVar[str] = "fix" diff --git a/src/primaite/game/agent/actions/session.py b/src/primaite/game/agent/actions/session.py new file mode 100644 index 00000000..1191987b --- /dev/null +++ b/src/primaite/game/agent/actions/session.py @@ -0,0 +1,108 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from abc import abstractmethod + +from primaite.game.agent.actions.manager import AbstractAction +from primaite.interface.request import RequestFormat + +__all__ = ( + "NodeSessionsRemoteLoginAction", + "NodeSessionsRemoteLogoutAction", + "NodeAccountChangePasswordAction", +) + + +class NodeSessionAbstractAction(AbstractAction, identifier="node_session_abstract"): + """Base class for NodeSession actions.""" + + config: "NodeSessionAbstractAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Base configuration schema for NodeSessionAbstractActions.""" + + node_name: str + remote_ip: str + + @classmethod + @abstractmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """ + Abstract method for request forming. + + Should return the action formatted as a request which can be ingested by the PrimAITE simulation. + """ + pass + + +class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, identifier="node_session_remote_login"): + """Action which performs a remote session login.""" + + config: "NodeSessionsRemoteLoginAction.ConfigSchema" + + class ConfigSchema(NodeSessionAbstractAction.ConfigSchema): + """Configuration schema for NodeSessionsRemoteLoginAction.""" + + username: str + password: str + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + if config.node_name is None or config.remote_ip is None: + return ["do_nothing"] + return [ + "network", + "node", + config.node_name, + "service", + "Terminal", + "ssh_to_remote", + config.username, + config.password, + config.remote_ip, + ] + + +class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, identifier="node_session_remote_logoff"): + """Action which performs a remote session logout.""" + + config: "NodeSessionsRemoteLogoutAction.ConfigSchema" + + class ConfigSchema(NodeSessionAbstractAction.ConfigSchema): + """Configuration schema for NodeSessionsRemoteLogoutAction.""" + + verb: str = "remote_logoff" + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + if config.node_name is None or config.remote_ip is None: + return ["do_nothing"] + return ["network", "node", config.node_name, "service", "Terminal", config.verb, config.remote_ip] + + +class NodeAccountChangePasswordAction(NodeSessionAbstractAction, identifier="node_account_change_password"): + """Action which changes the password for a user.""" + + config: "NodeAccountChangePasswordAction.ConfigSchema" + + class ConfigSchema(NodeSessionAbstractAction.ConfigSchema): + """Configuration schema for NodeAccountsChangePasswordAction.""" + + username: str + current_password: str + new_password: str + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.node_name, + "service", + "UserManager", + "change_password", + config.username, + config.current_password, + config.new_password, + ] diff --git a/src/primaite/game/agent/actions/software.py b/src/primaite/game/agent/actions/software.py new file mode 100644 index 00000000..760e8dfa --- /dev/null +++ b/src/primaite/game/agent/actions/software.py @@ -0,0 +1,238 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK + +from typing import List, Optional, Union + +from pydantic import ConfigDict, Field, field_validator, ValidationInfo + +from primaite.game.agent.actions.manager import AbstractAction, ActionManager +from primaite.interface.request import RequestFormat + +__all__ = ( + "ConfigureRansomwareScriptAction", + "ConfigureDoSBotAction", + "ConfigureC2BeaconAction", + "NodeSendRemoteCommandAction", + "TerminalC2ServerAction", + "RansomwareLaunchC2ServerAction", + "ExfiltrationC2ServerAction", + "ConfigureDatabaseClientAction", +) + + +class ConfigureRansomwareScriptAction(AbstractAction, identifier="c2_server_ransomware_configure"): + """Action which sets config parameters for a ransomware script on a node.""" + + config: "ConfigureRansomwareScriptAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Configuration schema for ConfigureRansomwareScriptAction.""" + + node_name: str + server_ip_address: Optional[str] + server_password: Optional[str] + payload: Optional[str] + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request that can be ingested by the simulation.""" + if config.node_name is None: + return ["do_nothing"] + return [ + "network", + "node", + config.node_name, + "application", + "RansomwareScript", + "configure", + config.model_config, + ] + + +class ConfigureDoSBotAction(AbstractAction, identifier="configure_dos_bot"): + """Action which sets config parameters for a DoS bot on a node.""" + + config: "ConfigureDoSBotAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Schema for options that can be passed to this action.""" + + node_name: str + model_config = ConfigDict(extra="forbid") + target_ip_address: Optional[str] = None + target_port: Optional[str] = None + payload: Optional[str] = None + repeat: Optional[bool] = None + port_scan_p_of_success: Optional[float] = None + dos_intensity: Optional[float] = None + max_sessions: Optional[int] = None + + def __init__(self, manager: "ActionManager", **kwargs) -> None: + super().__init__(manager=manager) + + def form_request(self, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request that can be ingested by the simulation.""" + if config.node_name is None: + return ["do_nothing"] + self.ConfigSchema.model_validate(config) # check that options adhere to schema + return ["network", "node", config.node_name, "application", "DoSBot", "configure", config] + + +class ConfigureC2BeaconAction(AbstractAction, identifier="configure_c2_beacon"): + """Action which configures a C2 Beacon based on the parameters given.""" + + config: "ConfigureC2BeaconAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Configuration schema for ConfigureC2BeaconAction.""" + + node_name: str + c2_server_ip_address: str + keep_alive_frequency: int = Field(default=5, ge=1) + masquerade_protocol: str = Field(default="TCP") + masquerade_port: str = Field(default="HTTP") + + @field_validator( + "c2_server_ip_address", + "keep_alive_frequency", + "masquerade_protocol", + "masquerade_port", + mode="before", + ) + @classmethod + def not_none(cls, v: str, info: ValidationInfo) -> int: + """If None is passed, use the default value instead.""" + if v is None: + return cls.model_fields[info.field_name].default + return v + + @classmethod + def form_request(self, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request that can be ingested by the simulation.""" + return ["network", "node", config.node_name, "application", "C2Beacon", "configure", config] + + +class NodeSendRemoteCommandAction(AbstractAction, identifier="node_send_remote_command"): + """Action which sends a terminal command to a remote node via SSH.""" + + config: "NodeSendRemoteCommandAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Configuration schema for NodeSendRemoteCommandAction.""" + + node_name: str + remote_ip: str + command: RequestFormat + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.node_name, + "service", + "Terminal", + "send_remote_command", + config.remote_ip, + {"command": config.command}, + ] + + +class TerminalC2ServerAction(AbstractAction, identifier="c2_server_terminal_command"): + """Action which causes the C2 Server to send a command to the C2 Beacon to execute the terminal command passed.""" + + config: "TerminalC2ServerAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Schema for options that can be passed to this action.""" + + node_name: str + commands: Union[List[RequestFormat], RequestFormat] + ip_address: Optional[str] + username: Optional[str] + password: Optional[str] + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request that can be ingested by the simulation.""" + if config.node_name is None: + return ["do_nothing"] + + command_model = { + "commands": config.commands, + "ip_address": config.ip_address, + "username": config.username, + "password": config.password, + } + return ["network", "node", config.node_name, "application", "C2Server", "terminal_command", command_model] + + +class RansomwareLaunchC2ServerAction(AbstractAction, identifier="c2_server_ransomware_launch"): + """Action which causes the C2 Server to send a command to the C2 Beacon to launch the RansomwareScript.""" + + config: "RansomwareLaunchC2ServerAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Configuration schema for RansomwareLaunchC2ServerAction.""" + + node_name: str + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request that can be ingested by the simulation.""" + if config.node_name is None: + return ["do_nothing"] + # This action currently doesn't require any further configuration options. + return ["network", "node", config.node_name, "application", "C2Server", "ransomware_launch"] + + +class ExfiltrationC2ServerAction(AbstractAction, identifier="c2_server_data_exfiltrate"): + """Action which exfiltrates a target file from a certain node onto the C2 beacon and then the C2 Server.""" + + config: "ExfiltrationC2ServerAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Schema for options that can be passed to this action.""" + + node_name: str + username: Optional[str] + password: Optional[str] + target_ip_address: str + target_file_name: str + target_folder_name: str + exfiltration_folder_name: Optional[str] + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request that can be ingested by the simulation.""" + if config.node_name is None: + return ["do_nothing"] + + command_model = { + "target_file_name": config.target_file_name, + "target_folder_name": config.target_folder_name, + "exfiltration_folder_name": config.exfiltration_folder_name, + "target_ip_address": config.target_ip_address, + "username": config.username, + "password": config.password, + } + return ["network", "node", config.node_name, "application", "C2Server", "exfiltrate", command_model] + + +class ConfigureDatabaseClientAction(AbstractAction, identifier="configure_database_client"): + """Action which sets config parameters for a database client on a node.""" + + config: "ConfigureDatabaseClientAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Schema for options that can be passed to this action.""" + + node_name: str + model_config = ConfigDict(extra="forbid") + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request that can be ingested by the simulation.""" + if config.node_name is None: + return ["do_nothing"] + return ["network", "node", config.node_name, "application", "DatabaseClient", "configure", config.model_config] diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index ac76a425..ec9d6c61 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -73,11 +73,13 @@ class AbstractAgent(BaseModel): model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) type: str = "AbstractAgent" - def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: + def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + if identifier is None: + return if identifier in cls._registry: raise ValueError(f"Cannot create a new agent under reserved name {identifier}") cls._registry[identifier] = cls - super().__init_subclass__(**kwargs) @property def flatten_obs(self) -> bool: diff --git a/src/primaite/game/agent/observations/__init__.py b/src/primaite/game/agent/observations/__init__.py index 68435eed..a38095b3 100644 --- a/src/primaite/game/agent/observations/__init__.py +++ b/src/primaite/game/agent/observations/__init__.py @@ -17,5 +17,5 @@ from primaite.game.agent.observations.software_observation import ApplicationObs __all__ = [ "ACLObservation", "FileObservation", "FolderObservation", "FirewallObservation", "HostObservation", "LinksObservation", "NICObservation", "PortObservation", "NodesObservation", "NestedObservation", - "ObservationManager", "ApplicationObservation", "ServiceObservation",] + "ObservationManager", "ApplicationObservation", "ServiceObservation", "RouterObservation", "LinkObservation",] # fmt: on diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py index 49b9ab72..89c45b37 100644 --- a/src/primaite/game/agent/observations/observations.py +++ b/src/primaite/game/agent/observations/observations.py @@ -31,7 +31,7 @@ class AbstractObservation(ABC): """Initialise an observation. This method must be overwritten.""" self.default_observation: ObsType - def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: + def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None: """ Register an observation type. @@ -40,6 +40,8 @@ class AbstractObservation(ABC): :raises ValueError: When attempting to create a component with a name that is already in use. """ super().__init_subclass__(**kwargs) + if identifier is None: + return if identifier in cls._registry: raise ValueError(f"Duplicate observation component type {identifier}") cls._registry[identifier] = cls diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 7c184770..50fdaba8 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -27,9 +27,10 @@ the structure: service_ref: web_server_database_client ``` """ -from abc import abstractmethod -from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union +from abc import ABC, abstractmethod +from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union +from pydantic import BaseModel from typing_extensions import Never from primaite import getLogger @@ -42,25 +43,28 @@ _LOGGER = getLogger(__name__) WhereType = Optional[Iterable[Union[str, int]]] -class AbstractReward: +class AbstractReward(BaseModel): """Base class for reward function components.""" - @abstractmethod - def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: - """Calculate the reward for the current state. + config: "AbstractReward.ConfigSchema" - :param state: Current simulation state - :type state: Dict - :param last_action_response: Current agent history state - :type last_action_response: AgentHistoryItem state - :return: Reward value - :rtype: float - """ - return 0.0 + class ConfigSchema(BaseModel, ABC): + """Config schema for AbstractReward.""" + + type: str + + _registry: ClassVar[Dict[str, Type["AbstractReward"]]] = {} + + def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + if identifier is None: + return + if identifier in cls._registry: + raise ValueError(f"Duplicate reward {identifier}") + cls._registry[identifier] = cls @classmethod - @abstractmethod - def from_config(cls, config: dict) -> "AbstractReward": + def from_config(cls, config: Dict) -> "AbstractReward": """Create a reward function component from a config dictionary. :param config: dict of options for the reward component's constructor @@ -68,11 +72,28 @@ class AbstractReward: :return: The reward component. :rtype: AbstractReward """ - return cls() + if config["type"] not in cls._registry: + raise ValueError(f"Invalid reward type {config['type']}") + reward_class = cls._registry[config["type"]] + reward_obj = reward_class(config=reward_class.ConfigSchema(**config)) + return reward_obj + + @abstractmethod + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: + """Calculate the reward for the current state. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ + return 0.0 -class DummyReward(AbstractReward): - """Dummy reward function component which always returns 0.""" +class DummyReward(AbstractReward, identifier="DUMMY"): + """Dummy reward function component which always returns 0.0.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. @@ -86,41 +107,21 @@ class DummyReward(AbstractReward): """ return 0.0 - @classmethod - def from_config(cls, config: dict) -> "DummyReward": - """Create a reward function component from a config dictionary. - :param config: dict of options for the reward component's constructor. Should be empty. - :type config: dict - :return: The reward component. - :rtype: DummyReward - """ - return cls() - - -class DatabaseFileIntegrity(AbstractReward): +class DatabaseFileIntegrity(AbstractReward, identifier="DATABASE_FILE_INTEGRITY"): """Reward function component which rewards the agent for maintaining the integrity of a database file.""" - def __init__(self, node_hostname: str, folder_name: str, file_name: str) -> None: - """Initialise the reward component. + config: "DatabaseFileIntegrity.ConfigSchema" + location_in_state: List[str] = [""] + reward: float = 0.0 - :param node_hostname: Hostname of the node which contains the database file. - :type node_hostname: str - :param folder_name: folder which contains the database file. - :type folder_name: str - :param file_name: name of the database file. - :type file_name: str - """ - self.location_in_state = [ - "network", - "nodes", - node_hostname, - "file_system", - "folders", - folder_name, - "files", - file_name, - ] + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for DatabaseFileIntegrity.""" + + type: str = "DATABASE_FILE_INTEGRITY" + node_hostname: str + folder_name: str + file_name: str def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. @@ -132,6 +133,17 @@ class DatabaseFileIntegrity(AbstractReward): :return: Reward value :rtype: float """ + self.location_in_state = [ + "network", + "nodes", + self.config.node_hostname, + "file_system", + "folders", + self.config.folder_name, + "files", + self.config.file_name, + ] + database_file_state = access_from_nested_dict(state, self.location_in_state) if database_file_state is NOT_PRESENT_IN_STATE: _LOGGER.debug( @@ -148,44 +160,21 @@ class DatabaseFileIntegrity(AbstractReward): else: return 0 - @classmethod - def from_config(cls, config: Dict) -> "DatabaseFileIntegrity": - """Create a reward function component from a config dictionary. - :param config: dict of options for the reward component's constructor - :type config: Dict - :return: The reward component. - :rtype: DatabaseFileIntegrity - """ - node_hostname = config.get("node_hostname") - folder_name = config.get("folder_name") - file_name = config.get("file_name") - if not (node_hostname and folder_name and file_name): - msg = f"{cls.__name__} could not be initialised with parameters {config}" - _LOGGER.error(msg) - raise ValueError(msg) - - return cls(node_hostname=node_hostname, folder_name=folder_name, file_name=file_name) - - -class WebServer404Penalty(AbstractReward): +class WebServer404Penalty(AbstractReward, identifier="WEB_SERVER_404_PENALTY"): """Reward function component which penalises the agent when the web server returns a 404 error.""" - def __init__(self, node_hostname: str, service_name: str, sticky: bool = True) -> None: - """Initialise the reward component. + config: "WebServer404Penalty.ConfigSchema" + location_in_state: List[str] = [""] + reward: float = 0.0 - :param node_hostname: Hostname of the node which contains the web server service. - :type node_hostname: str - :param service_name: Name of the web server service. - :type service_name: str - :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate - the reward if there were any responses this timestep. - :type sticky: bool - """ - self.sticky: bool = sticky - self.reward: float = 0.0 - """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" - self.location_in_state = ["network", "nodes", node_hostname, "services", service_name] + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for WebServer404Penalty.""" + + type: str = "WEB_SERVER_404_PENALTY" + node_hostname: str + service_name: str + sticky: bool = True def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. @@ -197,6 +186,13 @@ class WebServer404Penalty(AbstractReward): :return: Reward value :rtype: float """ + self.location_in_state = [ + "network", + "nodes", + self.config.node_hostname, + "services", + self.config.service_name, + ] web_service_state = access_from_nested_dict(state, self.location_in_state) # if webserver is no longer installed on the node, return 0 @@ -211,54 +207,27 @@ class WebServer404Penalty(AbstractReward): return 1.0 if status == 200 else -1.0 if status == 404 else 0.0 self.reward = sum(map(status2rew, codes)) / len(codes) # convert form HTTP codes to rewards and average - elif not self.sticky: # there are no codes, but reward is not sticky, set reward to 0 + elif not self.config.sticky: # there are no codes, but reward is not sticky, set reward to 0 self.reward = 0.0 else: # skip calculating if sticky and no new codes. instead, reuse last step's value pass return self.reward - @classmethod - def from_config(cls, config: Dict) -> "WebServer404Penalty": - """Create a reward function component from a config dictionary. - :param config: dict of options for the reward component's constructor - :type config: Dict - :return: The reward component. - :rtype: WebServer404Penalty - """ - node_hostname = config.get("node_hostname") - service_name = config.get("service_name") - if not (node_hostname and service_name): - msg = ( - f"{cls.__name__} could not be initialised from config because node_name and service_ref were not " - "found in reward config." - ) - _LOGGER.warning(msg) - raise ValueError(msg) - sticky = config.get("sticky", True) - - return cls(node_hostname=node_hostname, service_name=service_name, sticky=sticky) - - -class WebpageUnavailablePenalty(AbstractReward): +class WebpageUnavailablePenalty(AbstractReward, identifier="WEBPAGE_UNAVAILABLE_PENALTY"): """Penalises the agent when the web browser fails to fetch a webpage.""" - def __init__(self, node_hostname: str, sticky: bool = True) -> None: - """ - Initialise the reward component. + config: "WebpageUnavailablePenalty.ConfigSchema" + reward: float = 0.0 + location_in_state: List[str] = [""] # Calculate in __init__()? - :param node_hostname: Hostname of the node which has the web browser. - :type node_hostname: str - :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate - the reward if there were any responses this timestep. - :type sticky: bool - """ - self._node: str = node_hostname - self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"] - self.sticky: bool = sticky - self.reward: float = 0.0 - """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for WebpageUnavailablePenalty.""" + + type: str = "WEBPAGE_UNAVAILABLE_PENALTY" + node_hostname: str = "" + sticky: bool = True def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ @@ -274,6 +243,13 @@ class WebpageUnavailablePenalty(AbstractReward): :return: Reward value :rtype: float """ + self.location_in_state = [ + "network", + "nodes", + self.config.node_hostname, + "applications", + "WebBrowser", + ] web_browser_state = access_from_nested_dict(state, self.location_in_state) if web_browser_state is NOT_PRESENT_IN_STATE: @@ -283,14 +259,14 @@ class WebpageUnavailablePenalty(AbstractReward): request_attempted = last_action_response.request == [ "network", "node", - self._node, + self.config.node_hostname, "application", "WebBrowser", "execute", ] # skip calculating if sticky and no new codes, reusing last step value - if not request_attempted and self.sticky: + if not request_attempted and self.config.sticky: return self.reward if last_action_response.response.status != "success": @@ -298,7 +274,7 @@ class WebpageUnavailablePenalty(AbstractReward): elif web_browser_state is NOT_PRESENT_IN_STATE or not web_browser_state["history"]: _LOGGER.debug( "Web browser reward could not be calculated because the web browser history on node", - f"{self._node} was not reported in the simulation state. Returning 0.0", + f"{self.config.node_hostname} was not reported in the simulation state. Returning 0.0", ) self.reward = 0.0 else: @@ -312,37 +288,19 @@ class WebpageUnavailablePenalty(AbstractReward): return self.reward - @classmethod - def from_config(cls, config: dict) -> AbstractReward: - """ - Build the reward component object from config. - :param config: Configuration dictionary. - :type config: Dict - """ - node_hostname = config.get("node_hostname") - sticky = config.get("sticky", True) - return cls(node_hostname=node_hostname, sticky=sticky) - - -class GreenAdminDatabaseUnreachablePenalty(AbstractReward): +class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY"): """Penalises the agent when the green db clients fail to connect to the database.""" - def __init__(self, node_hostname: str, sticky: bool = True) -> None: - """ - Initialise the reward component. + config: "GreenAdminDatabaseUnreachablePenalty.ConfigSchema" + reward: float = 0.0 - :param node_hostname: Hostname of the node where the database client sits. - :type node_hostname: str - :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate - the reward if there were any responses this timestep. - :type sticky: bool - """ - self._node: str = node_hostname - self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] - self.sticky: bool = sticky - self.reward: float = 0.0 - """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for GreenAdminDatabaseUnreachablePenalty.""" + + type: str = "GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY" + node_hostname: str + sticky: bool = True def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ @@ -362,7 +320,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): request_attempted = last_action_response.request == [ "network", "node", - self._node, + self.config.node_hostname, "application", "DatabaseClient", "execute", @@ -371,7 +329,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): if request_attempted: # if agent makes request, always recalculate fresh value last_action_response.reward_info = {"connection_attempt_status": last_action_response.response.status} self.reward = 1.0 if last_action_response.response.status == "success" else -1.0 - elif not self.sticky: # if no new request and not sticky, set reward to 0 + elif not self.config.sticky: # if no new request and not sticky, set reward to 0 last_action_response.reward_info = {"connection_attempt_status": "n/a"} self.reward = 0.0 else: # if no new request and sticky, reuse reward value from last step @@ -380,47 +338,30 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): return self.reward - @classmethod - def from_config(cls, config: Dict) -> AbstractReward: - """ - Build the reward component object from config. - :param config: Configuration dictionary. - :type config: Dict - """ - node_hostname = config.get("node_hostname") - sticky = config.get("sticky", True) - return cls(node_hostname=node_hostname, sticky=sticky) - - -class SharedReward(AbstractReward): +class SharedReward(AbstractReward, identifier="SHARED_REWARD"): """Adds another agent's reward to the overall reward.""" - def __init__(self, agent_name: Optional[str] = None) -> None: + config: "SharedReward.ConfigSchema" + + class ConfigSchema(AbstractReward.ConfigSchema): + """Config schema for SharedReward.""" + + type: str = "SHARED_REWARD" + agent_name: str + + def default_callback(agent_name: str) -> Never: """ - Initialise the shared reward. + Default callback to prevent calling this reward until it's properly initialised. - The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work - correctly. - - :param agent_name: The name whose reward is an input - :type agent_name: Optional[str] + SharedReward should not be used until the game layer replaces self.callback with a reference to the + function that retrieves the desired agent's reward. Therefore, we define this default callback that raises + an error. """ - self.agent_name = agent_name - """Agent whose reward to track.""" + raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.") - def default_callback(agent_name: str) -> Never: - """ - Default callback to prevent calling this reward until it's properly initialised. - - SharedReward should not be used until the game layer replaces self.callback with a reference to the - function that retrieves the desired agent's reward. Therefore, we define this default callback that raises - an error. - """ - raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.") - - self.callback: Callable[[str], float] = default_callback - """Method that retrieves an agent's current reward given the agent's name.""" + callback: Callable[[str], float] = default_callback + """Method that retrieves an agent's current reward given the agent's name.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Simply access the other agent's reward and return it. @@ -432,36 +373,25 @@ class SharedReward(AbstractReward): :return: Reward value :rtype: float """ - return self.callback(self.agent_name) - - @classmethod - def from_config(cls, config: Dict) -> "SharedReward": - """ - Build the SharedReward object from config. - - :param config: Configuration dictionary - :type config: Dict - """ - agent_name = config.get("agent_name") - return cls(agent_name=agent_name) + return self.callback(self.config.agent_name) -class ActionPenalty(AbstractReward): +class ActionPenalty(AbstractReward, identifier="ACTION_PENALTY"): """Apply a negative reward when taking any action except DONOTHING.""" - def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None: - """ - Initialise the reward. + config: "ActionPenalty.ConfigSchema" - Reward or penalise agents for doing nothing or taking actions. + class ConfigSchema(AbstractReward.ConfigSchema): + """Config schema for ActionPenalty. - :param action_penalty: Reward to give agents for taking any action except DONOTHING + :param action_penalty: Reward to give agents for taking any action except do_nothing :type action_penalty: float - :param do_nothing_penalty: Reward to give agent for taking the DONOTHING action + :param do_nothing_penalty: Reward to give agent for taking the do_nothing action :type do_nothing_penalty: float """ - self.action_penalty = action_penalty - self.do_nothing_penalty = do_nothing_penalty + + action_penalty: float = -1.0 + do_nothing_penalty: float = 0.0 def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the penalty to be applied. @@ -473,33 +403,16 @@ class ActionPenalty(AbstractReward): :return: Reward value :rtype: float """ - if last_action_response.action == "DONOTHING": + if last_action_response.action == "do_nothing": return self.do_nothing_penalty - else: - return self.action_penalty - @classmethod - def from_config(cls, config: Dict) -> "ActionPenalty": - """Build the ActionPenalty object from config.""" - action_penalty = config.get("action_penalty", -1.0) - do_nothing_penalty = config.get("do_nothing_penalty", 0.0) - return cls(action_penalty=action_penalty, do_nothing_penalty=do_nothing_penalty) + else: + return self.config.action_penalty class RewardFunction: """Manages the reward function for the agent.""" - rew_class_identifiers: Dict[str, Type[AbstractReward]] = { - "DUMMY": DummyReward, - "DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity, - "WEB_SERVER_404_PENALTY": WebServer404Penalty, - "WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty, - "GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty, - "SHARED_REWARD": SharedReward, - "ACTION_PENALTY": ActionPenalty, - } - """List of reward class identifiers.""" - def __init__(self): """Initialise the reward function object.""" self.reward_components: List[Tuple[AbstractReward, float]] = [] @@ -534,7 +447,7 @@ class RewardFunction: @classmethod def from_config(cls, config: Dict) -> "RewardFunction": - """Create a reward function from a config dictionary. + """Create a reward function from a config dictionary and its related reward class. :param config: dict of options for the reward manager's constructor :type config: Dict @@ -545,8 +458,11 @@ class RewardFunction: for rew_component_cfg in config["reward_components"]: rew_type = rew_component_cfg["type"] + # XXX: If options key is missing add key then add type key. + if "options" not in rew_component_cfg: + rew_component_cfg["options"] = {} + rew_component_cfg["options"]["type"] = rew_type weight = rew_component_cfg.get("weight", 1.0) - rew_class = cls.rew_class_identifiers[rew_type] - rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {})) + rew_instance = AbstractReward.from_config(rew_component_cfg["options"]) new.register_component(component=rew_instance, weight=weight) return new diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index f2b1de4c..bf480d0e 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -370,7 +370,7 @@ class PrimaiteGame: if service_class is not None: _LOGGER.debug(f"installing {service_type} on node {new_node.hostname}") - new_node.software_manager.install(service_class) + new_node.software_manager.install(service_class, **service_cfg.get("options", {})) new_service = new_node.software_manager.software[service_class.__name__] # fixing duration for the service @@ -580,7 +580,7 @@ class PrimaiteGame: for comp, weight in agent.reward_function.reward_components: if isinstance(comp, SharedReward): comp: SharedReward - graph[name].add(comp.agent_name) + graph[name].add(comp.config.agent_name) # while constructing the graph, we might as well set up the reward sharing itself. comp.callback = lambda agent_name: self.agents[agent_name].reward_function.current_reward diff --git a/src/primaite/notebooks/Action-masking.ipynb b/src/primaite/notebooks/Action-masking.ipynb index ba70f2b4..858b4bb6 100644 --- a/src/primaite/notebooks/Action-masking.ipynb +++ b/src/primaite/notebooks/Action-masking.ipynb @@ -19,7 +19,7 @@ "source": [ "from primaite.session.environment import PrimaiteGymEnv\n", "from primaite.config.load import data_manipulation_config_path\n", - "from prettytable import PrettyTable\n" + "from prettytable import PrettyTable" ] }, { @@ -195,7 +195,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -209,7 +209,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.11" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb b/src/primaite/notebooks/Command-and-Control-E2E-Demonstration.ipynb similarity index 99% rename from src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb rename to src/primaite/notebooks/Command-and-Control-E2E-Demonstration.ipynb index 6e6819fa..d2972fa9 100644 --- a/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Command-and-Control-E2E-Demonstration.ipynb @@ -1780,10 +1780,11 @@ "metadata": {}, "outputs": [], "source": [ - "from primaite.simulator.network.transmission.network_layer import IPProtocol\n", - "from primaite.simulator.network.transmission.transport_layer import Port\n", + "from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP\n", + "from primaite.utils.validation.port import PORT_LOOKUP\n", + "\n", "# As we're configuring via the PrimAITE API we need to pass the actual IPProtocol/Port (Agents leverage the simulation via the game layer and thus can pass strings).\n", - "c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=IPProtocol["UDP"], masquerade_port=Port["DNS"])\n", + "c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=PROTOCOL_LOOKUP[\"UDP\"], masquerade_port=PORT_LOOKUP[\"DNS\"])\n", "c2_beacon.establish()\n", "c2_beacon.show()" ] @@ -1804,7 +1805,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -1818,7 +1819,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb index 0460f771..89620215 100644 --- a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb @@ -165,13 +165,13 @@ "\n", "| node_id | node name |\n", "|---------|------------------|\n", - "| 1 | domain_controller|\n", - "| 2 | web_server |\n", - "| 3 | database_server |\n", - "| 4 | backup_server |\n", - "| 5 | security_suite |\n", - "| 6 | client_1 |\n", - "| 7 | client_2 |\n", + "| 0 | domain_controller|\n", + "| 1 | web_server |\n", + "| 2 | database_server |\n", + "| 3 | backup_server |\n", + "| 4 | security_suite |\n", + "| 5 | client_1 |\n", + "| 6 | client_2 |\n", "\n", "Service 1 on node 2 (web_server) corresponds to the Web Server service. Other services are only there for padding to ensure that each node's observation space has the same shape. They are filled with zeroes.\n", "\n", diff --git a/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb b/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb index dbe8871c..0fd212f2 100644 --- a/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb +++ b/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb @@ -95,7 +95,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -109,7 +109,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.11" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb index 892736fe..5255b0ad 100644 --- a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb +++ b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb @@ -168,7 +168,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -182,7 +182,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.10.11" } }, "nbformat": 4, diff --git a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb index f573f251..7af8b98e 100644 --- a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb +++ b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb @@ -166,9 +166,10 @@ "from pathlib import Path\n", "from primaite.simulator.system.applications.application import Application, ApplicationOperatingState\n", "from primaite.simulator.system.software import SoftwareHealthState, SoftwareCriticality\n", - "from primaite.simulator.network.transmission.transport_layer import Port\n", - "from primaite.simulator.network.transmission.network_layer import IPProtocol\n", "from primaite.simulator.file_system.file_system import FileSystem\n", + "from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP\n", + "from primaite.utils.validation.port import PORT_LOOKUP\n", + "\n", "\n", "# no applications exist yet so we will create our own.\n", "class MSPaint(Application, identifier=\"MSPaint\"):\n", @@ -182,7 +183,7 @@ "metadata": {}, "outputs": [], "source": [ - "mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=Port["HTTP"], protocol = IPProtocol["NONE"],operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)" + "mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=PORT_LOOKUP[\"HTTP\"], protocol = PROTOCOL_LOOKUP[\"NONE\"],operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)" ] }, { @@ -249,7 +250,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -263,7 +264,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.11" } }, "nbformat": 4, diff --git a/src/primaite/simulator/_package_data/network_simulator_demo.ipynb b/src/primaite/simulator/_package_data/network_simulator_demo.ipynb index 2d5b4772..58c07fe5 100644 --- a/src/primaite/simulator/_package_data/network_simulator_demo.ipynb +++ b/src/primaite/simulator/_package_data/network_simulator_demo.ipynb @@ -532,12 +532,12 @@ }, "outputs": [], "source": [ - "from primaite.simulator.network.transmission.network_layer import IPProtocol\n", - "from primaite.simulator.network.transmission.transport_layer import Port\n", "from primaite.simulator.network.hardware.nodes.network.router import ACLAction\n", + "from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP\n", + "\n", "network.get_node_by_hostname(\"router_1\").acl.add_rule(\n", " action=ACLAction.DENY,\n", - " protocol=IPProtocol["ICMP"],\n", + " protocol=PROTOCOL_LOOKUP[\"ICMP\"],\n", " src_ip_address=\"192.168.10.22\",\n", " position=1\n", ")" @@ -650,7 +650,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": ".venv", "language": "python", "name": "python3" }, diff --git a/src/primaite/simulator/network/creation.py b/src/primaite/simulator/network/creation.py index 94c45428..ebd17638 100644 --- a/src/primaite/simulator/network/creation.py +++ b/src/primaite/simulator/network/creation.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from abc import ABC, abstractmethod from ipaddress import IPv4Address -from typing import Any, ClassVar, Dict, Literal, Type +from typing import Any, ClassVar, Dict, Literal, Optional, Type from pydantic import BaseModel, model_validator @@ -49,7 +49,7 @@ class NetworkNodeAdder(BaseModel): _registry: ClassVar[Dict[str, Type["NetworkNodeAdder"]]] = {} - def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: + def __init_subclass__(cls, identifier: Optional[str], **kwargs: Any) -> None: """ Register a network node adder class. @@ -58,6 +58,8 @@ class NetworkNodeAdder(BaseModel): :raises ValueError: When attempting to register a name that is already reserved. """ super().__init_subclass__(**kwargs) + if identifier is None: + return if identifier in cls._registry: raise ValueError(f"Duplicate node adder {identifier}") cls._registry[identifier] = cls diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 8324715f..0bbc0768 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1545,7 +1545,7 @@ class Node(SimComponent): _identifier: ClassVar[str] = "unknown" """Identifier for this particular class, used for printing and logging. Each subclass redefines this.""" - def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None: + def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None: """ Register a node type. @@ -1553,10 +1553,10 @@ class Node(SimComponent): :type identifier: str :raises ValueError: When attempting to register an node with a name that is already allocated. """ - if identifier == "default": + super().__init_subclass__(**kwargs) + if identifier is None: return identifier = identifier.lower() - super().__init_subclass__(**kwargs) if identifier in cls._registry: raise ValueError(f"Tried to define new node {identifier}, but this name is already reserved.") cls._registry[identifier] = cls diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 1752c09a..250a4dd9 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -44,7 +44,7 @@ class Application(IOSoftware): _registry: ClassVar[Dict[str, Type["Application"]]] = {} """Registry of application types. Automatically populated when subclasses are defined.""" - def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None: + def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None: """ Register an application type. @@ -52,9 +52,9 @@ class Application(IOSoftware): :type identifier: str :raises ValueError: When attempting to register an application with a name that is already allocated. """ - if identifier == "default": - return super().__init_subclass__(**kwargs) + if identifier is None: + return if identifier in cls._registry: raise ValueError(f"Tried to define new application {identifier}, but this name is already reserved.") cls._registry[identifier] = cls diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 840214f3..f3ccc531 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -308,6 +308,9 @@ class DatabaseClient(Application, identifier="DatabaseClient"): """ if not self._can_perform_action(): return None + if self.server_ip_address is None: + self.sys_log.warning(f"{self.name}: Database server IP address not provided.") + return None connection_request_id = str(uuid4()) self._client_connection_requests[connection_request_id] = None diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index 48f1f383..d6617efa 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -16,7 +16,7 @@ from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP from primaite.utils.validation.port import Port, PORT_LOOKUP if TYPE_CHECKING: - from primaite.simulator.network.hardware.base import NetworkInterface + from primaite.simulator.network.hardware.base import NetworkInterface, Node from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.core.sys_log import SysLog diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 3dc080b4..329aefef 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -52,7 +52,7 @@ class Service(IOSoftware): def __init__(self, **kwargs): super().__init__(**kwargs) - def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None: + def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None: """ Register a hostnode type. @@ -60,11 +60,11 @@ class Service(IOSoftware): :type identifier: str :raises ValueError: When attempting to register an hostnode with a name that is already allocated. """ - if identifier == "default": + super().__init_subclass__(**kwargs) + if identifier is None: return # Enforce lowercase registry entries because it makes comparisons everywhere else much easier. identifier = identifier.lower() - super().__init_subclass__(**kwargs) if identifier in cls._registry: raise ValueError(f"Tried to define new hostnode {identifier}, but this name is already reserved.") cls._registry[identifier] = cls diff --git a/src/primaite/utils/validation/ipv4_address.py b/src/primaite/utils/validation/ipv4_address.py index c385ed1e..b2b8b72e 100644 --- a/src/primaite/utils/validation/ipv4_address.py +++ b/src/primaite/utils/validation/ipv4_address.py @@ -31,7 +31,7 @@ def ipv4_validator(v: Any) -> IPv4Address: IPV4Address: Final[Annotated] = Annotated[IPv4Address, BeforeValidator(ipv4_validator)] """ -IPv4Address with with IPv4Address with with pre-validation and auto-conversion from str using ipv4_validator.. +IPv4Address with pre-validation and auto-conversion from str using ipv4_validator.. This type is essentially an IPv4Address from the standard library's ipaddress module, but with added validation logic. If you use this custom type, the ipv4_validator function diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml index 42400253..8aa97a6b 100644 --- a/tests/assets/configs/basic_switched_network.yaml +++ b/tests/assets/configs/basic_switched_network.yaml @@ -205,8 +205,6 @@ simulation: port_scan_p_of_success: 0.8 services: - type: DNSClient - options: - dns_server: 192.168.1.10 - type: DNSServer options: domain_mapping: diff --git a/tests/assets/configs/data_manipulation.yaml b/tests/assets/configs/data_manipulation.yaml index d604192e..a2d9bb55 100644 --- a/tests/assets/configs/data_manipulation.yaml +++ b/tests/assets/configs/data_manipulation.yaml @@ -33,7 +33,7 @@ agents: observation_space: null action_space: action_list: - - type: DONOTHING + - type: do_nothing - type: NODE_APPLICATION_EXECUTE options: nodes: @@ -47,7 +47,7 @@ agents: max_applications_per_node: 2 action_map: 0: - action: DONOTHING + action: do_nothing options: {} 1: action: NODE_APPLICATION_EXECUTE @@ -82,7 +82,7 @@ agents: observation_space: null action_space: action_list: - - type: DONOTHING + - type: do_nothing - type: NODE_APPLICATION_EXECUTE options: nodes: @@ -96,7 +96,7 @@ agents: max_applications_per_node: 2 action_map: 0: - action: DONOTHING + action: do_nothing options: {} 1: action: NODE_APPLICATION_EXECUTE @@ -132,7 +132,7 @@ agents: action_space: action_list: - - type: DONOTHING + - type: do_nothing - type: NODE_APPLICATION_EXECUTE options: nodes: @@ -235,7 +235,7 @@ agents: action_space: action_list: - - type: DONOTHING + - type: do_nothing - type: NODE_SERVICE_SCAN - type: NODE_SERVICE_STOP - type: NODE_SERVICE_START @@ -265,7 +265,7 @@ agents: action_map: 0: - action: DONOTHING + action: do_nothing options: {} # scan webapp service 1: diff --git a/tests/assets/configs/firewall_actions_network.yaml b/tests/assets/configs/firewall_actions_network.yaml index 2d42e4c5..d88942a8 100644 --- a/tests/assets/configs/firewall_actions_network.yaml +++ b/tests/assets/configs/firewall_actions_network.yaml @@ -96,155 +96,158 @@ agents: action_space: action_list: - - type: DONOTHING - - type: FIREWALL_ACL_ADDRULE - - type: FIREWALL_ACL_REMOVERULE - - type: NETWORK_PORT_DISABLE - - type: NETWORK_PORT_ENABLE + - type: do_nothing + - type: firewall_acl_add_rule + - type: firewall_acl_remove_rule + - type: network_port_disable + - type: network_port_enable action_map: 0: - action: DONOTHING + action: do_nothing options: {} 1: - action: FIREWALL_ACL_ADDRULE + action: firewall_acl_add_rule options: + type: firewall_acl_add_rule target_firewall_nodename: firewall firewall_port_name: internal firewall_port_direction: inbound position: 1 - permission: 1 - source_ip_id: 2 # client 1 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: PERMIT + src_ip: 192.168.0.10 + dst_ip: 0.0.0.0 + src_port: 80 + dst_port: HTTP + protocol_name: TCP + src_wildcard: 0 + dst_wildcard: 0 2: - action: FIREWALL_ACL_REMOVERULE + action: firewall_acl_remove_rule options: target_firewall_nodename: firewall firewall_port_name: internal firewall_port_direction: inbound position: 1 3: - action: FIREWALL_ACL_ADDRULE + action: firewall_acl_add_rule options: target_firewall_nodename: firewall firewall_port_name: internal firewall_port_direction: outbound position: 1 - permission: 2 - source_ip_id: 2 # client 1 - dest_ip_id: 1 # ALL - source_port_id: 2 - dest_port_id: 3 - protocol_id: 2 + permission: DENY + src_ip: 192.168.0.10 # client 1 + dest_ip: ALL + src_port: ARP + dst_port: DNS + protocol_name: ICMP source_wildcard_id: 0 dest_wildcard_id: 0 4: - action: FIREWALL_ACL_REMOVERULE + action: firewall_acl_remove_rule options: target_firewall_nodename: firewall firewall_port_name: internal firewall_port_direction: outbound position: 1 5: - action: FIREWALL_ACL_ADDRULE + action: firewall_acl_add_rule options: target_firewall_nodename: firewall firewall_port_name: dmz firewall_port_direction: inbound position: 1 - permission: 2 - source_ip_id: 3 # dmz_server - dest_ip_id: 2 # client_1 - source_port_id: 4 - dest_port_id: 4 - protocol_id: 4 + permission: DENY + src_ip: 192.168.10.10 # dmz_server + dest_ip: 192.168.0.10 # client_1 + src_port: HTTP + dst_port: HTTP + protocol_name: UDP source_wildcard_id: 0 dest_wildcard_id: 0 6: - action: FIREWALL_ACL_REMOVERULE + action: firewall_acl_remove_rule options: target_firewall_nodename: firewall firewall_port_name: dmz firewall_port_direction: inbound position: 1 7: - action: FIREWALL_ACL_ADDRULE + action: firewall_acl_add_rule options: target_firewall_nodename: firewall firewall_port_name: dmz firewall_port_direction: outbound position: 2 - permission: 2 - source_ip_id: 3 # dmz_server - dest_ip_id: 2 # client_1 - source_port_id: 4 - dest_port_id: 4 - protocol_id: 3 + permission: DENY + src_ip: 192.168.10.10 # dmz_server + dest_ip: 192.168.0.10 # client_1 + src_port: HTTP + dst_port: HTTP + protocol_name: TCP source_wildcard_id: 0 dest_wildcard_id: 0 8: - action: FIREWALL_ACL_REMOVERULE + action: firewall_acl_remove_rule options: target_firewall_nodename: firewall firewall_port_name: dmz firewall_port_direction: outbound position: 2 9: - action: FIREWALL_ACL_ADDRULE + action: firewall_acl_add_rule options: target_firewall_nodename: firewall firewall_port_name: external firewall_port_direction: inbound position: 10 - permission: 2 - source_ip_id: 4 # external_computer - dest_ip_id: 3 # dmz - source_port_id: 5 - dest_port_id: 5 - protocol_id: 2 + permission: DENY + src_ip: 192.168.20.10 # external_computer + dest_ip: 192.168.10.10 # dmz + src_port: POSTGRES_SERVER + dst_port: POSTGRES_SERVER + protocol_name: ICMP source_wildcard_id: 0 dest_wildcard_id: 0 10: - action: FIREWALL_ACL_REMOVERULE + action: firewall_acl_remove_rule options: target_firewall_nodename: firewall firewall_port_name: external firewall_port_direction: inbound position: 10 11: - action: FIREWALL_ACL_ADDRULE + action: firewall_acl_add_rule options: target_firewall_nodename: firewall firewall_port_name: external firewall_port_direction: outbound position: 1 - permission: 2 - source_ip_id: 4 # external_computer - dest_ip_id: 2 # client_1 - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 + permission: DENY + src_ip: 192.168.20.10 # external_computer + dest_ip: 192.168.0.10 # client_1 + src_port: NONE + dst_port: NONE + protocol_name: none source_wildcard_id: 0 dest_wildcard_id: 0 12: - action: FIREWALL_ACL_REMOVERULE + action: firewall_acl_remove_rule options: target_firewall_nodename: firewall firewall_port_name: external firewall_port_direction: outbound position: 1 13: - action: NETWORK_PORT_DISABLE + action: network_port_disable options: + type: network_port_disable target_nodename: firewall port_id: 3 14: - action: NETWORK_PORT_ENABLE + action: network_port_enable options: + type: network_port_enable target_nodename: firewall port_id: 3 options: diff --git a/tests/assets/configs/fix_duration_one_item.yaml b/tests/assets/configs/fix_duration_one_item.yaml index 0252ac32..704616f6 100644 --- a/tests/assets/configs/fix_duration_one_item.yaml +++ b/tests/assets/configs/fix_duration_one_item.yaml @@ -201,8 +201,6 @@ simulation: port_scan_p_of_success: 0.8 services: - type: DNSClient - options: - dns_server: 192.168.1.10 - type: DNSServer options: domain_mapping: @@ -233,8 +231,6 @@ simulation: server_password: arcd services: - type: DNSClient - options: - dns_server: 192.168.1.10 links: - endpoint_a_hostname: switch_1 diff --git a/tests/assets/configs/nmap_network_service_recon_red_agent_config.yaml b/tests/assets/configs/nmap_network_service_recon_red_agent_config.yaml index c5508f13..ec50ecdf 100644 --- a/tests/assets/configs/nmap_network_service_recon_red_agent_config.yaml +++ b/tests/assets/configs/nmap_network_service_recon_red_agent_config.yaml @@ -34,15 +34,16 @@ agents: max_services_per_node: 1 max_applications_per_node: 1 action_list: - - type: NODE_NMAP_NETWORK_SERVICE_RECON + - type: node_network_service_recon action_map: 0: - action: NODE_NMAP_NETWORK_SERVICE_RECON + action: node_network_service_recon options: source_node: client_1 target_ip_address: 192.168.10.0/24 target_port: 80 target_protocol: tcp + show: false reward_function: reward_components: diff --git a/tests/assets/configs/nmap_ping_scan_red_agent_config.yaml b/tests/assets/configs/nmap_ping_scan_red_agent_config.yaml index 33ba3d19..eb7b6752 100644 --- a/tests/assets/configs/nmap_ping_scan_red_agent_config.yaml +++ b/tests/assets/configs/nmap_ping_scan_red_agent_config.yaml @@ -34,13 +34,14 @@ agents: max_services_per_node: 1 max_applications_per_node: 1 action_list: - - type: NODE_NMAP_PING_SCAN + - type: node_nmap_ping_scan action_map: 0: - action: NODE_NMAP_PING_SCAN + action: node_nmap_ping_scan options: - source_node: client_1 + node_name: client_1 target_ip_address: 192.168.1.0/24 + show: False reward_function: reward_components: diff --git a/tests/assets/configs/nmap_port_scan_red_agent_config.yaml b/tests/assets/configs/nmap_port_scan_red_agent_config.yaml index 8ed715c1..15e2cb6a 100644 --- a/tests/assets/configs/nmap_port_scan_red_agent_config.yaml +++ b/tests/assets/configs/nmap_port_scan_red_agent_config.yaml @@ -34,19 +34,21 @@ agents: max_services_per_node: 1 max_applications_per_node: 1 action_list: - - type: NODE_NMAP_PORT_SCAN + - type: node_nmap_port_scan action_map: 0: - action: NODE_NMAP_PORT_SCAN + action: node_nmap_port_scan options: source_node: client_1 target_ip_address: 192.168.10.0/24 + target_protocol: tcp target_port: - 21 - 53 - 80 - 123 - 219 + show: false reward_function: reward_components: diff --git a/tests/assets/configs/software_fix_duration.yaml b/tests/assets/configs/software_fix_duration.yaml index 98260fe3..d57b88dd 100644 --- a/tests/assets/configs/software_fix_duration.yaml +++ b/tests/assets/configs/software_fix_duration.yaml @@ -210,7 +210,6 @@ simulation: services: - type: DNSClient options: - dns_server: 192.168.1.10 fix_duration: 3 - type: DNSServer options: @@ -251,8 +250,6 @@ simulation: server_password: arcd services: - type: DNSClient - options: - dns_server: 192.168.1.10 links: - endpoint_a_hostname: switch_1 diff --git a/tests/conftest.py b/tests/conftest.py index f4630c9a..c8c5e694 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -415,85 +415,58 @@ def game_and_agent(): install_stuff_to_sim(sim) actions = [ - {"type": "DONOTHING"}, - {"type": "NODE_SERVICE_SCAN"}, - {"type": "NODE_SERVICE_STOP"}, - {"type": "NODE_SERVICE_START"}, - {"type": "NODE_SERVICE_PAUSE"}, - {"type": "NODE_SERVICE_RESUME"}, - {"type": "NODE_SERVICE_RESTART"}, - {"type": "NODE_SERVICE_DISABLE"}, - {"type": "NODE_SERVICE_ENABLE"}, - {"type": "NODE_SERVICE_FIX"}, - {"type": "NODE_APPLICATION_EXECUTE"}, - {"type": "NODE_APPLICATION_SCAN"}, - {"type": "NODE_APPLICATION_CLOSE"}, - {"type": "NODE_APPLICATION_FIX"}, - {"type": "NODE_APPLICATION_INSTALL"}, - {"type": "NODE_APPLICATION_REMOVE"}, - {"type": "NODE_FILE_CREATE"}, - {"type": "NODE_FILE_SCAN"}, - {"type": "NODE_FILE_CHECKHASH"}, - {"type": "NODE_FILE_DELETE"}, - {"type": "NODE_FILE_REPAIR"}, - {"type": "NODE_FILE_RESTORE"}, - {"type": "NODE_FILE_CORRUPT"}, - {"type": "NODE_FILE_ACCESS"}, - {"type": "NODE_FOLDER_CREATE"}, - {"type": "NODE_FOLDER_SCAN"}, - {"type": "NODE_FOLDER_CHECKHASH"}, - {"type": "NODE_FOLDER_REPAIR"}, - {"type": "NODE_FOLDER_RESTORE"}, - {"type": "NODE_OS_SCAN"}, - {"type": "NODE_SHUTDOWN"}, - {"type": "NODE_STARTUP"}, - {"type": "NODE_RESET"}, - {"type": "ROUTER_ACL_ADDRULE"}, - {"type": "ROUTER_ACL_REMOVERULE"}, - {"type": "HOST_NIC_ENABLE"}, - {"type": "HOST_NIC_DISABLE"}, - {"type": "NETWORK_PORT_ENABLE"}, - {"type": "NETWORK_PORT_DISABLE"}, - {"type": "CONFIGURE_C2_BEACON"}, - {"type": "C2_SERVER_RANSOMWARE_LAUNCH"}, - {"type": "C2_SERVER_RANSOMWARE_CONFIGURE"}, - {"type": "C2_SERVER_TERMINAL_COMMAND"}, - {"type": "C2_SERVER_DATA_EXFILTRATE"}, - {"type": "NODE_ACCOUNTS_CHANGE_PASSWORD"}, - {"type": "SSH_TO_REMOTE"}, - {"type": "SESSIONS_REMOTE_LOGOFF"}, - {"type": "NODE_SEND_REMOTE_COMMAND"}, + {"type": "do_nothing"}, + {"type": "node_service_scan"}, + {"type": "node_service_stop"}, + {"type": "node_service_start"}, + {"type": "node_service_pause"}, + {"type": "node_service_resume"}, + {"type": "node_service_restart"}, + {"type": "node_service_disable"}, + {"type": "node_service_enable"}, + {"type": "node_service_fix"}, + {"type": "node_application_execute"}, + {"type": "node_application_scan"}, + {"type": "node_application_close"}, + {"type": "node_application_fix"}, + {"type": "node_application_install"}, + {"type": "node_application_remove"}, + {"type": "node_file_create"}, + {"type": "node_file_scan"}, + {"type": "node_file_checkhash"}, + {"type": "node_file_delete"}, + {"type": "node_file_repair"}, + {"type": "node_file_restore"}, + {"type": "node_file_corrupt"}, + {"type": "node_file_access"}, + {"type": "node_folder_create"}, + {"type": "node_folder_scan"}, + {"type": "node_folder_checkhash"}, + {"type": "node_folder_repair"}, + {"type": "node_folder_restore"}, + {"type": "node_os_scan"}, + {"type": "node_shutdown"}, + {"type": "node_startup"}, + {"type": "node_reset"}, + {"type": "router_acl_add_rule"}, + {"type": "router_acl_remove_rule"}, + {"type": "host_nic_enable"}, + {"type": "host_nic_disable"}, + {"type": "network_port_enable"}, + {"type": "network_port_disable"}, + {"type": "configure_c2_beacon"}, + {"type": "c2_server_ransomware_launch"}, + {"type": "c2_server_ransomware_configure"}, + {"type": "c2_server_terminal_command"}, + {"type": "c2_server_data_exfiltrate"}, + {"type": "node_account_change_password"}, + {"type": "node_session_remote_login"}, + {"type": "node_session_remote_logoff"}, + {"type": "node_send_remote_command"}, ] action_space = ActionManager( actions=actions, # ALL POSSIBLE ACTIONS - nodes=[ - { - "node_name": "client_1", - "applications": [ - {"application_name": "WebBrowser"}, - {"application_name": "DoSBot"}, - {"application_name": "C2Server"}, - ], - "folders": [{"folder_name": "downloads", "files": [{"file_name": "cat.png"}]}], - }, - { - "node_name": "server_1", - "services": [{"service_name": "DNSServer"}], - "applications": [{"application_name": "C2Beacon"}], - }, - {"node_name": "server_2", "services": [{"service_name": "WebServer"}]}, - {"node_name": "router"}, - ], - max_folders_per_node=2, - max_files_per_folder=2, - max_services_per_node=2, - max_applications_per_node=3, - max_nics_per_node=2, - max_acl_rules=10, - protocols=["TCP", "UDP", "ICMP"], - ports=["HTTP", "DNS", "ARP"], - ip_list=["10.0.1.1", "10.0.1.2", "10.0.2.1", "10.0.2.2", "10.0.2.3"], act_map={}, ) observation_space = ObservationManager(NestedObservation(components={})) diff --git a/tests/integration_tests/extensions/test_extendable_config.py b/tests/integration_tests/extensions/test_extendable_config.py index 5515d900..8e73f929 100644 --- a/tests/integration_tests/extensions/test_extendable_config.py +++ b/tests/integration_tests/extensions/test_extendable_config.py @@ -5,6 +5,7 @@ from primaite.config.load import get_extended_config_path from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer +from tests import TEST_ASSETS_ROOT from tests.integration_tests.configuration_file_parsing import BASIC_CONFIG, DMZ_NETWORK, load_config from tests.integration_tests.extensions.applications.extended_application import ExtendedApplication from tests.integration_tests.extensions.nodes.giga_switch import GigaSwitch @@ -13,11 +14,12 @@ from tests.integration_tests.extensions.nodes.giga_switch import GigaSwitch from tests.integration_tests.extensions.nodes.super_computer import SuperComputer from tests.integration_tests.extensions.services.extended_service import ExtendedService +CONFIG_PATH = TEST_ASSETS_ROOT / "configs/extended_config.yaml" + def test_extended_example_config(): """Test that the example config can be parsed properly.""" - config_path = os.path.join("tests", "assets", "configs", "extended_config.yaml") - game = load_config(config_path) + game = load_config(CONFIG_PATH) network: Network = game.simulation.network assert len(network.nodes) == 10 # 10 nodes in example network diff --git a/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py b/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py index 7930304e..36fee9a0 100644 --- a/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py +++ b/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py @@ -134,7 +134,7 @@ def test_c2_server_ransomware(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyA # Stepping a few timesteps to allow for the RansowmareScript to finish installing. - action = ("DONOTHING", {}) + action = ("do_nothing", {}) agent.store_action(action) game.step() game.step() diff --git a/tests/integration_tests/game_layer/actions/test_configure_actions.py b/tests/integration_tests/game_layer/actions/test_configure_actions.py index 4e72e60d..338bd049 100644 --- a/tests/integration_tests/game_layer/actions/test_configure_actions.py +++ b/tests/integration_tests/game_layer/actions/test_configure_actions.py @@ -4,7 +4,7 @@ from ipaddress import IPv4Address import pytest from pydantic import ValidationError -from primaite.game.agent.actions import ( +from primaite.game.agent.actions.software import ( ConfigureDatabaseClientAction, ConfigureDoSBotAction, ConfigureRansomwareScriptAction, @@ -35,10 +35,10 @@ class TestConfigureDatabaseAction: db_client: DatabaseClient = client_1.software_manager.software["DatabaseClient"] action = ( - "CONFIGURE_DATABASE_CLIENT", + "configure_database_client", { - "node_id": 0, - "config": { + "node_name": "client_1", + "model_config": { "server_ip_address": "192.168.1.99", "server_password": "admin123", }, @@ -53,7 +53,7 @@ class TestConfigureDatabaseAction: def test_configure_ip(self, game_and_agent): game, agent = game_and_agent agent: ControlledAgent - agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager) + agent.action_manager.actions["configure_database_client"] = ConfigureDatabaseClientAction(agent.action_manager) # make sure there is a database client on this node client_1 = game.simulation.network.get_node_by_hostname("client_1") diff --git a/tests/integration_tests/game_layer/actions/test_node_request_permission.py b/tests/integration_tests/game_layer/actions/test_node_request_permission.py index d75567e9..8fbbbd70 100644 --- a/tests/integration_tests/game_layer/actions/test_node_request_permission.py +++ b/tests/integration_tests/game_layer/actions/test_node_request_permission.py @@ -36,7 +36,7 @@ def test_node_startup_shutdown(game_and_agent_fixture: Tuple[PrimaiteGame, Proxy assert client_1.operating_state == NodeOperatingState.SHUTTING_DOWN for i in range(client_1.shut_down_duration + 1): - action = ("DONOTHING", {"node_id": 0}) + action = ("do_nothing", {"node_id": 0}) agent.store_action(action) game.step() @@ -50,7 +50,7 @@ def test_node_startup_shutdown(game_and_agent_fixture: Tuple[PrimaiteGame, Proxy assert client_1.operating_state == NodeOperatingState.BOOTING for i in range(client_1.start_up_duration + 1): - action = ("DONOTHING", {"node_id": 0}) + action = ("do_nothing", {"node_id": 0}) agent.store_action(action) game.step() @@ -80,7 +80,7 @@ def test_node_cannot_be_shut_down_if_node_is_already_off(game_and_agent_fixture: client_1.power_off() for i in range(client_1.shut_down_duration + 1): - action = ("DONOTHING", {"node_id": 0}) + action = ("do_nothing", {"node_id": 0}) agent.store_action(action) game.step() diff --git a/tests/integration_tests/game_layer/test_RNG_seed.py b/tests/integration_tests/game_layer/test_RNG_seed.py index 45ad2708..464f95db 100644 --- a/tests/integration_tests/game_layer/test_RNG_seed.py +++ b/tests/integration_tests/game_layer/test_RNG_seed.py @@ -24,12 +24,12 @@ def test_rng_seed_set(create_env): env.reset(seed=3) for i in range(100): env.step(0) - a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"] + a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do_nothing"] env.reset(seed=3) for i in range(100): env.step(0) - b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"] + b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do_nothing"] assert a == b @@ -40,11 +40,11 @@ def test_rng_seed_unset(create_env): env.reset() for i in range(100): env.step(0) - a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"] + a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do_nothing"] env.reset() for i in range(100): env.step(0) - b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"] + b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do_nothing"] assert a != b diff --git a/tests/integration_tests/game_layer/test_action_mask.py b/tests/integration_tests/game_layer/test_action_mask.py index 3d26b73d..22c00aa4 100644 --- a/tests/integration_tests/game_layer/test_action_mask.py +++ b/tests/integration_tests/game_layer/test_action_mask.py @@ -91,7 +91,7 @@ def test_mask_contents_correct(): assert mask[action_num] node_obj.operating_state = NodeOperatingState.ON - if act_type == "DONOTHING": + if act_type == "do_nothing": assert mask[action_num] if act_type == "NODE_SERVICE_DISABLE": diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index beb7b6a8..ff86dbf0 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -32,10 +32,10 @@ FIREWALL_ACTIONS_NETWORK = TEST_ASSETS_ROOT / "configs/firewall_actions_network. def test_do_nothing_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]): - """Test that the DoNothingAction can form a request and that it is accepted by the simulation.""" + """Test that the do_nothingAction can form a request and that it is accepted by the simulation.""" game, agent = game_and_agent - action = ("DONOTHING", {}) + action = ("do_nothing", {}) agent.store_action(action) game.step() @@ -56,7 +56,7 @@ def test_node_service_scan_integration(game_and_agent: Tuple[PrimaiteGame, Proxy assert svc.health_state_visible == SoftwareHealthState.UNUSED # 2: Scan and check that the visible state is now correct - action = ("NODE_SERVICE_SCAN", {"node_id": 1, "service_id": 0}) + action = ("node_service_scan", {"type": "node_service_scan", "node_name": "server_1", "service_name": "DNSServer"}) agent.store_action(action) game.step() assert svc.health_state_actual == SoftwareHealthState.GOOD @@ -67,7 +67,7 @@ def test_node_service_scan_integration(game_and_agent: Tuple[PrimaiteGame, Proxy assert svc.health_state_visible == SoftwareHealthState.GOOD # 4: Scan and check that the visible state is now correct - action = ("NODE_SERVICE_SCAN", {"node_id": 1, "service_id": 0}) + action = ("node_service_scan", {"type": "node_service_scan", "node_name": "server_1", "service_name": "DNSServer"}) agent.store_action(action) game.step() assert svc.health_state_actual == SoftwareHealthState.COMPROMISED @@ -88,7 +88,7 @@ def test_node_service_fix_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA svc.health_state_actual = SoftwareHealthState.COMPROMISED # 2: Apply a patch action - action = ("NODE_SERVICE_FIX", {"node_id": 1, "service_id": 0}) + action = ("node_service_fix", {"type": "node_service_fix", "node_name": "server_1", "service_name": "DNSServer"}) agent.store_action(action) game.step() @@ -96,7 +96,7 @@ def test_node_service_fix_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA assert svc.health_state_actual == SoftwareHealthState.FIXING # 4: perform a few do-nothing steps and check that the service is now in the good state - action = ("DONOTHING", {}) + action = ("do_nothing", {}) agent.store_action(action) game.step() assert svc.health_state_actual == SoftwareHealthState.GOOD @@ -121,18 +121,19 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox # 2: Add a rule to block client 1 from reaching server 2 on router action = ( - "ROUTER_ACL_ADDRULE", + "router_acl_add_rule", { + "type": "router_acl_add_rule", "target_router": "router", - "position": 4, # 4th rule - "permission": 2, # DENY - "source_ip_id": 3, # 10.0.1.2 (client_1) - "dest_ip_id": 6, # 10.0.2.3 (server_2) - "dest_port_id": 1, # ALL - "source_port_id": 1, # ALL - "protocol_id": 1, # ALL - "source_wildcard_id": 0, - "dest_wildcard_id": 0, + "position": 4, + "permission": "DENY", + "src_ip": "10.0.1.2", + "src_wildcard": 0, + "src_port": "HTTP", + "dst_ip": "10.0.2.3", + "dst_wildcard": 0, + "dst_port": "HTTP", + "protocol_name": "udp", }, ) agent.store_action(action) @@ -148,24 +149,27 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox # 4: Add a rule to block server_1 from reaching server_2 on router (this should not affect comms as they are on same subnet) action = ( - "ROUTER_ACL_ADDRULE", + "router_acl_add_rule", { + "type": "router_acl_add_rule", "target_router": "router", "position": 5, # 5th rule - "permission": 2, # DENY - "source_ip_id": 5, # 10.0.2.2 (server_1) - "dest_ip_id": 6, # 10.0.2.3 (server_2) - "dest_port_id": 1, # ALL - "source_port_id": 1, # ALL - "protocol_id": 1, # ALL - "source_wildcard_id": 0, - "dest_wildcard_id": 0, + "permission": "DENY", # DENY + "src_ip": "10.0.2.2", # 10.0.2.2 (server_1) + "src_wildcard": 0, + "source_port": "ALL", # ALL + "dst_ip": "10.0.2.3", # 10.0.2.3 (server_2) + "dst_wildcard": 0, + "dst_port": "ALL", # ALL + "protocol_name": "ALL", # ALL }, ) agent.store_action(action) + print(agent.most_recent_action) game.step() - + print(agent.most_recent_action) # 5: Check that the ACL now has 6 rules, but that server_1 can still ping server_2 + print(router.acl.show()) assert router.acl.num_rules == 6 assert server_1.ping("10.0.2.3") # Can ping server_2 @@ -186,8 +190,9 @@ def test_router_acl_removerule_integration(game_and_agent: Tuple[PrimaiteGame, P # 2: Remove rule that allows HTTP traffic across the network action = ( - "ROUTER_ACL_REMOVERULE", + "router_acl_remove_rule", { + "type": "router_acl_remove_rule", "target_router": "router", "position": 3, # 4th rule }, @@ -219,10 +224,11 @@ def test_host_nic_disable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA # 2: Disable the NIC on client_1 action = ( - "HOST_NIC_DISABLE", + "host_nic_disable", { - "node_id": 0, # client_1 - "nic_id": 0, # the only nic (eth-1) + "type": "host_nic_disable", + "node_name": "client_1", # client_1 + "nic_num": 1, # the only nic (eth-1) }, ) agent.store_action(action) @@ -250,10 +256,11 @@ def test_host_nic_enable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAg # 2: Use action to enable nic action = ( - "HOST_NIC_ENABLE", + "host_nic_enable", { - "node_id": 0, # client_1 - "nic_id": 0, # the only nic (eth-1) + "type": "host_nic_enable", + "node_name": "client_1", # client_1 + "nic_num": 1, # the only nic (eth-1) }, ) agent.store_action(action) @@ -277,11 +284,12 @@ def test_node_file_scan_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAge # 2: perform a scan and make sure nothing has changed action = ( - "NODE_FILE_SCAN", + "node_file_scan", { - "node_id": 0, # client_1, - "folder_id": 0, # downloads, - "file_id": 0, # cat.png + "type": "node_file_scan", + "node_name": "client_1", # client_1, + "folder_name": "downloads", # downloads, + "file_name": "cat.png", # cat.png }, ) agent.store_action(action) @@ -314,11 +322,12 @@ def test_node_file_delete_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA # 2: delete the file action = ( - "NODE_FILE_DELETE", + "node_file_delete", { - "node_id": 0, # client_1 - "folder_id": 0, # downloads - "file_id": 0, # cat.png + "type": "node_file_delete", + "node_name": "client_1", # client_1 + "folder_name": "downloads", # downloads + "file_name": "cat.png", # cat.png }, ) agent.store_action(action) @@ -334,14 +343,16 @@ def test_node_file_create(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]): """Test that a file is created.""" game, agent = game_and_agent - client_1 = game.simulation.network.get_node_by_hostname("client_1") # + client_1 = game.simulation.network.get_node_by_hostname("client_1") action = ( - "NODE_FILE_CREATE", + "node_file_create", { - "node_id": 0, + "type": "node_file_create", + "node_name": "client_1", "folder_name": "test", "file_name": "file.txt", + "force": "False", }, ) agent.store_action(action) @@ -357,9 +368,10 @@ def test_node_file_access(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]): client_1 = game.simulation.network.get_node_by_hostname("client_1") # action = ( - "NODE_FILE_CREATE", + "node_file_create", { - "node_id": 0, + "type": "node_file_create", + "node_name": "client_1", "folder_name": "test", "file_name": "file.txt", }, @@ -370,9 +382,10 @@ def test_node_file_access(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]): assert client_1.file_system.get_file(folder_name="test", file_name="file.txt").num_access == 0 action = ( - "NODE_FILE_ACCESS", + "node_file_access", { - "node_id": 0, + "type": "node_file_access", + "node_name": "client_1", "folder_name": "test", "file_name": "file.txt", }, @@ -390,9 +403,10 @@ def test_node_folder_create(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]): client_1 = game.simulation.network.get_node_by_hostname("client_1") # action = ( - "NODE_FOLDER_CREATE", + "node_folder_create", { - "node_id": 0, + "type": "node_folder_create", + "node_name": "client_1", "folder_name": "test", }, ) @@ -418,8 +432,9 @@ def test_network_router_port_disable_integration(game_and_agent: Tuple[PrimaiteG # 2: Disable the NIC on client_1 action = ( - "NETWORK_PORT_DISABLE", + "network_port_disable", { + "type": "network_port_disable", "target_nodename": "router", # router "port_id": 1, # port 1 }, @@ -450,8 +465,9 @@ def test_network_router_port_enable_integration(game_and_agent: Tuple[PrimaiteGa # 2: Use action to enable port action = ( - "NETWORK_PORT_ENABLE", + "network_port_enable", { + "type": "network_port_enable", "target_nodename": "router", # router "port_id": 1, # port 1 }, @@ -480,7 +496,10 @@ def test_node_application_scan_integration(game_and_agent: Tuple[PrimaiteGame, P assert browser.health_state_visible == SoftwareHealthState.UNUSED # 2: Scan and check that the visible state is now correct - action = ("NODE_APPLICATION_SCAN", {"node_id": 0, "application_id": 0}) + action = ( + "node_application_scan", + {"type": "node_application_scan", "node_name": "client_1", "application_name": "WebBrowser"}, + ) agent.store_action(action) game.step() assert browser.health_state_actual == SoftwareHealthState.GOOD @@ -491,7 +510,10 @@ def test_node_application_scan_integration(game_and_agent: Tuple[PrimaiteGame, P assert browser.health_state_visible == SoftwareHealthState.GOOD # 4: Scan and check that the visible state is now correct - action = ("NODE_APPLICATION_SCAN", {"node_id": 0, "application_id": 0}) + action = ( + "node_application_scan", + {"type": "node_application_scan", "node_name": "client_1", "application_name": "WebBrowser"}, + ) agent.store_action(action) game.step() assert browser.health_state_actual == SoftwareHealthState.COMPROMISED @@ -512,7 +534,10 @@ def test_node_application_fix_integration(game_and_agent: Tuple[PrimaiteGame, Pr browser.health_state_actual = SoftwareHealthState.COMPROMISED # 2: Apply a fix action - action = ("NODE_APPLICATION_FIX", {"node_id": 0, "application_id": 0}) + action = ( + "node_application_fix", + {"type": "node_application_fix", "node_name": "client_1", "application_name": "WebBrowser"}, + ) agent.store_action(action) game.step() @@ -520,7 +545,7 @@ def test_node_application_fix_integration(game_and_agent: Tuple[PrimaiteGame, Pr assert browser.health_state_actual == SoftwareHealthState.FIXING # 4: perform a few do-nothing steps and check that the application is now in the good state - action = ("DONOTHING", {}) + action = ("do_nothing", {}) agent.store_action(action) game.step() assert browser.health_state_actual == SoftwareHealthState.GOOD @@ -538,7 +563,10 @@ def test_node_application_close_integration(game_and_agent: Tuple[PrimaiteGame, assert browser.operating_state == ApplicationOperatingState.RUNNING # 2: Apply a close action - action = ("NODE_APPLICATION_CLOSE", {"node_id": 0, "application_id": 0}) + action = ( + "node_application_close", + {"type": "node_application_close", "node_name": "client_1", "application_name": "WebBrowser"}, + ) agent.store_action(action) game.step() @@ -549,7 +577,7 @@ def test_node_application_install_and_uninstall_integration(game_and_agent: Tupl """Test that the NodeApplicationInstallAction and NodeApplicationRemoveAction can form a request and that it is accepted by the simulation. - When you initiate a install action, the Application will be installed and configured on the node. + When you initiate an install action, the Application will be installed and configured on the node. The remove action will uninstall the application from the node.""" game, agent = game_and_agent @@ -557,13 +585,19 @@ def test_node_application_install_and_uninstall_integration(game_and_agent: Tupl assert client_1.software_manager.software.get("DoSBot") is None - action = ("NODE_APPLICATION_INSTALL", {"node_id": 0, "application_name": "DoSBot"}) + action = ( + "node_application_install", + {"type": "node_application_install", "node_name": "client_1", "application_name": "DoSBot"}, + ) agent.store_action(action) game.step() assert client_1.software_manager.software.get("DoSBot") is not None - action = ("NODE_APPLICATION_REMOVE", {"node_id": 0, "application_name": "DoSBot"}) + action = ( + "node_application_remove", + {"type": "node_application_remove", "node_name": "client_1", "application_name": "DoSBot"}, + ) agent.store_action(action) game.step() @@ -656,9 +690,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.external_outbound_acl.acl[1].action.name == "DENY" assert firewall.external_outbound_acl.acl[1].src_ip_address == IPv4Address("192.168.20.10") assert firewall.external_outbound_acl.acl[1].dst_ip_address == IPv4Address("192.168.0.10") - assert firewall.external_outbound_acl.acl[1].dst_port is None - assert firewall.external_outbound_acl.acl[1].src_port is None - assert firewall.external_outbound_acl.acl[1].protocol is None + assert firewall.external_outbound_acl.acl[1].dst_port == PORT_LOOKUP["NONE"] + assert firewall.external_outbound_acl.acl[1].src_port == PORT_LOOKUP["NONE"] + assert firewall.external_outbound_acl.acl[1].protocol == PROTOCOL_LOOKUP["NONE"] env.step(12) # Remove ACL rule from External Outbound assert firewall.external_outbound_acl.num_rules == 1 diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index dc7ed132..1648d685 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -18,12 +18,14 @@ from tests import TEST_ASSETS_ROOT from tests.conftest import ControlledAgent -def test_WebpageUnavailablePenalty(game_and_agent): +def test_WebpageUnavailablePenalty(game_and_agent: tuple[PrimaiteGame, ControlledAgent]): """Test that we get the right reward for failing to fetch a website.""" # set up the scenario, configure the web browser to the correct url game, agent = game_and_agent agent: ControlledAgent - comp = WebpageUnavailablePenalty(node_hostname="client_1") + schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="client_1", sticky=True) + comp = WebpageUnavailablePenalty(config=schema) + client_1 = game.simulation.network.get_node_by_hostname("client_1") browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") browser.run() @@ -31,7 +33,7 @@ def test_WebpageUnavailablePenalty(game_and_agent): agent.reward_function.register_component(comp, 0.7) # Check that before trying to fetch the webpage, the reward is 0.0 - agent.store_action(("DONOTHING", {})) + agent.store_action(("do_nothing", {})) game.step() assert agent.reward_function.current_reward == 0.0 @@ -53,7 +55,7 @@ def test_WebpageUnavailablePenalty(game_and_agent): assert agent.reward_function.current_reward == -0.7 -def test_uc2_rewards(game_and_agent): +def test_uc2_rewards(game_and_agent: tuple[PrimaiteGame, ControlledAgent]): """Test that the reward component correctly applies a penalty when the selected client cannot reach the database.""" game, agent = game_and_agent agent: ControlledAgent @@ -74,7 +76,8 @@ def test_uc2_rewards(game_and_agent): ACLAction.PERMIT, src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"], position=2 ) - comp = GreenAdminDatabaseUnreachablePenalty("client_1") + schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema(node_hostname="client_1", sticky=True) + comp = GreenAdminDatabaseUnreachablePenalty(config=schema) request = ["network", "node", "client_1", "application", "DatabaseClient", "execute"] response = game.simulation.apply_request(request) @@ -139,17 +142,19 @@ def test_action_penalty_loads_from_config(): act_penalty_obj = comp[0] if act_penalty_obj is None: pytest.fail("Action penalty reward component was not added to the agent from config.") - assert act_penalty_obj.action_penalty == -0.75 - assert act_penalty_obj.do_nothing_penalty == 0.125 + assert act_penalty_obj.config.action_penalty == -0.75 + assert act_penalty_obj.config.do_nothing_penalty == 0.125 def test_action_penalty(): """Test that the action penalty is correctly applied when agent performs any action""" # Create an ActionPenalty Reward - Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125) + schema = ActionPenalty.ConfigSchema(action_penalty=-0.75, do_nothing_penalty=0.125) + # Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125) + Penalty = ActionPenalty(config=schema) - # Assert that penalty is applied if action isn't DONOTHING + # Assert that penalty is applied if action isn't do_nothing reward_value = Penalty.calculate( state={}, last_action_response=AgentHistoryItem( @@ -163,12 +168,12 @@ def test_action_penalty(): assert reward_value == -0.75 - # Assert that no penalty applied for a DONOTHING action + # Assert that no penalty applied for a do_nothing action reward_value = Penalty.calculate( state={}, last_action_response=AgentHistoryItem( timestep=0, - action="DONOTHING", + action="do_nothing", parameters={}, request=["do_nothing"], response=RequestResponse.from_bool(True), @@ -178,15 +183,16 @@ def test_action_penalty(): assert reward_value == 0.125 -def test_action_penalty_e2e(game_and_agent): +def test_action_penalty_e2e(game_and_agent: tuple[PrimaiteGame, ControlledAgent]): """Test that we get the right reward for doing actions to fetch a website.""" game, agent = game_and_agent agent: ControlledAgent - comp = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125) + schema = ActionPenalty.ConfigSchema(action_penalty=-0.75, do_nothing_penalty=0.125) + comp = ActionPenalty(config=schema) agent.reward_function.register_component(comp, 1.0) - action = ("DONOTHING", {}) + action = ("do_nothing", {}) agent.store_action(action) game.step() assert agent.reward_function.current_reward == 0.125 diff --git a/tests/unit_tests/_primaite/_game/_agent/test_actions.py b/tests/unit_tests/_primaite/_game/_agent/test_actions.py index 63ea1b07..79cf7e4b 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_actions.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_actions.py @@ -3,9 +3,11 @@ from unittest.mock import Mock import pytest -from primaite.game.agent.actions import ( +from primaite.game.agent.actions import ( # DoNothingAction,; NodeServiceDisableAction,; NodeServiceEnableAction,; NodeServicePauseAction,; NodeServiceRestartAction,; NodeServiceResumeAction,; NodeServiceScanAction,; NodeServiceStartAction,; NodeServiceStopAction, ActionManager, - DoNothingAction, +) +from primaite.game.agent.actions.manager import DoNothingAction +from primaite.game.agent.actions.service import ( NodeServiceDisableAction, NodeServiceEnableAction, NodeServicePauseAction, @@ -18,7 +20,7 @@ from primaite.game.agent.actions import ( def test_do_nothing_action_form_request(): - """Test that the DoNothingAction can form a request and that it is correct.""" + """Test that the do_nothingAction can form a request and that it is correct.""" manager = Mock() action = DoNothingAction(manager=manager) diff --git a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py index 7035e98f..7824e71e 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py @@ -28,9 +28,9 @@ def test_probabilistic_agent(): action_space_cfg = { "action_list": [ - {"type": "DONOTHING"}, - {"type": "NODE_APPLICATION_EXECUTE"}, - {"type": "NODE_FILE_DELETE"}, + {"type": "do_nothing"}, + {"type": "node_application_execute"}, + {"type": "node_file_delete"}, ], "nodes": [ { @@ -48,9 +48,9 @@ def test_probabilistic_agent(): "protocols": ["TCP", "UDP", "ICMP"], "ports": ["HTTP", "DNS", "ARP"], "act_map": { - 0: {"action": "DONOTHING", "options": {}}, - 1: {"action": "NODE_APPLICATION_EXECUTE", "options": {"node_id": 0, "application_id": 0}}, - 2: {"action": "NODE_FILE_DELETE", "options": {"node_id": 0, "folder_id": 0, "file_id": 0}}, + 0: {"action": "do_nothing", "options": {}}, + 1: {"action": "node_application_execute", "options": {"node_id": 0, "application_id": 0}}, + 2: {"action": "node_file_delete", "options": {"node_id": 0, "folder_id": 0, "file_id": 0}}, }, "options": {}, } @@ -80,11 +80,11 @@ def test_probabilistic_agent(): node_file_delete_count = 0 for _ in range(N_TRIALS): a = pa.get_action(0) - if a == ("DONOTHING", {}): + if a == ("do_nothing", {}): do_nothing_count += 1 - elif a == ("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}): + elif a == ("node_application_execute", {"node_name": "client_1", "application_name": "WebBrowser"}): node_application_execute_count += 1 - elif a == ("NODE_FILE_DELETE", {"node_id": 0, "folder_id": 0, "file_id": 0}): + elif a == ("node_file_delete", {"node_name": "client_1", "folder_name": "downloads", "file_name": "cat.png"}): node_file_delete_count += 1 else: raise AssertionError("Probabilistic agent produced an unexpected action.") diff --git a/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py b/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py index 0e4bf1bb..91d5c607 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py @@ -11,7 +11,12 @@ from primaite.interface.request import RequestResponse class TestWebServer404PenaltySticky: def test_non_sticky(self): - reward = WebServer404Penalty("computer", "WebService", sticky=False) + schema = WebServer404Penalty.ConfigSchema( + node_hostname="computer", + service_name="WebService", + sticky=False, + ) + reward = WebServer404Penalty(config=schema) # no response codes yet, reward is 0 codes = [] @@ -38,7 +43,12 @@ class TestWebServer404PenaltySticky: assert reward.calculate(state, last_action_response) == -1.0 def test_sticky(self): - reward = WebServer404Penalty("computer", "WebService", sticky=True) + schema = WebServer404Penalty.ConfigSchema( + node_hostname="computer", + service_name="WebService", + sticky=True, + ) + reward = WebServer404Penalty(config=schema) # no response codes yet, reward is 0 codes = [] @@ -67,10 +77,11 @@ class TestWebServer404PenaltySticky: class TestWebpageUnavailabilitySticky: def test_non_sticky(self): - reward = WebpageUnavailablePenalty("computer", sticky=False) + schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="computer", sticky=False) + reward = WebpageUnavailablePenalty(config=schema) # no response codes yet, reward is 0 - action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + action, params, request = "do_nothing", {}, ["do_nothing"] response = RequestResponse(status="success", data={}) browser_history = [] state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} @@ -93,7 +104,7 @@ class TestWebpageUnavailabilitySticky: # THE IMPORTANT BIT # agent did nothing, because reward is not sticky, it goes back to 0 - action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + action, params, request = "DO_NOTHING", {}, ["do_nothing"] response = RequestResponse(status="success", data={}) browser_history = [] state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} @@ -127,10 +138,11 @@ class TestWebpageUnavailabilitySticky: assert reward.calculate(state, last_action_response) == -1.0 def test_sticky(self): - reward = WebpageUnavailablePenalty("computer", sticky=True) + schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="computer", sticky=True) + reward = WebpageUnavailablePenalty(config=schema) # no response codes yet, reward is 0 - action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + action, params, request = "DO_NOTHING", {}, ["do_nothing"] response = RequestResponse(status="success", data={}) browser_history = [] state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} @@ -153,7 +165,7 @@ class TestWebpageUnavailabilitySticky: # THE IMPORTANT BIT # agent did nothing, because reward is sticky, it stays at 1.0 - action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + action, params, request = "DO_NOTHING", {}, ["do_nothing"] response = RequestResponse(status="success", data={}) state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} last_action_response = AgentHistoryItem( @@ -188,10 +200,14 @@ class TestWebpageUnavailabilitySticky: class TestGreenAdminDatabaseUnreachableSticky: def test_non_sticky(self): - reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=False) + schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema( + node_hostname="computer", + sticky=False, + ) + reward = GreenAdminDatabaseUnreachablePenalty(config=schema) # no response codes yet, reward is 0 - action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + action, params, request = "DO_NOTHING", {}, ["do_nothing"] response = RequestResponse(status="success", data={}) state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} last_action_response = AgentHistoryItem( @@ -212,9 +228,8 @@ class TestGreenAdminDatabaseUnreachableSticky: # THE IMPORTANT BIT # agent did nothing, because reward is not sticky, it goes back to 0 - action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + action, params, request = "DO_NOTHING", {}, ["do_nothing"] response = RequestResponse(status="success", data={}) - browser_history = [] state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response @@ -244,10 +259,14 @@ class TestGreenAdminDatabaseUnreachableSticky: assert reward.calculate(state, last_action_response) == -1.0 def test_sticky(self): - reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=True) + schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema( + node_hostname="computer", + sticky=True, + ) + reward = GreenAdminDatabaseUnreachablePenalty(config=schema) # no response codes yet, reward is 0 - action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + action, params, request = "DO_NOTHING", {}, ["do_nothing"] response = RequestResponse(status="success", data={}) state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} last_action_response = AgentHistoryItem( @@ -268,7 +287,7 @@ class TestGreenAdminDatabaseUnreachableSticky: # THE IMPORTANT BIT # agent did nothing, because reward is not sticky, it goes back to 0 - action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + action, params, request = "DO_NOTHING", {}, ["do_nothing"] response = RequestResponse(status="success", data={}) state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} last_action_response = AgentHistoryItem(