diff --git a/CHANGELOG.md b/CHANGELOG.md index fb7650de..ccaa411a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,7 +31,8 @@ SessionManager. - `DatabaseClient` and `DatabaseService` created to allow emulation of database actions - Ability for `DatabaseService` to backup its data to another server via FTP and restore data from backup - Red Agent Services: - - Data Manipulator Bot - A red agent service which sends a payload to a target machine. (By default this payload is a SQL query that breaks a database) + - Data Manipulator Bot - A red agent service which sends a payload to a target machine. (By default this payload is a SQL query that breaks a database). The attack runs in stages with a random, configurable probability of succeeding. + - `DataManipulationAgent` runs the Data Manipulator Bot according to a configured start step, frequency and variance. - DNS Services: `DNSClient` and `DNSServer` - FTP Services: `FTPClient` and `FTPServer` - HTTP Services: `WebBrowser` to simulate a web client and `WebServer` diff --git a/docs/source/simulation_components/system/data_manipulation_bot.rst b/docs/source/simulation_components/system/data_manipulation_bot.rst index 53484cac..5180974f 100644 --- a/docs/source/simulation_components/system/data_manipulation_bot.rst +++ b/docs/source/simulation_components/system/data_manipulation_bot.rst @@ -18,19 +18,28 @@ The bot is intended to simulate a malicious actor carrying out attacks like: - Modifying data on a database server by abusing an application's trusted database connectivity. +The bot performs attacks in the following stages to simulate the real pattern of an attack: + +- Logon - *The bot gains credentials and accesses the node.* +- Port Scan - *The bot finds accessible database servers on the network.* +- Attacking - *The bot delivers the payload to the discovered database servers.* + +Each of these stages has a random, configurable probability of succeeding (by default 10%). The bot can also be configured to repeat the attack once complete. + Usage ----- - Create an instance and call ``configure`` to set: - - - Target database server IP - - Database password (if needed) - - SQL statement payload - + - Target database server IP + - Database password (if needed) + - SQL statement payload + - Probabilities for succeeding each of the above attack stages - Call ``run`` to connect and execute the statement. The bot handles connecting, executing the statement, and disconnecting. +In a simulation, the bot can be controlled by using ``DataManipulationAgent`` which calls ``run`` on the bot at configured timesteps. + Example ------- @@ -51,13 +60,81 @@ Example This would connect to the database service at 192.168.1.14, authenticate, and execute the SQL statement to drop the 'users' table. +Example with ``DataManipulationAgent`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If not using the data manipulation bot manually, it needs to be used with a data manipulation agent. Below is an example section of configuration file for setting up a simulation with data manipulation bot and agent. + +.. code-block:: yaml + + game_config: + # ... + agents: + - ref: data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + applications: + - application_ref: data_manipulation_bot + observations: + operating_status + health_status + folders: {} + + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_ref: client_1 + applications: + - application_ref: data_manipulation_bot + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_settings: + start_step: 25 + frequency: 20 + variance: 5 + # ... + + simulation: + network: + nodes: + - ref: client_1 + type: computer + # ... additional configuration here + applications: + - ref: data_manipulation_bot + type: DataManipulationBot + options: + port_scan_p_of_success: 0.1 + data_manipulation_p_of_success: 0.1 + payload: "DELETE" + server_ip: 192.168.1.14 + Implementation -------------- The bot extends ``DatabaseClient`` and leverages its connectivity. - Uses the Application base class for lifecycle management. -- Credentials and target IP set via ``configure``. +- Credentials, target IP and other options set via ``configure``. - ``run`` handles connecting, executing statement, and disconnecting. - SQL payload executed via ``query`` method. - Results in malicious SQL being executed on remote database server. diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index f167dc2f..b68861e1 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -1,5 +1,5 @@ training_config: - rl_framework: RLLIB_single_agent + rl_framework: SB3 rl_algorithm: PPO seed: 333 n_learn_episodes: 1 @@ -36,31 +36,26 @@ agents: action_space: action_list: - type: DONOTHING - # - # - type: NODE_LOGON - # - type: NODE_LOGOFF - # - type: NODE_APPLICATION_EXECUTE - # options: - # execution_definition: - # target_address: arcd.com - + - type: NODE_APPLICATION_EXECUTE options: nodes: - node_ref: client_2 + applications: + - application_ref: client_2_web_browser max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 + max_applications_per_node: 1 reward_function: reward_components: - type: DUMMY agent_settings: - start_step: 5 - frequency: 4 - variance: 3 + start_settings: + start_step: 5 + frequency: 4 + variance: 3 - ref: client_1_data_manipulation_red_bot team: RED @@ -69,38 +64,20 @@ agents: observation_space: type: UC2RedObservation options: - nodes: - - node_ref: client_1 - observations: - - logon_status - - operating_status - services: - - service_ref: data_manipulation_bot - observations: - operating_status - health_status - folders: {} + nodes: {} action_space: action_list: - type: DONOTHING - # None: super().__init__(manager=manager) self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services} - self.verb: str + self.verb: str # define but don't initialise: defends against children classes not defining this def form_request(self, node_id: int, service_id: int) -> List[str]: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" @@ -98,7 +97,7 @@ class NodeServiceScanAction(NodeServiceAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb = "scan" + self.verb: str = "scan" class NodeServiceStopAction(NodeServiceAbstractAction): @@ -106,7 +105,7 @@ class NodeServiceStopAction(NodeServiceAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb = "stop" + self.verb: str = "stop" class NodeServiceStartAction(NodeServiceAbstractAction): @@ -114,7 +113,7 @@ class NodeServiceStartAction(NodeServiceAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb = "start" + self.verb: str = "start" class NodeServicePauseAction(NodeServiceAbstractAction): @@ -122,7 +121,7 @@ class NodeServicePauseAction(NodeServiceAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb = "pause" + self.verb: str = "pause" class NodeServiceResumeAction(NodeServiceAbstractAction): @@ -130,7 +129,7 @@ class NodeServiceResumeAction(NodeServiceAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb = "resume" + self.verb: str = "resume" class NodeServiceRestartAction(NodeServiceAbstractAction): @@ -138,7 +137,7 @@ class NodeServiceRestartAction(NodeServiceAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb = "restart" + self.verb: str = "restart" class NodeServiceDisableAction(NodeServiceAbstractAction): @@ -146,7 +145,7 @@ class NodeServiceDisableAction(NodeServiceAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb = "disable" + self.verb: str = "disable" class NodeServiceEnableAction(NodeServiceAbstractAction): @@ -154,7 +153,38 @@ class NodeServiceEnableAction(NodeServiceAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb = "enable" + self.verb: str = "enable" + + +class NodeApplicationAbstractAction(AbstractAction): + """ + Base class for application actions. + + Any action which applies to an application and uses node_id and application_id as its only two parameters can + inherit from this base class. + """ + + @abstractmethod + def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None: + super().__init__(manager=manager) + self.shape: Dict[str, int] = {"node_id": num_nodes, "application_id": num_applications} + self.verb: str # define but don't initialise: defends against children classes not defining this + + def form_request(self, node_id: int, application_id: int) -> List[str]: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + node_uuid = self.manager.get_node_uuid_by_idx(node_id) + application_uuid = self.manager.get_application_uuid_by_idx(node_id, application_id) + if node_uuid is None or application_uuid is None: + return ["do_nothing"] + return ["network", "node", node_uuid, "application", application_uuid, self.verb] + + +class NodeApplicationExecuteAction(NodeApplicationAbstractAction): + """Action which executes an application.""" + + def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None: + super().__init__(manager=manager, num_nodes=num_nodes, num_applications=num_applications) + self.verb: str = "execute" class NodeFolderAbstractAction(AbstractAction): @@ -169,7 +199,7 @@ class NodeFolderAbstractAction(AbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: super().__init__(manager=manager) self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders} - self.verb: str + self.verb: str # define but don't initialise: defends against children classes not defining this def form_request(self, node_id: int, folder_id: int) -> List[str]: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" @@ -223,7 +253,7 @@ class NodeFileAbstractAction(AbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager=manager) self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders, "file_id": num_files} - self.verb: str + self.verb: str # define but don't initialise: defends against children classes not defining this def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" @@ -240,7 +270,7 @@ class NodeFileScanAction(NodeFileAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb = "scan" + self.verb: str = "scan" class NodeFileCheckhashAction(NodeFileAbstractAction): @@ -248,7 +278,7 @@ class NodeFileCheckhashAction(NodeFileAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb = "checkhash" + self.verb: str = "checkhash" class NodeFileDeleteAction(NodeFileAbstractAction): @@ -256,7 +286,7 @@ class NodeFileDeleteAction(NodeFileAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb = "delete" + self.verb: str = "delete" class NodeFileRepairAction(NodeFileAbstractAction): @@ -264,7 +294,7 @@ class NodeFileRepairAction(NodeFileAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb = "repair" + self.verb: str = "repair" class NodeFileRestoreAction(NodeFileAbstractAction): @@ -272,7 +302,7 @@ class NodeFileRestoreAction(NodeFileAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb = "restore" + self.verb: str = "restore" class NodeFileCorruptAction(NodeFileAbstractAction): @@ -280,7 +310,7 @@ class NodeFileCorruptAction(NodeFileAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb = "corrupt" + self.verb: str = "corrupt" class NodeAbstractAction(AbstractAction): @@ -294,7 +324,7 @@ class NodeAbstractAction(AbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: super().__init__(manager=manager) self.shape: Dict[str, int] = {"node_id": num_nodes} - self.verb: str + self.verb: str # define but don't initialise: defends against children classes not defining this def form_request(self, node_id: int) -> List[str]: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" @@ -307,7 +337,7 @@ class NodeOSScanAction(NodeAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes) - self.verb = "scan" + self.verb: str = "scan" class NodeShutdownAction(NodeAbstractAction): @@ -315,7 +345,7 @@ class NodeShutdownAction(NodeAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes) - self.verb = "shutdown" + self.verb: str = "shutdown" class NodeStartupAction(NodeAbstractAction): @@ -323,7 +353,7 @@ class NodeStartupAction(NodeAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes) - self.verb = "startup" + self.verb: str = "startup" class NodeResetAction(NodeAbstractAction): @@ -331,7 +361,7 @@ class NodeResetAction(NodeAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes) - self.verb = "reset" + self.verb: str = "reset" class NetworkACLAddRuleAction(AbstractAction): @@ -489,7 +519,7 @@ class NetworkNICAbstractAction(AbstractAction): """ super().__init__(manager=manager) self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node} - self.verb: str + self.verb: str # define but don't initialise: defends against children classes not defining this def form_request(self, node_id: int, nic_id: int) -> List[str]: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" @@ -512,7 +542,7 @@ class NetworkNICEnableAction(NetworkNICAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs) - self.verb = "enable" + self.verb: str = "enable" class NetworkNICDisableAction(NetworkNICAbstractAction): @@ -520,7 +550,7 @@ class NetworkNICDisableAction(NetworkNICAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs) - self.verb = "disable" + self.verb: str = "disable" class ActionManager: @@ -536,6 +566,7 @@ class ActionManager: "NODE_SERVICE_RESTART": NodeServiceRestartAction, "NODE_SERVICE_DISABLE": NodeServiceDisableAction, "NODE_SERVICE_ENABLE": NodeServiceEnableAction, + "NODE_APPLICATION_EXECUTE": NodeApplicationExecuteAction, "NODE_FILE_SCAN": NodeFileScanAction, "NODE_FILE_CHECKHASH": NodeFileCheckhashAction, "NODE_FILE_DELETE": NodeFileDeleteAction, @@ -562,9 +593,11 @@ class ActionManager: game: "PrimaiteGame", # reference to game for information lookup actions: List[str], # stores list of actions available to agent node_uuids: List[str], # allows mapping index to node + application_uuids: List[List[str]], # allows mapping index to application max_folders_per_node: int = 2, # allows calculating shape max_files_per_folder: int = 2, # allows calculating shape max_services_per_node: int = 2, # allows calculating shape + max_applications_per_node: int = 10, # allows calculating shape max_nics_per_node: int = 8, # allows calculating shape max_acl_rules: int = 10, # allows calculating shape protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol @@ -600,8 +633,8 @@ class ActionManager: :type act_map: Optional[Dict[int, Dict]] """ self.game: "PrimaiteGame" = game - self.sim: Simulation = self.game.simulation self.node_uuids: List[str] = node_uuids + self.application_uuids: List[List[str]] = application_uuids self.protocols: List[str] = protocols self.ports: List[str] = ports @@ -611,7 +644,7 @@ class ActionManager: else: self.ip_address_list = [] for node_uuid in self.node_uuids: - node_obj = self.sim.network.nodes[node_uuid] + node_obj = self.game.simulation.network.nodes[node_uuid] nics = node_obj.nics for nic_uuid, nic_obj in nics.items(): self.ip_address_list.append(nic_obj.ip_address) @@ -622,6 +655,7 @@ class ActionManager: "num_folders": max_folders_per_node, "num_files": max_files_per_folder, "num_services": max_services_per_node, + "num_applications": max_applications_per_node, "num_nics": max_nics_per_node, "num_acl_rules": max_acl_rules, "num_protocols": len(self.protocols), @@ -734,7 +768,7 @@ class ActionManager: :rtype: Optional[str] """ node_uuid = self.get_node_uuid_by_idx(node_idx) - node = self.sim.network.nodes[node_uuid] + node = self.game.simulation.network.nodes[node_uuid] folder_uuids = list(node.file_system.folders.keys()) return folder_uuids[folder_idx] if len(folder_uuids) > folder_idx else None @@ -752,7 +786,7 @@ class ActionManager: :rtype: Optional[str] """ node_uuid = self.get_node_uuid_by_idx(node_idx) - node = self.sim.network.nodes[node_uuid] + node = self.game.simulation.network.nodes[node_uuid] folder_uuids = list(node.file_system.folders.keys()) if len(folder_uuids) <= folder_idx: return None @@ -771,10 +805,22 @@ class ActionManager: :rtype: Optional[str] """ node_uuid = self.get_node_uuid_by_idx(node_idx) - node = self.sim.network.nodes[node_uuid] + node = self.game.simulation.network.nodes[node_uuid] service_uuids = list(node.services.keys()) return service_uuids[service_idx] if len(service_uuids) > service_idx else None + def get_application_uuid_by_idx(self, node_idx: int, application_idx: int) -> Optional[str]: + """Get the application UUID corresponding to the given node and service indices. + + :param node_idx: The index of the node. + :type node_idx: int + :param application_idx: The index of the service on the node. + :type application_idx: int + :return: The UUID of the service. Or None if the node has fewer services than the given index. + :rtype: Optional[str] + """ + return self.application_uuids[node_idx][application_idx] + def get_internet_protocol_by_idx(self, protocol_idx: int) -> str: """Get the internet protocol corresponding to the given index. @@ -819,7 +865,7 @@ class ActionManager: :rtype: str """ node_uuid = self.get_node_uuid_by_idx(node_idx) - node_obj = self.sim.network.nodes[node_uuid] + node_obj = self.game.simulation.network.nodes[node_uuid] nics = list(node_obj.nics.keys()) if len(nics) <= nic_idx: return None diff --git a/src/primaite/game/agent/data_manipulation_bot.py b/src/primaite/game/agent/data_manipulation_bot.py new file mode 100644 index 00000000..8237ce06 --- /dev/null +++ b/src/primaite/game/agent/data_manipulation_bot.py @@ -0,0 +1,48 @@ +import random +from typing import Dict, List, Tuple + +from gymnasium.core import ObsType + +from primaite.game.agent.interface import AbstractScriptedAgent +from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot + + +class DataManipulationAgent(AbstractScriptedAgent): + """Agent that uses a DataManipulationBot to perform an SQL injection attack.""" + + data_manipulation_bots: List["DataManipulationBot"] = [] + next_execution_timestep: int = 0 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._set_next_execution_timestep(self.agent_settings.start_settings.start_step) + + def _set_next_execution_timestep(self, timestep: int) -> None: + """Set the next execution timestep with a configured random variance. + + :param timestep: The timestep to add variance to. + """ + random_timestep_increment = random.randint( + -self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance + ) + self.next_execution_timestep = timestep + random_timestep_increment + + def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: + """Randomly sample an action from the action space. + + :param obs: _description_ + :type obs: ObsType + :param reward: _description_, defaults to None + :type reward: float, optional + :return: _description_ + :rtype: Tuple[str, Dict] + """ + current_timestep = self.action_manager.game.step_counter + + if current_timestep < self.next_execution_timestep: + return "DONOTHING", {"dummy": 0} + + self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency) + + return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0} diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 75d209ce..fbbe5473 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -1,13 +1,64 @@ """Interface for agents.""" from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING from gymnasium.core import ActType, ObsType +from pydantic import BaseModel, model_validator from primaite.game.agent.actions import ActionManager from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction +if TYPE_CHECKING: + pass + + +class AgentStartSettings(BaseModel): + """Configuration values for when an agent starts performing actions.""" + + start_step: int = 5 + "The timestep at which an agent begins performing it's actions" + frequency: int = 5 + "The number of timesteps to wait between performing actions" + variance: int = 0 + "The amount the frequency can randomly change to" + + @model_validator(mode="after") + def check_variance_lt_frequency(self) -> "AgentStartSettings": + """ + Make sure variance is equal to or lower than frequency. + + This is because the calculation for the next execution time is now + (frequency +- variance). If variance were + greater than frequency, sometimes the bracketed term would be negative and the attack would never happen again. + """ + if self.variance > self.frequency: + raise ValueError( + f"Agent start settings error: variance must be lower than frequency " + f"{self.variance=}, {self.frequency=}" + ) + return self + + +class AgentSettings(BaseModel): + """Settings for configuring the operation of an agent.""" + + start_settings: Optional[AgentStartSettings] = None + "Configuration for when an agent begins performing it's actions" + + @classmethod + def from_config(cls, config: Optional[Dict]) -> "AgentSettings": + """Construct agent settings from a config dictionary. + + :param config: A dict of options for the agent settings. + :type config: Dict + :return: The agent settings. + :rtype: AgentSettings + """ + if config is None: + return cls() + + return cls(**config) + class AbstractAgent(ABC): """Base class for scripted and RL agents.""" @@ -18,6 +69,7 @@ class AbstractAgent(ABC): action_space: Optional[ActionManager], observation_space: Optional[ObservationManager], reward_function: Optional[RewardFunction], + agent_settings: Optional[AgentSettings] = None, ) -> None: """ Initialize an agent. @@ -35,10 +87,7 @@ class AbstractAgent(ABC): self.action_manager: Optional[ActionManager] = action_space self.observation_manager: Optional[ObservationManager] = observation_space self.reward_function: Optional[RewardFunction] = reward_function - - # exection definiton converts CAOS action to Primaite simulator request, sometimes having to enrich the info - # by for example specifying target ip addresses, or converting a node ID into a uuid - self.execution_definition = None + self.agent_settings = agent_settings or AgentSettings() def update_observation(self, state: Dict) -> ObsType: """ diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index e1ca3120..93fd81b8 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -162,7 +162,7 @@ class ServiceObservation(AbstractObservation): :return: Constructed service observation :rtype: ServiceObservation """ - return cls(where=parent_where + ["services", game.ref_map_services[config["service_ref"]].uuid]) + return cls(where=parent_where + ["services", game.ref_map_services[config["service_ref"]]]) class LinkObservation(AbstractObservation): diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 8a1c2da4..3466114c 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -213,7 +213,7 @@ class WebServer404Penalty(AbstractReward): _LOGGER.warn(msg) return DummyReward() # TODO: should we error out with incorrect inputs? Probably! node_uuid = game.ref_map_nodes[node_ref] - service_uuid = game.ref_map_services[service_ref].uuid + service_uuid = game.ref_map_services[service_ref] if not (node_uuid and service_uuid): msg = ( f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not" diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index e96b9a42..38e9d5fc 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -1,5 +1,4 @@ """PrimAITE game - Encapsulates the simulation and agents.""" -from copy import deepcopy from ipaddress import IPv4Address from typing import Dict, List @@ -7,10 +6,11 @@ from pydantic import BaseModel, ConfigDict from primaite import getLogger from primaite.game.agent.actions import ActionManager -from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent +from primaite.game.agent.data_manipulation_bot import DataManipulationAgent +from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction -from primaite.simulator.network.hardware.base import Link, NIC, Node +from primaite.simulator.network.hardware.base import NIC, NodeOperatingState from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.server import Server @@ -18,14 +18,14 @@ from primaite.simulator.network.hardware.nodes.switch import Switch from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation -from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer +from primaite.simulator.system.services.ftp.ftp_client import FTPClient +from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot -from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.web_server.web_server import WebServer _LOGGER = getLogger(__name__) @@ -57,9 +57,6 @@ class PrimaiteGame: self.simulation: Simulation = Simulation() """Simulation object with which the agents will interact.""" - self._simulation_initial_state = deepcopy(self.simulation) - """The Simulation original state (deepcopy of the original Simulation).""" - self.agents: List[AbstractAgent] = [] """List of agents.""" @@ -75,16 +72,16 @@ class PrimaiteGame: self.options: PrimaiteGameOptions """Special options that apply for the entire game.""" - self.ref_map_nodes: Dict[str, Node] = {} + self.ref_map_nodes: Dict[str, str] = {} """Mapping from unique node reference name to node object. Used when parsing config files.""" - self.ref_map_services: Dict[str, Service] = {} + self.ref_map_services: Dict[str, str] = {} """Mapping from human-readable service reference to service object. Used for parsing config files.""" - self.ref_map_applications: Dict[str, Application] = {} + self.ref_map_applications: Dict[str, str] = {} """Mapping from human-readable application reference to application object. Used for parsing config files.""" - self.ref_map_links: Dict[str, Link] = {} + self.ref_map_links: Dict[str, str] = {} """Mapping from human-readable link reference to link object. Used when parsing config files.""" def step(self): @@ -157,7 +154,7 @@ class PrimaiteGame: self.episode_counter += 1 self.step_counter = 0 _LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}") - self.simulation = deepcopy(self._simulation_initial_state) + self.simulation.reset_component_for_episode(episode=self.episode_counter) def close(self) -> None: """Close the game, this will close the simulation.""" @@ -187,10 +184,6 @@ class PrimaiteGame: sim = game.simulation net = sim.network - game.ref_map_nodes: Dict[str, Node] = {} - game.ref_map_services: Dict[str, Service] = {} - game.ref_map_links: Dict[str, Link] = {} - nodes_cfg = cfg["simulation"]["network"]["nodes"] links_cfg = cfg["simulation"]["network"]["links"] for node_cfg in nodes_cfg: @@ -203,6 +196,7 @@ class PrimaiteGame: subnet_mask=node_cfg["subnet_mask"], default_gateway=node_cfg["default_gateway"], dns_server=node_cfg["dns_server"], + operating_state=NodeOperatingState.ON, ) elif n_type == "server": new_node = Server( @@ -211,16 +205,26 @@ class PrimaiteGame: subnet_mask=node_cfg["subnet_mask"], default_gateway=node_cfg["default_gateway"], dns_server=node_cfg.get("dns_server"), + operating_state=NodeOperatingState.ON, ) elif n_type == "switch": - new_node = Switch(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports")) + new_node = Switch( + hostname=node_cfg["hostname"], + num_ports=node_cfg.get("num_ports"), + operating_state=NodeOperatingState.ON, + ) elif n_type == "router": - new_node = Router(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports")) + new_node = Router( + hostname=node_cfg["hostname"], + num_ports=node_cfg.get("num_ports"), + operating_state=NodeOperatingState.ON, + ) if "ports" in node_cfg: for port_num, port_cfg in node_cfg["ports"].items(): new_node.configure_port( port=port_num, ip_address=port_cfg["ip_address"], subnet_mask=port_cfg["subnet_mask"] ) + # new_node.enable_port(port_num) if "acl" in node_cfg: for r_num, r_cfg in node_cfg["acl"].items(): # excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating @@ -239,6 +243,7 @@ class PrimaiteGame: print("invalid node type") if "services" in node_cfg: for service_cfg in node_cfg["services"]: + new_service = None service_ref = service_cfg["ref"] service_type = service_cfg["type"] service_types_mapping = { @@ -247,13 +252,14 @@ class PrimaiteGame: "DatabaseClient": DatabaseClient, "DatabaseService": DatabaseService, "WebServer": WebServer, - "DataManipulationBot": DataManipulationBot, + "FTPClient": FTPClient, + "FTPServer": FTPServer, } if service_type in service_types_mapping: print(f"installing {service_type} on node {new_node.hostname}") new_node.software_manager.install(service_types_mapping[service_type]) new_service = new_node.software_manager.software[service_type] - game.ref_map_services[service_ref] = new_service + game.ref_map_services[service_ref] = new_service.uuid else: print(f"service type not found {service_type}") # service-dependent options @@ -268,30 +274,49 @@ class PrimaiteGame: if "domain_mapping" in opt: for domain, ip in opt["domain_mapping"].items(): new_service.dns_register(domain, ip) + if service_type == "DatabaseService": + if "options" in service_cfg: + opt = service_cfg["options"] + if "backup_server_ip" in opt: + new_service.configure_backup(backup_server=IPv4Address(opt["backup_server_ip"])) + new_service.start() + if "applications" in node_cfg: for application_cfg in node_cfg["applications"]: + new_application = None application_ref = application_cfg["ref"] application_type = application_cfg["type"] application_types_mapping = { "WebBrowser": WebBrowser, + "DataManipulationBot": DataManipulationBot, } if application_type in application_types_mapping: new_node.software_manager.install(application_types_mapping[application_type]) new_application = new_node.software_manager.software[application_type] - game.ref_map_applications[application_ref] = new_application + game.ref_map_applications[application_ref] = new_application.uuid else: print(f"application type not found {application_type}") + + if application_type == "DataManipulationBot": + if "options" in application_cfg: + opt = application_cfg["options"] + new_application.configure( + server_ip_address=IPv4Address(opt.get("server_ip")), + payload=opt.get("payload"), + port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), + data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")), + ) + elif application_type == "WebBrowser": + if "options" in application_cfg: + opt = application_cfg["options"] + new_application.target_url = opt.get("target_url") if "nics" in node_cfg: for nic_num, nic_cfg in node_cfg["nics"].items(): new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"])) net.add_node(new_node) new_node.power_on() - game.ref_map_nodes[ - node_ref - ] = ( - new_node.uuid - ) # TODO: fix inconsistency with service and link. Node gets added by uuid, but service by object + game.ref_map_nodes[node_ref] = new_node.uuid # 2. create links between nodes for link_cfg in links_cfg: @@ -323,11 +348,25 @@ class PrimaiteGame: # CREATE ACTION SPACE action_space_cfg["options"]["node_uuids"] = [] + action_space_cfg["options"]["application_uuids"] = [] + # if a list of nodes is defined, convert them from node references to node UUIDs for action_node_option in action_space_cfg.get("options", {}).pop("nodes", {}): if "node_ref" in action_node_option: node_uuid = game.ref_map_nodes[action_node_option["node_ref"]] action_space_cfg["options"]["node_uuids"].append(node_uuid) + + if "applications" in action_node_option: + node_application_uuids = [] + for application_option in action_node_option["applications"]: + # TODO: fix inconsistency with node uuids and application uuids. The node object get added to + # node_uuid, whereas here the application gets added by uuid. + application_uuid = game.ref_map_applications[application_option["application_ref"]] + node_application_uuids.append(application_uuid) + + action_space_cfg["options"]["application_uuids"].append(node_application_uuids) + else: + action_space_cfg["options"]["application_uuids"].append([]) # Each action space can potentially have a different list of nodes that it can apply to. Therefore, # we will pass node_uuids as a part of the action space config. # However, it's not possible to specify the node uuids directly in the config, as they are generated @@ -345,6 +384,8 @@ class PrimaiteGame: # CREATE REWARD FUNCTION rew_function = RewardFunction.from_config(reward_function_cfg, game=game) + agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) + # CREATE AGENT if agent_type == "GreenWebBrowsingAgent": # TODO: implement non-random agents and fix this parsing @@ -353,6 +394,7 @@ class PrimaiteGame: action_space=action_space, observation_space=obs_space, reward_function=rew_function, + agent_settings=agent_settings, ) game.agents.append(new_agent) elif agent_type == "ProxyAgent": @@ -365,16 +407,17 @@ class PrimaiteGame: game.agents.append(new_agent) game.rl_agents.append(new_agent) elif agent_type == "RedDatabaseCorruptingAgent": - new_agent = RandomAgent( + new_agent = DataManipulationAgent( agent_name=agent_cfg["ref"], action_space=action_space, observation_space=obs_space, reward_function=rew_function, + agent_settings=agent_settings, ) game.agents.append(new_agent) else: print("agent type not found") - game._simulation_initial_state = deepcopy(game.simulation) # noqa + game.simulation.set_original_state() return game diff --git a/src/primaite/game/science.py b/src/primaite/game/science.py new file mode 100644 index 00000000..19a86237 --- /dev/null +++ b/src/primaite/game/science.py @@ -0,0 +1,16 @@ +from random import random + + +def simulate_trial(p_of_success: float) -> bool: + """ + Simulates the outcome of a single trial in a Bernoulli process. + + This function returns True with a probability 'p_of_success', simulating a success outcome in a single + trial of a Bernoulli process. When this function is executed multiple times, the set of outcomes follows + a binomial distribution. This is useful in scenarios where one needs to model or simulate events that + have two possible outcomes (success or failure) with a fixed probability of success. + + :param p_of_success: The probability of success in a single trial, ranging from 0 to 1. + :returns: True if the trial is successful (with probability 'p_of_success'); otherwise, False. + """ + return random() < p_of_success diff --git a/src/primaite/notebooks/uc2_demo.ipynb b/src/primaite/notebooks/uc2_demo.ipynb new file mode 100644 index 00000000..3950ef10 --- /dev/null +++ b/src/primaite/notebooks/uc2_demo.ipynb @@ -0,0 +1,306 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/cade/repos/PrimAITE/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "2023-11-26 23:25:47,985\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", + "2023-11-26 23:25:51,213\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", + "2023-11-26 23:25:51,491\tWARNING __init__.py:10 -- PG has/have been moved to `rllib_contrib` and will no longer be maintained by the RLlib team. You can still use it/them normally inside RLlib util Ray 2.8, but from Ray 2.9 on, all `rllib_contrib` algorithms will no longer be part of the core repo, and will therefore have to be installed separately with pinned dependencies for e.g. ray[rllib] and other packages! See https://github.com/ray-project/ray/tree/master/rllib_contrib#rllib-contrib for more information on the RLlib contrib effort.\n" + ] + } + ], + "source": [ + "from primaite.session.session import PrimaiteSession\n", + "from primaite.game.game import PrimaiteGame\n", + "from primaite.config.load import example_config_path\n", + "\n", + "from primaite.simulator.system.services.database.database_service import DatabaseService\n", + "\n", + "import yaml" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-11-26 23:25:51,579::ERROR::primaite.simulator.network.hardware.base::175::NIC a9:92:0a:5e:1b:e4/127.0.0.1 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,580::ERROR::primaite.simulator.network.hardware.base::175::NIC ef:03:23:af:3c:19/127.0.0.1 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,581::ERROR::primaite.simulator.network.hardware.base::175::NIC ae:cf:83:2f:94:17/127.0.0.1 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,582::ERROR::primaite.simulator.network.hardware.base::175::NIC 4c:b2:99:e2:4a:5d/127.0.0.1 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,583::ERROR::primaite.simulator.network.hardware.base::175::NIC b9:eb:f9:c2:17:2f/127.0.0.1 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,590::ERROR::primaite.simulator.network.hardware.base::175::NIC cb:df:ca:54:be:01/192.168.1.10 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,595::ERROR::primaite.simulator.network.hardware.base::175::NIC 6e:32:12:da:4d:0d/192.168.1.12 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,600::ERROR::primaite.simulator.network.hardware.base::175::NIC 58:6e:9b:a7:68:49/192.168.1.14 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,604::ERROR::primaite.simulator.network.hardware.base::175::NIC 33:db:a6:40:dd:a3/192.168.1.16 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,608::ERROR::primaite.simulator.network.hardware.base::175::NIC 72:aa:2b:c0:4c:5f/192.168.1.110 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,610::ERROR::primaite.simulator.network.hardware.base::175::NIC 11:d7:0e:90:d9:a4/192.168.10.110 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,614::ERROR::primaite.simulator.network.hardware.base::175::NIC 86:2b:a4:e5:4d:0f/192.168.10.21 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,631::ERROR::primaite.simulator.network.hardware.base::175::NIC af:ad:8f:84:f1:db/192.168.10.22 cannot be enabled as it is not connected to a Link\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "installing DNSServer on node domain_controller\n", + "installing DatabaseClient on node web_server\n", + "installing WebServer on node web_server\n", + "installing DatabaseService on node database_server\n", + "installing FTPClient on node database_server\n", + "installing FTPServer on node backup_server\n", + "installing DNSClient on node client_1\n", + "installing DNSClient on node client_2\n" + ] + } + ], + "source": [ + "\n", + "with open(example_config_path(),'r') as cfgfile:\n", + " cfg = yaml.safe_load(cfgfile)\n", + "game = PrimaiteGame.from_config(cfg)\n", + "net = game.simulation.network\n", + "database_server = net.get_node_by_hostname('database_server')\n", + "web_server = net.get_node_by_hostname('web_server')\n", + "client_1 = net.get_node_by_hostname('client_1')\n", + "\n", + "db_service = database_server.software_manager.software[\"DatabaseService\"]\n", + "db_client = web_server.software_manager.software[\"DatabaseClient\"]\n", + "# db_client.run()\n", + "db_manipulation_bot = client_1.software_manager.software[\"DataManipulationBot\"]\n", + "db_manipulation_bot.port_scan_p_of_success=1.0\n", + "db_manipulation_bot.data_manipulation_p_of_success=1.0\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "db_client.run()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "db_service.backup_database()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "db_client.query(\"SELECT\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "db_manipulation_bot.run()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "db_client.query(\"SELECT\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "db_service.restore_backup()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "db_client.query(\"SELECT\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "db_manipulation_bot.run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "client_1.ping(database_server.ethernet_port[1].ip_address)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from pydantic import validate_call, BaseModel" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "class A(BaseModel):\n", + " x:int\n", + "\n", + " @validate_call\n", + " def increase_x(self, by:int) -> None:\n", + " self.x += 1" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "my_a = A(x=3)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "ename": "ValidationError", + "evalue": "1 validation error for increase_x\n0\n Input should be a valid integer, got a number with a fractional part [type=int_from_float, input_value=3.2, input_type=float]\n For further information visit https://errors.pydantic.dev/2.1/v/int_from_float", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/cade/repos/PrimAITE/src/primaite/notebooks/uc2_demo.ipynb Cell 15\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m my_a\u001b[39m.\u001b[39;49mincrease_x(\u001b[39m3.2\u001b[39;49m)\n", + "File \u001b[0;32m~/repos/PrimAITE/venv/lib/python3.10/site-packages/pydantic/_internal/_validate_call.py:91\u001b[0m, in \u001b[0;36mValidateCallWrapper.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs: Any, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs: Any) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Any:\n\u001b[0;32m---> 91\u001b[0m res \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__pydantic_validator__\u001b[39m.\u001b[39;49mvalidate_python(pydantic_core\u001b[39m.\u001b[39;49mArgsKwargs(args, kwargs))\n\u001b[1;32m 92\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__return_pydantic_validator__:\n\u001b[1;32m 93\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__return_pydantic_validator__\u001b[39m.\u001b[39mvalidate_python(res)\n", + "\u001b[0;31mValidationError\u001b[0m: 1 validation error for increase_x\n0\n Input should be a valid integer, got a number with a fractional part [type=int_from_float, input_value=3.2, input_type=float]\n For further information visit https://errors.pydantic.dev/2.1/v/int_from_float" + ] + } + ], + "source": [ + "my_a.increase_x(3.2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index db24db60..a5fdade9 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -37,7 +37,7 @@ class PrimaiteGymEnv(gymnasium.Env): terminated = False truncated = self.game.calculate_truncated() info = {} - + print(f"Episode: {self.game.episode_counter}, Step: {self.game.step_counter}, Reward: {reward}") return next_obs, reward, terminated, truncated, info def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py index 80b63ba7..3919902a 100644 --- a/src/primaite/session/session.py +++ b/src/primaite/session/session.py @@ -88,16 +88,16 @@ class PrimaiteSession: @classmethod def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession": """Create a PrimaiteSession object from a config dictionary.""" + # READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS... + io_settings = cfg.get("io_settings", {}) + io_manager = SessionIO(SessionIOSettings(**io_settings)) + game = PrimaiteGame.from_config(cfg) sess = cls(game=game) - + sess.io_manager = io_manager sess.training_options = TrainingOptions(**cfg["training_config"]) - # READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS... - io_settings = cfg.get("io_settings", {}) - sess.io_manager.settings = SessionIOSettings(**io_settings) - # CREATE ENVIRONMENT if sess.training_options.rl_framework == "RLLIB_single_agent": sess.env = PrimaiteRayEnv(env_config={"game": game}) diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 9ead877e..18a470cd 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -153,6 +153,8 @@ class SimComponent(BaseModel): uuid: str """The component UUID.""" + _original_state: Dict = {} + def __init__(self, **kwargs): if not kwargs.get("uuid"): kwargs["uuid"] = str(uuid4()) @@ -160,6 +162,16 @@ class SimComponent(BaseModel): self._request_manager: RequestManager = self._init_request_manager() self._parent: Optional["SimComponent"] = None + # @abstractmethod + def set_original_state(self): + """Sets the original state.""" + pass + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + for key, value in self._original_state.items(): + self.__setattr__(key, value) + def _init_request_manager(self) -> RequestManager: """ Initialise the request manager for this component. @@ -227,14 +239,6 @@ class SimComponent(BaseModel): """ pass - def reset_component_for_episode(self, episode: int): - """ - Reset this component to its original state for a new episode. - - Override this method with anything that needs to happen within the component for it to be reset. - """ - pass - @property def parent(self) -> "SimComponent": """Reference to the parent object which manages this object. diff --git a/src/primaite/simulator/domain/account.py b/src/primaite/simulator/domain/account.py index d235c00e..1402a474 100644 --- a/src/primaite/simulator/domain/account.py +++ b/src/primaite/simulator/domain/account.py @@ -42,6 +42,19 @@ class Account(SimComponent): "Account Type, currently this can be service account (used by apps) or user account." enabled: bool = True + def set_original_state(self): + """Sets the original state.""" + vals_to_include = { + "num_logons", + "num_logoffs", + "num_group_changes", + "username", + "password", + "account_type", + "enabled", + } + self._original_state = self.model_dump(include=vals_to_include) + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. diff --git a/src/primaite/simulator/file_system/file.py b/src/primaite/simulator/file_system/file.py index d9b02e8e..8f0abb3c 100644 --- a/src/primaite/simulator/file_system/file.py +++ b/src/primaite/simulator/file_system/file.py @@ -73,6 +73,18 @@ class File(FileSystemItemABC): self.sys_log.info(f"Created file /{self.path} (id: {self.uuid})") + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"folder_id", "folder_name", "file_type", "sim_size", "real", "sim_path", "sim_root"} + self._original_state.update(self.model_dump(include=vals_to_include)) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + super().reset_component_for_episode(episode) + @property def path(self) -> str: """ diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index 41f02270..dc6f01a3 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -35,6 +35,36 @@ class FileSystem(SimComponent): if not self.folders: self.create_folder("root") + def set_original_state(self): + """Sets the original state.""" + for folder in self.folders.values(): + folder.set_original_state() + super().set_original_state() + # Capture a list of all 'original' file uuids + self._original_state["original_folder_uuids"] = list(self.folders.keys()) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + # Move any 'original' folder that have been deleted back to folders + original_folder_uuids = self._original_state.pop("original_folder_uuids") + for uuid in original_folder_uuids: + if uuid in self.deleted_folders: + self.folders[uuid] = self.deleted_folders.pop(uuid) + + # Clear any other deleted folders that aren't original (have been created by agent) + self.deleted_folders.clear() + + # Now clear all non-original folders created by agent + current_folder_uuids = list(self.folders.keys()) + for uuid in current_folder_uuids: + if uuid not in original_folder_uuids: + self.folders.pop(uuid) + + # Now reset all remaining folders + for folder in self.folders.values(): + folder.reset_component_for_episode(episode) + super().reset_component_for_episode(episode) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/file_system/file_system_item_abc.py b/src/primaite/simulator/file_system/file_system_item_abc.py index fbe5f4b3..86cd1ee7 100644 --- a/src/primaite/simulator/file_system/file_system_item_abc.py +++ b/src/primaite/simulator/file_system/file_system_item_abc.py @@ -85,6 +85,11 @@ class FileSystemItemABC(SimComponent): deleted: bool = False "If true, the FileSystemItem was deleted." + def set_original_state(self): + """Sets the original state.""" + vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red"} + self._original_state = self.model_dump(include=vals_to_keep) + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index f0d55ef8..8e577097 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -51,6 +51,44 @@ class Folder(FileSystemItemABC): self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})") + def set_original_state(self): + """Sets the original state.""" + for file in self.files.values(): + file.set_original_state() + super().set_original_state() + vals_to_include = { + "scan_duration", + "scan_countdown", + "red_scan_duration", + "red_scan_countdown", + "restore_duration", + "restore_countdown", + } + self._original_state.update(self.model_dump(include=vals_to_include)) + self._original_state["original_file_uuids"] = list(self.files.keys()) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + # Move any 'original' file that have been deleted back to files + original_file_uuids = self._original_state.pop("original_file_uuids") + for uuid in original_file_uuids: + if uuid in self.deleted_files: + self.files[uuid] = self.deleted_files.pop(uuid) + + # Clear any other deleted files that aren't original (have been created by agent) + self.deleted_files.clear() + + # Now clear all non-original files created by agent + current_file_uuids = list(self.files.keys()) + for uuid in current_file_uuids: + if uuid not in original_file_uuids: + self.files.pop(uuid) + + # Now reset all remaining files + for file in self.files.values(): + file.reset_component_for_episode(episode) + super().reset_component_for_episode(episode) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request( diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index a356549a..7ef55c3c 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -43,6 +43,20 @@ class Network(SimComponent): self._nx_graph = MultiGraph() + def set_original_state(self): + """Sets the original state.""" + for node in self.nodes.values(): + node.set_original_state() + for link in self.links.values(): + link.set_original_state() + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + for node in self.nodes.values(): + node.reset_component_for_episode(episode) + for link in self.links.values(): + link.reset_component_for_episode(episode) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() self._node_request_manager = RequestManager() diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 81272547..c6ee373e 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -121,6 +121,21 @@ class NIC(SimComponent): _LOGGER.error(msg) raise ValueError(msg) + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"} + self._original_state = self.model_dump(include=vals_to_include) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + super().reset_component_for_episode(episode) + if episode and self.pcap: + self.pcap.current_episode = episode + self.pcap.setup_logger() + self.enable() + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -308,6 +323,14 @@ class SwitchPort(SimComponent): kwargs["mac_address"] = generate_mac_address() super().__init__(**kwargs) + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + vals_to_include = {"port_num", "mac_address", "speed", "mtu", "enabled"} + self._original_state = self.model_dump(include=vals_to_include) + super().set_original_state() + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -454,6 +477,14 @@ class Link(SimComponent): self.endpoint_b.connect_link(self) self.endpoint_up() + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + vals_to_include = {"bandwidth", "current_load"} + self._original_state = self.model_dump(include=vals_to_include) + super().set_original_state() + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -536,15 +567,6 @@ class Link(SimComponent): return True return False - def reset_component_for_episode(self, episode: int): - """ - Link reset function. - - Reset: - - returns the link current_load to 0. - """ - self.current_load = 0 - def __str__(self) -> str: return f"{self.endpoint_a}<-->{self.endpoint_b}" @@ -584,6 +606,10 @@ class ARPCache: ) print(table) + def clear(self): + """Clears the arp cache.""" + self.arp.clear() + def add_arp_cache_entry(self, ip_address: IPv4Address, mac_address: str, nic: NIC, override: bool = False): """ Add an ARP entry to the cache. @@ -756,6 +782,10 @@ class ICMP: self.arp: ARPCache = arp_cache self.request_replies = {} + def clear(self): + """Clears the ICMP request replies tracker.""" + self.request_replies.clear() + def process_icmp(self, frame: Frame, from_nic: NIC, is_reattempt: bool = False): """ Process an ICMP packet, including handling echo requests and replies. @@ -959,6 +989,62 @@ class Node(SimComponent): self.arp.nics = self.nics self.session_manager.software_manager = self.software_manager self._install_system_software() + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + for software in self.software_manager.software.values(): + software.set_original_state() + + self.file_system.set_original_state() + + for nic in self.nics.values(): + nic.set_original_state() + + vals_to_include = { + "hostname", + "default_gateway", + "operating_state", + "revealed_to_red", + "start_up_duration", + "start_up_countdown", + "shut_down_duration", + "shut_down_countdown", + "is_resetting", + "node_scan_duration", + "node_scan_countdown", + "red_scan_countdown", + } + self._original_state = self.model_dump(include=vals_to_include) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + # Reset ARP Cache + self.arp.clear() + + # Reset ICMP + self.icmp.clear() + + # Reset Session Manager + self.session_manager.clear() + + # Reset software + for software in self.software_manager.software.values(): + software.reset_component_for_episode(episode) + + # Reset File System + self.file_system.reset_component_for_episode(episode) + + # Reset all Nics + for nic in self.nics.values(): + nic.reset_component_for_episode(episode) + + # + if episode and self.sys_log: + self.sys_log.current_episode = episode + self.sys_log.setup_logger() + + super().reset_component_for_episode(episode) def _init_request_manager(self) -> RequestManager: # TODO: I see that this code is really confusing and hard to read right now... I think some of these things will @@ -1442,99 +1528,3 @@ class Node(SimComponent): if isinstance(item, Service): return item.uuid in self.services return None - - -class Switch(Node): - """A class representing a Layer 2 network switch.""" - - num_ports: int = 24 - "The number of ports on the switch." - switch_ports: Dict[int, SwitchPort] = {} - "The SwitchPorts on the switch." - mac_address_table: Dict[str, SwitchPort] = {} - "A MAC address table mapping destination MAC addresses to corresponding SwitchPorts." - - def __init__(self, **kwargs): - super().__init__(**kwargs) - if not self.switch_ports: - self.switch_ports = {i: SwitchPort() for i in range(1, self.num_ports + 1)} - for port_num, port in self.switch_ports.items(): - port._connected_node = self - port.parent = self - port.port_num = port_num - - def show(self): - """Prints a table of the SwitchPorts on the Switch.""" - table = PrettyTable(["Port", "MAC Address", "Speed", "Status"]) - - for port_num, port in self.switch_ports.items(): - table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"]) - print(table) - - def describe_state(self) -> Dict: - """ - Produce a dictionary describing the current state of this object. - - Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation. - - :return: Current state of this object and child objects. - :rtype: Dict - """ - return { - "uuid": self.uuid, - "num_ports": self.num_ports, # redundant? - "ports": {port_num: port.describe_state() for port_num, port in self.switch_ports.items()}, - "mac_address_table": {mac: port for mac, port in self.mac_address_table.items()}, - } - - def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort): - mac_table_port = self.mac_address_table.get(mac_address) - if not mac_table_port: - self.mac_address_table[mac_address] = switch_port - self.sys_log.info(f"Added MAC table entry: Port {switch_port.port_num} -> {mac_address}") - else: - if mac_table_port != switch_port: - self.mac_address_table.pop(mac_address) - self.sys_log.info(f"Removed MAC table entry: Port {mac_table_port.port_num} -> {mac_address}") - self._add_mac_table_entry(mac_address, switch_port) - - def forward_frame(self, frame: Frame, incoming_port: SwitchPort): - """ - Forward a frame to the appropriate port based on the destination MAC address. - - :param frame: The Frame to be forwarded. - :param incoming_port: The port number from which the frame was received. - """ - src_mac = frame.ethernet.src_mac_addr - dst_mac = frame.ethernet.dst_mac_addr - self._add_mac_table_entry(src_mac, incoming_port) - - outgoing_port = self.mac_address_table.get(dst_mac) - if outgoing_port or dst_mac != "ff:ff:ff:ff:ff:ff": - outgoing_port.send_frame(frame) - else: - # If the destination MAC is not in the table, flood to all ports except incoming - for port in self.switch_ports.values(): - if port != incoming_port: - port.send_frame(frame) - - def disconnect_link_from_port(self, link: Link, port_number: int): - """ - Disconnect a given link from the specified port number on the switch. - - :param link: The Link object to be disconnected. - :param port_number: The port number on the switch from where the link should be disconnected. - :raise NetworkError: When an invalid port number is provided or the link does not match the connection. - """ - port = self.switch_ports.get(port_number) - if port is None: - msg = f"Invalid port number {port_number} on the switch" - _LOGGER.error(msg) - raise NetworkError(msg) - - if port._connected_link != link: - msg = f"The link does not match the connection at port number {port_number}" - _LOGGER.error(msg) - raise NetworkError(msg) - - port.disconnect_link() diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index c2a38aba..34b92a07 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -52,6 +52,11 @@ class ACLRule(SimComponent): rule_strings.append(f"{key}={value}") return ", ".join(rule_strings) + def set_original_state(self): + """Sets the original state.""" + vals_to_keep = {"action", "protocol", "src_ip_address", "src_port", "dst_ip_address", "dst_port"} + self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True) + def describe_state(self) -> Dict: """ Describes the current state of the ACLRule. @@ -93,6 +98,18 @@ class AccessControlList(SimComponent): super().__init__(**kwargs) self._acl = [None] * (self.max_acl_rules - 1) + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + self.implicit_rule.set_original_state() + vals_to_keep = {"implicit_action", "max_acl_rules", "acl"} + self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + self.implicit_rule.reset_component_for_episode(episode) + super().reset_component_for_episode(episode) def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() @@ -337,6 +354,11 @@ class RouteEntry(SimComponent): kwargs[key] = IPv4Address(kwargs[key]) super().__init__(**kwargs) + def set_original_state(self): + """Sets the original state.""" + vals_to_include = {"address", "subnet_mask", "next_hop_ip_address", "metric"} + self._original_values = self.model_dump(include=vals_to_include) + def describe_state(self) -> Dict: """ Describes the current state of the RouteEntry. @@ -368,6 +390,18 @@ class RouteTable(SimComponent): routes: List[RouteEntry] = [] sys_log: SysLog + def set_original_state(self): + """Sets the original state.""" + """Sets the original state.""" + super().set_original_state() + self._original_state["routes_orig"] = self.routes + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + self.routes.clear() + self.routes = self._original_state["routes_orig"] + super().reset_component_for_episode(episode) + def describe_state(self) -> Dict: """ Describes the current state of the RouteTable. @@ -638,6 +672,26 @@ class Router(Node): self.arp.nics = self.nics self.icmp.arp = self.arp + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + self.acl.set_original_state() + self.route_table.set_original_state() + vals_to_include = {"num_ports"} + self._original_state = self.model_dump(include=vals_to_include) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + self.arp.clear() + self.acl.reset_component_for_episode(episode) + self.route_table.reset_component_for_episode(episode) + for i, nic in self.ethernet_ports.items(): + nic.reset_component_for_episode(episode) + self.enable_port(i) + + super().reset_component_for_episode(episode) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request("acl", RequestType(func=self.acl._request_manager)) @@ -730,6 +784,7 @@ class Router(Node): dst_ip_address=dst_ip_address, dst_port=dst_port, ) + if not permitted: at_port = self._get_port_of_nic(from_nic) self.sys_log.info(f"Frame blocked at port {at_port} by rule {rule}") @@ -763,6 +818,7 @@ class Router(Node): nic.ip_address = ip_address nic.subnet_mask = subnet_mask self.sys_log.info(f"Configured port {port} with ip_address={ip_address}/{nic.ip_network.prefixlen}") + self.set_original_state() def enable_port(self, port: int): """ diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index c0f9a07e..b7bd2e95 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -140,7 +140,12 @@ def arcd_uc2_network() -> Network: network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1]) client_1.software_manager.install(DataManipulationBot) db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"] - db_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE") + db_manipulation_bot.configure( + server_ip_address=IPv4Address("192.168.1.14"), + payload="DELETE", + port_scan_p_of_success=1.0, + data_manipulation_p_of_success=1.0, + ) # Client 2 client_2 = Computer( @@ -152,6 +157,8 @@ def arcd_uc2_network() -> Network: operating_state=NodeOperatingState.ON, ) client_2.power_on() + web_browser = client_2.software_manager.software["WebBrowser"] + web_browser.target_url = "http://arcd.com/users/" network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2]) # Domain Controller diff --git a/src/primaite/simulator/network/protocols/http.py b/src/primaite/simulator/network/protocols/http.py index 2dba2614..b88916a9 100644 --- a/src/primaite/simulator/network/protocols/http.py +++ b/src/primaite/simulator/network/protocols/http.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import Enum, IntEnum from primaite.simulator.network.protocols.packet import DataPacket @@ -25,7 +25,7 @@ class HttpRequestMethod(Enum): """Apply partial modifications to a resource.""" -class HttpStatusCode(Enum): +class HttpStatusCode(IntEnum): """List of available HTTP Statuses.""" OK = 200 diff --git a/src/primaite/simulator/sim_container.py b/src/primaite/simulator/sim_container.py index 8e820ec8..c529ed04 100644 --- a/src/primaite/simulator/sim_container.py +++ b/src/primaite/simulator/sim_container.py @@ -9,7 +9,7 @@ class Simulation(SimComponent): """Top-level simulation object which holds a reference to all other parts of the simulation.""" network: Network - domain: DomainController + # domain: DomainController def __init__(self, **kwargs): """Initialise the Simulation.""" @@ -21,6 +21,14 @@ class Simulation(SimComponent): super().__init__(**kwargs) + def set_original_state(self): + """Sets the original state.""" + self.network.set_original_state() + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + self.network.reset_component_for_episode(episode) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() # pass through network requests to the network objects diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index d2f9772d..898e5917 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -41,6 +41,12 @@ class Application(IOSoftware): self.health_state_visible = SoftwareHealthState.UNUSED self.health_state_actual = SoftwareHealthState.UNUSED + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"operating_state", "execution_control_status", "num_executions", "groups"} + self._original_state.update(self.model_dump(include=vals_to_include)) + @abstractmethod def describe_state(self) -> Dict: """ @@ -90,6 +96,10 @@ class Application(IOSoftware): self.sys_log.info(f"Running Application {self.name}") self.operating_state = ApplicationOperatingState.RUNNING + def _application_loop(self): + """The main application loop.""" + pass + def close(self) -> None: """Close the Application.""" if self.operating_state == ApplicationOperatingState.RUNNING: @@ -98,23 +108,11 @@ class Application(IOSoftware): def install(self) -> None: """Install Application.""" - if self._can_perform_action(): - return - super().install() if self.operating_state == ApplicationOperatingState.CLOSED: self.sys_log.info(f"Installing Application {self.name}") self.operating_state = ApplicationOperatingState.INSTALLING - def reset_component_for_episode(self, episode: int): - """ - Resets the Application component for a new episode. - - This method ensures the Application is ready for a new episode, including resetting any - stateful properties or statistics, and clearing any message queues. - """ - pass - def receive(self, payload: Any, session_id: str, **kwargs) -> bool: """ Receives a payload from the SessionManager. diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 9cb87bf6..37f85b28 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -31,6 +31,13 @@ class DatabaseClient(Application): kwargs["port"] = Port.POSTGRES_SERVER kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"server_ip_address", "server_password", "connected"} + self._original_state.update(self.model_dump(include=vals_to_include)) def describe_state(self) -> Dict: """ @@ -78,11 +85,11 @@ class DatabaseClient(Application): """ if is_reattempt: if self.connected: - self.sys_log.info(f"{self.name}: DatabaseClient connected to {server_ip_address} authorised") + self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} authorised") self.server_ip_address = server_ip_address return self.connected else: - self.sys_log.info(f"{self.name}: DatabaseClient connected to {server_ip_address} declined") + self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} declined") return False payload = {"type": "connect_request", "password": password} software_manager: SoftwareManager = self.software_manager @@ -135,8 +142,8 @@ class DatabaseClient(Application): def run(self) -> None: """Run the DatabaseClient.""" super().run() - self.operating_state = ApplicationOperatingState.RUNNING - self.connect() + if self.operating_state == ApplicationOperatingState.RUNNING: + self.connect() def query(self, sql: str, is_reattempt: bool = False) -> bool: """ diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index 71e30c7f..bf304d7b 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -2,6 +2,7 @@ from ipaddress import IPv4Address from typing import Dict, Optional from urllib.parse import urlparse +from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.protocols.http import ( HttpRequestMethod, HttpRequestPacket, @@ -21,6 +22,8 @@ class WebBrowser(Application): The application requests and loads web pages using its domain name and requesting IP addresses using DNS. """ + target_url: Optional[str] = None + domain_name_ip_address: Optional[IPv4Address] = None "The IP address of the domain name for the webpage." @@ -35,8 +38,23 @@ class WebBrowser(Application): kwargs["port"] = Port.HTTP super().__init__(**kwargs) + self.set_original_state() self.run() + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"target_url", "domain_name_ip_address", "latest_response"} + self._original_state.update(self.model_dump(include=vals_to_include)) + + def _init_request_manager(self) -> RequestManager: + rm = super()._init_request_manager() + rm.add_request( + name="execute", request_type=RequestType(func=lambda request, context: self.get_webpage()) # noqa + ) + + return rm + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of the WebBrowser. @@ -47,16 +65,9 @@ class WebBrowser(Application): state["last_response_status_code"] = self.latest_response.status_code if self.latest_response else None def reset_component_for_episode(self, episode: int): - """ - Resets the Application component for a new episode. + """Reset the original state of the SimComponent.""" - This method ensures the Application is ready for a new episode, including resetting any - stateful properties or statistics, and clearing any message queues. - """ - self.domain_name_ip_address = None - self.latest_response = None - - def get_webpage(self, url: str) -> bool: + def get_webpage(self) -> bool: """ Retrieve the webpage. @@ -65,6 +76,7 @@ class WebBrowser(Application): :param: url: The address of the web page the browser requests :type: url: str """ + url = self.target_url if not self._can_perform_action(): return False @@ -79,7 +91,6 @@ class WebBrowser(Application): # get the IP address of the domain name via DNS dns_client: DNSClient = self.software_manager.software["DNSClient"] - domain_exists = dns_client.check_domain_exists(target_domain=parsed_url.hostname) # if domain does not exist, the request fails diff --git a/src/primaite/simulator/system/core/packet_capture.py b/src/primaite/simulator/system/core/packet_capture.py index c2faeb10..1539e024 100644 --- a/src/primaite/simulator/system/core/packet_capture.py +++ b/src/primaite/simulator/system/core/packet_capture.py @@ -34,9 +34,12 @@ class PacketCapture: "The IP address associated with the PCAP logs." self.switch_port_number = switch_port_number "The SwitchPort number." - self._setup_logger() - def _setup_logger(self): + self.current_episode: int = 1 + + self.setup_logger() + + def setup_logger(self): """Set up the logger configuration.""" log_path = self._get_log_path() @@ -75,7 +78,7 @@ class PacketCapture: def _get_log_path(self) -> Path: """Get the path for the log file.""" - root = SIM_OUTPUT.path / self.hostname + root = SIM_OUTPUT.path / f"episode_{self.current_episode}" / self.hostname root.mkdir(exist_ok=True, parents=True) return root / f"{self._logger_name}.log" diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index 360b5e73..8658f155 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -93,6 +93,11 @@ class SessionManager: """ pass + def clear(self): + """Clears the sessions.""" + self.sessions_by_key.clear() + self.sessions_by_uuid.clear() + @staticmethod def _get_session_key( frame: Frame, inbound_frame: bool = True diff --git a/src/primaite/simulator/system/core/sys_log.py b/src/primaite/simulator/system/core/sys_log.py index 7ac6df85..41ce8fee 100644 --- a/src/primaite/simulator/system/core/sys_log.py +++ b/src/primaite/simulator/system/core/sys_log.py @@ -31,9 +31,10 @@ class SysLog: :param hostname: The hostname associated with the system logs being recorded. """ self.hostname = hostname - self._setup_logger() + self.current_episode: int = 1 + self.setup_logger() - def _setup_logger(self): + def setup_logger(self): """ Configures the logger for this SysLog instance. @@ -80,7 +81,7 @@ class SysLog: :return: Path object representing the location of the log file. """ - root = SIM_OUTPUT.path / self.hostname + root = SIM_OUTPUT.path / f"episode_{self.current_episode}" / self.hostname root.mkdir(exist_ok=True, parents=True) return root / f"{self.hostname}_sys.log" diff --git a/src/primaite/simulator/system/processes/process.py b/src/primaite/simulator/system/processes/process.py index c4e94845..ad9af335 100644 --- a/src/primaite/simulator/system/processes/process.py +++ b/src/primaite/simulator/system/processes/process.py @@ -24,6 +24,12 @@ class Process(Software): operating_state: ProcessOperatingState "The current operating state of the Process." + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"operating_state"} + self._original_state.update(self.model_dump(include=vals_to_include)) + @abstractmethod def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 740ed4fd..45e469fb 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -38,6 +38,23 @@ class DatabaseService(Service): self._db_file: File self._create_db_file() + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = { + "password", + "connections", + "backup_server", + "latest_backup_directory", + "latest_backup_file_name", + } + self._original_state.update(self.model_dump(include=vals_to_include)) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + self.connections.clear() + super().reset_component_for_episode(episode) + def configure_backup(self, backup_server: IPv4Address): """ Set up the database backup. diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index 47196d15..3d425bfa 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -29,6 +29,17 @@ class DNSClient(Service): super().__init__(**kwargs) self.start() + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"dns_server"} + self._original_state.update(self.model_dump(include=vals_to_include)) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + self.dns_cache.clear() + super().reset_component_for_episode(episode) + def describe_state(self) -> Dict: """ Describes the current state of the software. diff --git a/src/primaite/simulator/system/services/dns/dns_server.py b/src/primaite/simulator/system/services/dns/dns_server.py index b6d4961f..30278ab1 100644 --- a/src/primaite/simulator/system/services/dns/dns_server.py +++ b/src/primaite/simulator/system/services/dns/dns_server.py @@ -28,6 +28,22 @@ class DNSServer(Service): super().__init__(**kwargs) self.start() + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"dns_table"} + self._original_state["dns_table_orig"] = self.model_dump(include=vals_to_include)["dns_table"] + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + print("dns reset") + print("DNSServer original state", self._original_state) + self.dns_table.clear() + for key, value in self._original_state["dns_table_orig"].items(): + self.dns_table[key] = value + super().reset_component_for_episode(episode) + self.show() + def describe_state(self) -> Dict: """ Describes the current state of the software. @@ -68,15 +84,6 @@ class DNSServer(Service): self.dns_table[domain_name] = domain_ip_address - def reset_component_for_episode(self, episode: int): - """ - Resets the Service component for a new episode. - - This method ensures the Service is ready for a new episode, including resetting any - stateful properties or statistics, and clearing any message queues. - """ - pass - def receive( self, payload: Any, diff --git a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py index 8dc2eeab..b0b34396 100644 --- a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py +++ b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py @@ -1,27 +1,67 @@ +from enum import IntEnum from ipaddress import IPv4Address from typing import Optional +from primaite.game.science import simulate_trial +from primaite.simulator.core import RequestManager, RequestType +from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient -class DataManipulationBot(DatabaseClient): +class DataManipulationAttackStage(IntEnum): """ - Red Agent Data Integration Service. + Enumeration representing different stages of a data manipulation attack. - The Service represents a bot that causes files/folders in the File System to - become corrupted. + This enumeration defines the various stages a data manipulation attack can be in during its lifecycle in the + simulation. Each stage represents a specific phase in the attack process. """ + NOT_STARTED = 0 + "Indicates that the attack has not started yet." + LOGON = 1 + "The stage where logon procedures are simulated." + PORT_SCAN = 2 + "Represents the stage of performing a horizontal port scan on the target." + ATTACKING = 3 + "Stage of actively attacking the target." + SUCCEEDED = 4 + "Indicates the attack has been successfully completed." + FAILED = 5 + "Signifies that the attack has failed." + + +class DataManipulationBot(DatabaseClient): + """A bot that simulates a script which performs a SQL injection attack.""" + server_ip_address: Optional[IPv4Address] = None payload: Optional[str] = None server_password: Optional[str] = None + port_scan_p_of_success: float = 0.1 + data_manipulation_p_of_success: float = 0.1 + + attack_stage: DataManipulationAttackStage = DataManipulationAttackStage.NOT_STARTED + repeat: bool = False + "Whether to repeat attacking once finished." def __init__(self, **kwargs): super().__init__(**kwargs) self.name = "DataManipulationBot" + def _init_request_manager(self) -> RequestManager: + rm = super()._init_request_manager() + + rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.run())) + + return rm + def configure( - self, server_ip_address: IPv4Address, server_password: Optional[str] = None, payload: Optional[str] = None + self, + server_ip_address: IPv4Address, + server_password: Optional[str] = None, + payload: Optional[str] = None, + port_scan_p_of_success: float = 0.1, + data_manipulation_p_of_success: float = 0.1, + repeat: bool = False, ): """ Configure the DataManipulatorBot to communicate with a DatabaseService. @@ -29,26 +69,111 @@ class DataManipulationBot(DatabaseClient): :param server_ip_address: The IP address of the Node the DatabaseService is on. :param server_password: The password on the DatabaseService. :param payload: The data manipulation query payload. + :param port_scan_p_of_success: The probability of success for the port scan stage. + :param data_manipulation_p_of_success: The probability of success for the data manipulation stage. + :param repeat: Whether to repeat attacking once finished. """ self.server_ip_address = server_ip_address self.payload = payload self.server_password = server_password + self.port_scan_p_of_success = port_scan_p_of_success + self.data_manipulation_p_of_success = data_manipulation_p_of_success + self.repeat = repeat self.sys_log.info( - f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}." + f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}, " + f"{repeat=}." ) - def run(self): - """Run the DataManipulationBot.""" - if self.server_ip_address and self.payload: - self.sys_log.info(f"{self.name}: Attempting to start the {self.name}") - super().run() - else: - self.sys_log.error(f"Failed to start the {self.name} as it requires both a target_ip_address and payload.") + def _logon(self): + """ + Simulate the logon process as the initial stage of the attack. - def attack(self): - """Run the data manipulation attack.""" - if not self.connected: - self.connect() - if self.connected: - self.query(self.payload) - self.sys_log.info(f"{self.name} payload delivered: {self.payload}") + Advances the attack stage to `LOGON` if successful. + """ + if self.attack_stage == DataManipulationAttackStage.NOT_STARTED: + # Bypass this stage as we're not dealing with logon for now + self.sys_log.info(f"{self.name}: ") + self.attack_stage = DataManipulationAttackStage.LOGON + + def _perform_port_scan(self, p_of_success: Optional[float] = 0.1): + """ + Perform a simulated port scan to check for open SQL ports. + + Advances the attack stage to `PORT_SCAN` if successful. + + :param p_of_success: Probability of successful port scan, by default 0.1. + """ + if self.attack_stage == DataManipulationAttackStage.LOGON: + # perform a port scan to identify that the SQL port is open on the server + if simulate_trial(p_of_success): + self.sys_log.info(f"{self.name}: Performing port scan") + # perform the port scan + port_is_open = True # Temporary; later we can implement NMAP port scan. + if port_is_open: + self.sys_log.info(f"{self.name}: ") + self.attack_stage = DataManipulationAttackStage.PORT_SCAN + + def _perform_data_manipulation(self, p_of_success: Optional[float] = 0.1): + """ + Execute the data manipulation attack on the target. + + Advances the attack stage to `COMPLETE` if successful, or 'FAILED' if unsuccessful. + + :param p_of_success: Probability of successfully performing data manipulation, by default 0.1. + """ + if self.attack_stage == DataManipulationAttackStage.PORT_SCAN: + # perform the actual data manipulation attack + if simulate_trial(p_of_success): + self.sys_log.info(f"{self.name}: Performing data manipulation") + # perform the attack + if not self.connected: + self.connect() + if self.connected: + self.query(self.payload) + self.sys_log.info(f"{self.name} payload delivered: {self.payload}") + attack_successful = True + if attack_successful: + self.sys_log.info(f"{self.name}: Data manipulation successful") + self.attack_stage = DataManipulationAttackStage.SUCCEEDED + else: + self.sys_log.info(f"{self.name}: Data manipulation failed") + self.attack_stage = DataManipulationAttackStage.FAILED + + def run(self): + """ + Run the Data Manipulation Bot. + + Calls the parent classes execute method before starting the application loop. + """ + super().run() + self._application_loop() + + def _application_loop(self): + """ + The main application loop of the bot, handling the attack process. + + This is the core loop where the bot sequentially goes through the stages of the attack. + """ + if self.operating_state != ApplicationOperatingState.RUNNING: + return + if self.server_ip_address and self.payload and self.operating_state: + self.sys_log.info(f"{self.name}: Running") + self._logon() + self._perform_port_scan(p_of_success=self.port_scan_p_of_success) + self._perform_data_manipulation(p_of_success=self.data_manipulation_p_of_success) + + if self.repeat and self.attack_stage in ( + DataManipulationAttackStage.SUCCEEDED, + DataManipulationAttackStage.FAILED, + ): + self.attack_stage = DataManipulationAttackStage.NOT_STARTED + else: + self.sys_log.error(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.") + + def apply_timestep(self, timestep: int) -> None: + """ + Apply a timestep to the bot, triggering the application loop. + + :param timestep: The timestep value to update the bot's state. + """ + self._application_loop() diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 04a4603a..6d6cda86 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -80,6 +80,12 @@ class Service(IOSoftware): self.health_state_visible = SoftwareHealthState.UNUSED self.health_state_actual = SoftwareHealthState.UNUSED + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"operating_state", "restart_duration", "restart_countdown"} + self._original_state.update(self.model_dump(include=vals_to_include)) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request("scan", RequestType(func=lambda request, context: self.scan())) @@ -107,15 +113,6 @@ class Service(IOSoftware): state["health_state_visible"] = self.health_state_visible return state - def reset_component_for_episode(self, episode: int): - """ - Resets the Service component for a new episode. - - This method ensures the Service is ready for a new episode, including resetting any - stateful properties or statistics, and clearing any message queues. - """ - pass - def stop(self) -> None: """Stop the service.""" if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]: diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 63df2f7d..becbf9f9 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -17,7 +17,21 @@ from primaite.simulator.system.services.service import Service class WebServer(Service): """Class used to represent a Web Server Service in simulation.""" - last_response_status_code: Optional[HttpStatusCode] = None + _last_response_status_code: Optional[HttpStatusCode] = None + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + self._last_response_status_code = None + super().reset_component_for_episode(episode) + + @property + def last_response_status_code(self) -> HttpStatusCode: + """The latest http response code.""" + return self._last_response_status_code + + @last_response_status_code.setter + def last_response_status_code(self, val: Any): + self._last_response_status_code = val def describe_state(self) -> Dict: """ @@ -30,8 +44,9 @@ class WebServer(Service): """ state = super().describe_state() state["last_response_status_code"] = ( - self.last_response_status_code.value if self.last_response_status_code else None + self.last_response_status_code.value if isinstance(self.last_response_status_code, HttpStatusCode) else None ) + print(state) return state def __init__(self, **kwargs): diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 5564bd48..87802a7b 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -90,6 +90,19 @@ class Software(SimComponent): folder: Optional[Folder] = None "The folder on the file system the Software uses." + def set_original_state(self): + """Sets the original state.""" + vals_to_include = { + "name", + "health_state_actual", + "health_state_visible", + "criticality", + "patching_count", + "scanning_count", + "revealed_to_red", + } + self._original_state = self.model_dump(include=vals_to_include) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request( @@ -132,16 +145,6 @@ class Software(SimComponent): ) return state - def reset_component_for_episode(self, episode: int): - """ - Resets the software component for a new episode. - - This method should ensure the software is ready for a new episode, including resetting any - stateful properties or statistics, and clearing any message queues. The specifics of what constitutes a - "reset" should be implemented in subclasses. - """ - pass - def set_health_state(self, health_state: SoftwareHealthState) -> None: """ Assign a new health state to this software. @@ -204,6 +207,12 @@ class IOSoftware(Software): port: Port "The port to which the software is connected." + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"installing_count", "max_sessions", "tcp", "udp", "port"} + self._original_state.update(self.model_dump(include=vals_to_include)) + @abstractmethod def describe_state(self) -> Dict: """ diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index b5e43ab3..9070f246 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -27,14 +27,6 @@ agents: action_space: action_list: - type: DONOTHING - # - # - type: NODE_LOGON - # - type: NODE_LOGOFF - # - type: NODE_APPLICATION_EXECUTE - # options: - # execution_definition: - # target_address: arcd.com - options: nodes: - node_ref: client_2 @@ -48,10 +40,11 @@ agents: reward_components: - type: DUMMY - agent_settings: - start_step: 5 - frequency: 4 - variance: 3 + agent_settings: # options specific to this particular agent type, basically args of __init__(self) + start_settings: + start_step: 25 + frequency: 20 + variance: 5 - ref: client_1_data_manipulation_red_bot team: RED @@ -60,38 +53,20 @@ agents: observation_space: type: UC2RedObservation options: - nodes: - - node_ref: client_1 - observations: - - logon_status - - operating_status - services: - - service_ref: data_manipulation_bot - observations: - operating_status - health_status - folders: {} + nodes: {} action_space: action_list: - type: DONOTHING - # Node: network = arcd_uc2_network() + return network.get_node_by_hostname("client_1") - client_1: Node = network.get_node_by_hostname("client_1") - data_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"] +@pytest.fixture +def dm_bot(dm_client) -> DataManipulationBot: + return dm_client.software_manager.software["DataManipulationBot"] + + +def test_create_dm_bot(dm_client): + data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software["DataManipulationBot"] assert data_manipulation_bot.name == "DataManipulationBot" assert data_manipulation_bot.port == Port.POSTGRES_SERVER assert data_manipulation_bot.protocol == IPProtocol.TCP assert data_manipulation_bot.payload == "DELETE" + + +def test_dm_bot_logon(dm_bot): + dm_bot.attack_stage = DataManipulationAttackStage.NOT_STARTED + + dm_bot._logon() + + assert dm_bot.attack_stage == DataManipulationAttackStage.LOGON + + +def test_dm_bot_perform_port_scan_no_success(dm_bot): + dm_bot.attack_stage = DataManipulationAttackStage.LOGON + + dm_bot._perform_port_scan(p_of_success=0.0) + + assert dm_bot.attack_stage == DataManipulationAttackStage.LOGON + + +def test_dm_bot_perform_port_scan_success(dm_bot): + dm_bot.attack_stage = DataManipulationAttackStage.LOGON + + dm_bot._perform_port_scan(p_of_success=1.0) + + assert dm_bot.attack_stage == DataManipulationAttackStage.PORT_SCAN + + +def test_dm_bot_perform_data_manipulation_no_success(dm_bot): + dm_bot.attack_stage = DataManipulationAttackStage.PORT_SCAN + + dm_bot._perform_data_manipulation(p_of_success=0.0) + + assert dm_bot.attack_stage == DataManipulationAttackStage.PORT_SCAN + + +def test_dm_bot_perform_data_manipulation_success(dm_bot): + dm_bot.attack_stage = DataManipulationAttackStage.PORT_SCAN + dm_bot.operating_state = ApplicationOperatingState.RUNNING + + dm_bot._perform_data_manipulation(p_of_success=1.0) + + assert dm_bot.attack_stage in (DataManipulationAttackStage.SUCCEEDED, DataManipulationAttackStage.FAILED) + assert dm_bot.connected