From 23fd9c3839288a9839d0dc3327aac769a4c201f1 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Mon, 13 Nov 2023 15:55:14 +0000 Subject: [PATCH 01/35] #1859 - Started giving the red agent some 'intelligence' and a sense of a state. Changed Application.run to .execute. --- src/primaite/game/agent/GATE_agents.py | 8 +- src/primaite/game/agent/interface.py | 2 + src/primaite/game/science.py | 16 +++ src/primaite/game/session.py | 4 +- .../system/applications/application.py | 4 + .../system/applications/database_client.py | 8 +- .../system/applications/web_browser.py | 2 +- .../red_services/data_manipulation_bot.py | 134 +++++++++++++++--- tests/conftest.py | 1 - .../system/test_web_client_server.py | 6 +- 10 files changed, 151 insertions(+), 34 deletions(-) create mode 100644 src/primaite/game/science.py diff --git a/src/primaite/game/agent/GATE_agents.py b/src/primaite/game/agent/GATE_agents.py index e50d7831..e4ee16ca 100644 --- a/src/primaite/game/agent/GATE_agents.py +++ b/src/primaite/game/agent/GATE_agents.py @@ -19,10 +19,10 @@ class GATERLAgent(AbstractGATEAgent): def __init__( self, - agent_name: str | None, - action_space: ActionManager | None, - observation_space: ObservationSpace | None, - reward_function: RewardFunction | None, + agent_name: Optional[str], + action_space: Optional[ActionManager], + observation_space: Optional[ObservationSpace], + reward_function: Optional[RewardFunction], ) -> None: super().__init__(agent_name, action_space, observation_space, reward_function) self.most_recent_action: ActType diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 89f27f3f..78d18a68 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -109,6 +109,8 @@ class RandomAgent(AbstractScriptedAgent): """ return self.action_space.get_action(self.action_space.space.sample()) +class DataManipulationAgent(AbstractScriptedAgent): + pass class AbstractGATEAgent(AbstractAgent): """Base class for actors controlled via external messages, such as RL policies.""" diff --git a/src/primaite/game/science.py b/src/primaite/game/science.py new file mode 100644 index 00000000..f6215127 --- /dev/null +++ b/src/primaite/game/science.py @@ -0,0 +1,16 @@ +from random import random + + +def simulate_trial(p_of_success: float): + """ + Simulates the outcome of a single trial in a Bernoulli process. + + This function returns True with a probability 'p_of_success', simulating a success outcome in a single + trial of a Bernoulli process. When this function is executed multiple times, the set of outcomes follows + a binomial distribution. This is useful in scenarios where one needs to model or simulate events that + have two possible outcomes (success or failure) with a fixed probability of success. + + :param p_of_success: The probability of success in a single trial, ranging from 0 to 1. + :returns: True if the trial is successful (with probability 'p_of_success'); otherwise, False. + """ + return random() < p_of_success diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index d40d0754..9c2bb6b7 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -60,7 +60,7 @@ class PrimaiteGATEClient(GATEClient): return self.parent_session.training_options.rl_algorithm @property - def seed(self) -> int | None: + def seed(self) -> Optional[int]: """The seed to use for the environment's random number generator.""" return self.parent_session.training_options.seed @@ -115,7 +115,7 @@ class PrimaiteGATEClient(GATEClient): info = {} return obs, rew, term, trunc, info - def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> Tuple[ObsType, Dict]: + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) -> Tuple[ObsType, Dict]: """Reset the environment. This method is called when the environment is initialized and at the end of each episode. diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index db323cf6..7f79ac2b 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -65,6 +65,10 @@ class Application(IOSoftware): self.sys_log.info(f"Running Application {self.name}") self.operating_state = ApplicationOperatingState.RUNNING + def _application_loop(self): + """THe main application loop.""" + pass + def close(self) -> None: """Close the Application.""" if self.operating_state == ApplicationOperatingState.RUNNING: diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index d021cb78..28e826fd 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -128,11 +128,11 @@ class DatabaseClient(Application): ) return self._query(sql=sql, query_id=query_id, is_reattempt=True) - def run(self) -> None: + def execute(self) -> None: """Run the DatabaseClient.""" - super().run() - self.operating_state = ApplicationOperatingState.RUNNING - self.connect() + super().execute() + if self.operating_state == ApplicationOperatingState.RUNNING: + self.connect() def query(self, sql: str) -> bool: """ diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index ea9c3ac3..6799358d 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -30,7 +30,7 @@ class WebBrowser(Application): kwargs["port"] = Port.HTTP super().__init__(**kwargs) - self.run() + self.execute() def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py index 996e6790..aec7bbd8 100644 --- a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py +++ b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py @@ -1,27 +1,46 @@ +from enum import IntEnum from ipaddress import IPv4Address from typing import Optional +from primaite.game.science import simulate_trial +from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient +class DataManipulationAttackStage(IntEnum): + """ + Enumeration representing different stages of a data manipulation attack. + + This enumeration defines the various stages a data manipulation attack can be in during its lifecycle in the + simulation. Each stage represents a specific phase in the attack process. + """ + NOT_STARTED = 0 + "Indicates that the attack has not started yet." + LOGON = 1 + "The stage where logon procedures are simulated." + PORT_SCAN = 2 + "Represents the stage of performing a horizontal port scan on the target." + ATTACKING = 3 + "Stage of actively attacking the target." + COMPLETE = 4 + "Indicates the attack has been successfully completed." + FAILED = 5 + "Signifies that the attack has failed." + + class DataManipulationBot(DatabaseClient): - """ - Red Agent Data Integration Service. - - The Service represents a bot that causes files/folders in the File System to - become corrupted. - """ - + """A bot that simulates a script which performs a SQL injection attack.""" server_ip_address: Optional[IPv4Address] = None payload: Optional[str] = None server_password: Optional[str] = None + attack_stage: DataManipulationAttackStage = DataManipulationAttackStage.NOT_STARTED def __init__(self, **kwargs): super().__init__(**kwargs) self.name = "DataManipulationBot" def configure( - self, server_ip_address: IPv4Address, server_password: Optional[str] = None, payload: Optional[str] = None + self, server_ip_address: IPv4Address, server_password: Optional[str] = None, payload: Optional[str] = None ): """ Configure the DataManipulatorBot to communicate with a DatabaseService. @@ -37,15 +56,92 @@ class DataManipulationBot(DatabaseClient): f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}." ) - def run(self): - """Run the DataManipulationBot.""" - if self.server_ip_address and self.payload: - self.sys_log.info(f"{self.name}: Attempting to start the {self.name}") - super().run() - if not self.connected: - self.connect() - if self.connected: - self.query(self.payload) - self.sys_log.info(f"{self.name} payload delivered: {self.payload}") + def _logon(self): + """ + Simulate the logon process as the initial stage of the attack. + + Advances the attack stage to `LOGON` if successful. + """ + if self.attack_stage == DataManipulationAttackStage.NOT_STARTED: + # Bypass this stage as we're not dealing with logon for now + self.sys_log.info(f"{self.name}: ") + self.attack_stage = DataManipulationAttackStage.LOGON + + def _perform_port_scan(self, p_of_success: Optional[float] = 0.1): + """ + Perform a simulated port scan to check for open SQL ports. + + Advances the attack stage to `PORT_SCAN` if successful. + + :param p_of_success: Probability of successful port scan, by default 0.1. + """ + if self.attack_stage == DataManipulationAttackStage.LOGON: + # perform a port scan to identify that the SQL port is open on the server + if simulate_trial(p_of_success): + self.sys_log.info(f"{self.name}: Performing port scan") + # perform the port scan + port_is_open = True # Temporary; later we can implement NMAP port scan. + if port_is_open: + self.sys_log.info(f"{self.name}: ") + self.attack_stage = DataManipulationAttackStage.PORT_SCAN + + def _perform_data_manipulation(self, p_of_success: Optional[float] = 0.1): + """ + Execute the data manipulation attack on the target. + + Advances the attack stage to `COMPLETE` if successful, or 'FAILED' if unsuccessful. + + :param p_of_success: Probability of successfully performing data manipulation, by default 0.1. + """ + if self.attack_stage == DataManipulationAttackStage.PORT_SCAN: + # perform the actual data manipulation attack + if simulate_trial(p_of_success): + + self.sys_log.info(f"{self.name}: Performing port scan") + # perform the attack + if not self.connected: + self.connect() + if self.connected: + self.query(self.payload) + self.sys_log.info(f"{self.name} payload delivered: {self.payload}") + attack_successful = True + if attack_successful: + self.sys_log.info(f"{self.name}: Performing port scan") + self.attack_stage = DataManipulationAttackStage.COMPLETE + else: + self.sys_log.info(f"{self.name}: Performing port scan") + self.attack_stage = DataManipulationAttackStage.FAILED + + def execute(self): + """ + Execute the Data Manipulation Bot + + Calls the parent classes execute method before starting the application loop. + """ + super().execute() + self._application_loop() + + def _application_loop(self): + """ + The main application loop of the bot, handling the attack process. + + This is the core loop where the bot sequentially goes through the stages of the attack. + """ + + if self.operating_state != ApplicationOperatingState.RUNNING: + return + if self.server_ip_address and self.payload and self.operating_state: + self.sys_log.info(f"{self.name}: Running") + self._logon() + self._perform_port_scan() + self._perform_data_manipulation() else: - self.sys_log.error(f"Failed to start the {self.name} as it requires both a target_ip_address and payload.") + self.sys_log.error(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.") + + def apply_timestep(self, timestep: int) -> None: + """ + Apply a timestep to the bot, triggering the application loop. + + :param timestep: The timestep value to update the bot's state. + """ + self._application_loop() diff --git a/tests/conftest.py b/tests/conftest.py index dc749cfc..c046ca0d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,6 @@ from pathlib import Path from typing import Any, Dict, Union from unittest.mock import patch -import nodeenv import pytest from primaite import getLogger diff --git a/tests/integration_tests/system/test_web_client_server.py b/tests/integration_tests/system/test_web_client_server.py index f4546cbf..e36cff2b 100644 --- a/tests/integration_tests/system/test_web_client_server.py +++ b/tests/integration_tests/system/test_web_client_server.py @@ -10,7 +10,7 @@ def test_web_page_home_page(uc2_network): """Test to see if the browser is able to open the main page of the web server.""" client_1: Computer = uc2_network.get_node_by_hostname("client_1") web_client: WebBrowser = client_1.software_manager.software["WebBrowser"] - web_client.run() + web_client.execute() assert web_client.operating_state == ApplicationOperatingState.RUNNING assert web_client.get_webpage("http://arcd.com/") is True @@ -24,7 +24,7 @@ def test_web_page_get_users_page_request_with_domain_name(uc2_network): """Test to see if the client can handle requests with domain names""" client_1: Computer = uc2_network.get_node_by_hostname("client_1") web_client: WebBrowser = client_1.software_manager.software["WebBrowser"] - web_client.run() + web_client.execute() assert web_client.operating_state == ApplicationOperatingState.RUNNING assert web_client.get_webpage("http://arcd.com/users/") is True @@ -38,7 +38,7 @@ def test_web_page_get_users_page_request_with_ip_address(uc2_network): """Test to see if the client can handle requests that use ip_address.""" client_1: Computer = uc2_network.get_node_by_hostname("client_1") web_client: WebBrowser = client_1.software_manager.software["WebBrowser"] - web_client.run() + web_client.execute() web_server: Server = uc2_network.get_node_by_hostname("web_server") web_server_ip = web_server.nics.get(next(iter(web_server.nics))).ip_address From 1c5ff66d26599834f75c7cbc402fdf4f05a041c8 Mon Sep 17 00:00:00 2001 From: Jake Walker Date: Thu, 16 Nov 2023 13:26:30 +0000 Subject: [PATCH 02/35] Pass execution definition from config to agent --- .../config/_package_data/example_config.yaml | 4 ++++ src/primaite/game/agent/interface.py | 15 +++++++++++++-- src/primaite/game/session.py | 9 +++++++-- .../simulator/system/applications/web_browser.py | 3 +++ 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index ee42cf4f..f034f9ea 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -58,6 +58,10 @@ game_config: team: RED type: RedDatabaseCorruptingAgent + execution_definition: + port_scan_p_of_success: 0.1 + data_manipulation_p_of_success: 0.1 + observation_space: type: UC2RedObservation options: diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 78d18a68..d04b298e 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional, Tuple, TypeAlias, Union import numpy as np +from pydantic import BaseModel from primaite.game.agent.actions import ActionManager from primaite.game.agent.observations import ObservationSpace @@ -11,6 +12,11 @@ from primaite.game.agent.rewards import RewardFunction ObsType: TypeAlias = Union[Dict, np.ndarray] +class AgentExecutionDefinition(BaseModel): + port_scan_p_of_success: float = 0.1 + data_manipulation_p_of_success: float = 0.1 + + class AbstractAgent(ABC): """Base class for scripted and RL agents.""" @@ -20,6 +26,7 @@ class AbstractAgent(ABC): action_space: Optional[ActionManager], observation_space: Optional[ObservationSpace], reward_function: Optional[RewardFunction], + execution_definition: Optional[AgentExecutionDefinition] ) -> None: """ Initialize an agent. @@ -40,7 +47,7 @@ class AbstractAgent(ABC): # exection definiton converts CAOS action to Primaite simulator request, sometimes having to enrich the info # by for example specifying target ip addresses, or converting a node ID into a uuid - self.execution_definition = None + self.execution_definition = execution_definition or AgentExecutionDefinition() def convert_state_to_obs(self, state: Dict) -> ObsType: """ @@ -110,7 +117,11 @@ class RandomAgent(AbstractScriptedAgent): return self.action_space.get_action(self.action_space.space.sample()) class DataManipulationAgent(AbstractScriptedAgent): - pass + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: + return self.action_space.get_action(self.action_space.space.sample()) class AbstractGATEAgent(AbstractAgent): """Base class for actors controlled via external messages, such as RL policies.""" diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 9c2bb6b7..082ed281 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -10,7 +10,7 @@ from pydantic import BaseModel from primaite import getLogger from primaite.game.agent.actions import ActionManager -from primaite.game.agent.interface import AbstractAgent, RandomAgent +from primaite.game.agent.interface import AbstractAgent, RandomAgent, DataManipulationAgent, AgentExecutionDefinition from primaite.game.agent.observations import ObservationSpace from primaite.game.agent.rewards import RewardFunction from primaite.simulator.network.hardware.base import Link, NIC, Node @@ -438,6 +438,8 @@ class PrimaiteSession: # CREATE REWARD FUNCTION rew_function = RewardFunction.from_config(reward_function_cfg, session=sess) + execution_definition = AgentExecutionDefinition(**agent_cfg.get("execution_definition", {})) + # CREATE AGENT if agent_type == "GreenWebBrowsingAgent": # TODO: implement non-random agents and fix this parsing @@ -446,6 +448,7 @@ class PrimaiteSession: action_space=action_space, observation_space=obs_space, reward_function=rew_function, + execution_definition=execution_definition, ) sess.agents.append(new_agent) elif agent_type == "GATERLAgent": @@ -454,15 +457,17 @@ class PrimaiteSession: action_space=action_space, observation_space=obs_space, reward_function=rew_function, + execution_definition=execution_definition, ) sess.agents.append(new_agent) sess.rl_agent = new_agent elif agent_type == "RedDatabaseCorruptingAgent": - new_agent = RandomAgent( + new_agent = DataManipulationAgent( agent_name=agent_cfg["ref"], action_space=action_space, observation_space=obs_space, reward_function=rew_function, + execution_definition=execution_definition, ) sess.agents.append(new_agent) else: diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index 6799358d..964e1ce4 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -135,3 +135,6 @@ class WebBrowser(Application): self.sys_log.info(f"{self.name}: Received HTTP {payload.status_code.value}") self.latest_response = payload return True + + def execute(self): + pass From 227e73602f8468523da60e8fe983622959d9ae92 Mon Sep 17 00:00:00 2001 From: Jake Walker Date: Fri, 17 Nov 2023 11:51:19 +0000 Subject: [PATCH 03/35] Pass execution definition from config to agent --- src/primaite/game/agent/interface.py | 37 ++++++++++++++++++- src/primaite/game/session.py | 2 +- .../red_services/data_manipulation_bot.py | 10 +++-- 3 files changed, 42 insertions(+), 7 deletions(-) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index d04b298e..c591c554 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -1,6 +1,6 @@ """Interface for agents.""" from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple, TypeAlias, Union +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, TypeAlias, Union import numpy as np from pydantic import BaseModel @@ -8,13 +8,21 @@ from pydantic import BaseModel from primaite.game.agent.actions import ActionManager from primaite.game.agent.observations import ObservationSpace from primaite.game.agent.rewards import RewardFunction +from primaite.simulator.network.hardware.base import Node + +if TYPE_CHECKING: + from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot ObsType: TypeAlias = Union[Dict, np.ndarray] class AgentExecutionDefinition(BaseModel): + """Additional configuration for agents.""" + port_scan_p_of_success: float = 0.1 + "The probability of a port scan succeeding." data_manipulation_p_of_success: float = 0.1 + "The probability of data manipulation succeeding." class AbstractAgent(ABC): @@ -26,7 +34,7 @@ class AbstractAgent(ABC): action_space: Optional[ActionManager], observation_space: Optional[ObservationSpace], reward_function: Optional[RewardFunction], - execution_definition: Optional[AgentExecutionDefinition] + execution_definition: Optional[AgentExecutionDefinition], ) -> None: """ Initialize an agent. @@ -116,13 +124,38 @@ class RandomAgent(AbstractScriptedAgent): """ return self.action_space.get_action(self.action_space.space.sample()) + class DataManipulationAgent(AbstractScriptedAgent): + """Agent that uses a DataManipulationBot to perform an SQL injection attack.""" + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + # get node ids that are part of the agent's observation space + node_ids: List[str] = [n.where[-1] for n in self.observation_space.obs.nodes] + # get all nodes from their ids + nodes: List[Node] = [n for n_id, n in self.action_space.sim.network.nodes.items() if n_id in node_ids] + + # get execution definition for data manipulation bots + for node in nodes: + bot_sw: Optional["DataManipulationBot"] = node.software_manager.software.get("DataManipulationBot") + + if bot_sw is not None: + bot_sw.execution_definition = self.execution_definition + def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: + """Randomly sample an action from the action space. + + :param obs: _description_ + :type obs: ObsType + :param reward: _description_, defaults to None + :type reward: float, optional + :return: _description_ + :rtype: Tuple[str, Dict] + """ return self.action_space.get_action(self.action_space.space.sample()) + class AbstractGATEAgent(AbstractAgent): """Base class for actors controlled via external messages, such as RL policies.""" diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 082ed281..5f3fb7b9 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -10,7 +10,7 @@ from pydantic import BaseModel from primaite import getLogger from primaite.game.agent.actions import ActionManager -from primaite.game.agent.interface import AbstractAgent, RandomAgent, DataManipulationAgent, AgentExecutionDefinition +from primaite.game.agent.interface import AbstractAgent, AgentExecutionDefinition, DataManipulationAgent, RandomAgent from primaite.game.agent.observations import ObservationSpace from primaite.game.agent.rewards import RewardFunction from primaite.simulator.network.hardware.base import Link, NIC, Node diff --git a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py index aec7bbd8..35ea413a 100644 --- a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py +++ b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py @@ -2,6 +2,7 @@ from enum import IntEnum from ipaddress import IPv4Address from typing import Optional +from primaite.game.agent.interface import AgentExecutionDefinition from primaite.game.science import simulate_trial from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient @@ -14,6 +15,7 @@ class DataManipulationAttackStage(IntEnum): This enumeration defines the various stages a data manipulation attack can be in during its lifecycle in the simulation. Each stage represents a specific phase in the attack process. """ + NOT_STARTED = 0 "Indicates that the attack has not started yet." LOGON = 1 @@ -30,17 +32,19 @@ class DataManipulationAttackStage(IntEnum): class DataManipulationBot(DatabaseClient): """A bot that simulates a script which performs a SQL injection attack.""" + server_ip_address: Optional[IPv4Address] = None payload: Optional[str] = None server_password: Optional[str] = None attack_stage: DataManipulationAttackStage = DataManipulationAttackStage.NOT_STARTED + execution_definition: AgentExecutionDefinition = AgentExecutionDefinition() def __init__(self, **kwargs): super().__init__(**kwargs) self.name = "DataManipulationBot" def configure( - self, server_ip_address: IPv4Address, server_password: Optional[str] = None, payload: Optional[str] = None + self, server_ip_address: IPv4Address, server_password: Optional[str] = None, payload: Optional[str] = None ): """ Configure the DataManipulatorBot to communicate with a DatabaseService. @@ -96,7 +100,6 @@ class DataManipulationBot(DatabaseClient): if self.attack_stage == DataManipulationAttackStage.PORT_SCAN: # perform the actual data manipulation attack if simulate_trial(p_of_success): - self.sys_log.info(f"{self.name}: Performing port scan") # perform the attack if not self.connected: @@ -114,7 +117,7 @@ class DataManipulationBot(DatabaseClient): def execute(self): """ - Execute the Data Manipulation Bot + Execute the Data Manipulation Bot. Calls the parent classes execute method before starting the application loop. """ @@ -127,7 +130,6 @@ class DataManipulationBot(DatabaseClient): This is the core loop where the bot sequentially goes through the stages of the attack. """ - if self.operating_state != ApplicationOperatingState.RUNNING: return if self.server_ip_address and self.payload and self.operating_state: From 7e0e8a476817118005307fa129f025a00cad0360 Mon Sep 17 00:00:00 2001 From: Jake Walker Date: Mon, 20 Nov 2023 10:38:01 +0000 Subject: [PATCH 04/35] Pass agent settings from config to agent --- .../config/_package_data/example_config.yaml | 14 +++++++------ src/primaite/game/agent/interface.py | 21 +++++++++++++++++++ src/primaite/game/session.py | 12 ++++++++++- 3 files changed, 40 insertions(+), 7 deletions(-) diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index f034f9ea..700a45fd 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -50,9 +50,10 @@ game_config: - type: DUMMY agent_settings: - start_step: 5 - frequency: 4 - variance: 3 + start_settings: + start_step: 5 + frequency: 4 + variance: 3 - ref: client_1_data_manipulation_red_bot team: RED @@ -106,9 +107,10 @@ game_config: - type: DUMMY agent_settings: # options specific to this particular agent type, basically args of __init__(self) - start_step: 25 - frequency: 20 - variance: 5 + start_settings: + start_step: 25 + frequency: 20 + variance: 5 - ref: defender team: BLUE diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index c591c554..70eb1980 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -25,6 +25,24 @@ class AgentExecutionDefinition(BaseModel): "The probability of data manipulation succeeding." +class AgentStartSettings(BaseModel): + """Configuration values for when an agent starts performing actions.""" + + start_step: int = 5 + "The timestep at which an agent begins performing it's actions" + frequency: int = 5 + "The number of timesteps to wait between performing actions" + variance: int = 0 + "The amount the frequency can randomly change to" + + +class AgentSettings(BaseModel): + """Settings for configuring the operation of an agent.""" + + start_settings: Optional[AgentStartSettings] = None + "Configuration for when an agent begins performing it's actions" + + class AbstractAgent(ABC): """Base class for scripted and RL agents.""" @@ -35,6 +53,7 @@ class AbstractAgent(ABC): observation_space: Optional[ObservationSpace], reward_function: Optional[RewardFunction], execution_definition: Optional[AgentExecutionDefinition], + agent_settings: Optional[AgentSettings], ) -> None: """ Initialize an agent. @@ -57,6 +76,8 @@ class AbstractAgent(ABC): # by for example specifying target ip addresses, or converting a node ID into a uuid self.execution_definition = execution_definition or AgentExecutionDefinition() + self.agent_settings = agent_settings or AgentSettings() + def convert_state_to_obs(self, state: Dict) -> ObsType: """ Convert a state from the simulator into an observation for the agent using the observation space. diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 5f3fb7b9..9701fec9 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -10,7 +10,13 @@ from pydantic import BaseModel from primaite import getLogger from primaite.game.agent.actions import ActionManager -from primaite.game.agent.interface import AbstractAgent, AgentExecutionDefinition, DataManipulationAgent, RandomAgent +from primaite.game.agent.interface import ( + AbstractAgent, + AgentExecutionDefinition, + AgentSettings, + DataManipulationAgent, + RandomAgent, +) from primaite.game.agent.observations import ObservationSpace from primaite.game.agent.rewards import RewardFunction from primaite.simulator.network.hardware.base import Link, NIC, Node @@ -439,6 +445,7 @@ class PrimaiteSession: rew_function = RewardFunction.from_config(reward_function_cfg, session=sess) execution_definition = AgentExecutionDefinition(**agent_cfg.get("execution_definition", {})) + agent_settings = AgentSettings(**agent_cfg.get("agent_settings", {})) # CREATE AGENT if agent_type == "GreenWebBrowsingAgent": @@ -449,6 +456,7 @@ class PrimaiteSession: observation_space=obs_space, reward_function=rew_function, execution_definition=execution_definition, + agent_settings=agent_settings, ) sess.agents.append(new_agent) elif agent_type == "GATERLAgent": @@ -458,6 +466,7 @@ class PrimaiteSession: observation_space=obs_space, reward_function=rew_function, execution_definition=execution_definition, + agent_settings=agent_settings, ) sess.agents.append(new_agent) sess.rl_agent = new_agent @@ -468,6 +477,7 @@ class PrimaiteSession: observation_space=obs_space, reward_function=rew_function, execution_definition=execution_definition, + agent_settings=agent_settings, ) sess.agents.append(new_agent) else: From 2975aa882774c3b5979072646de64c243ab880b4 Mon Sep 17 00:00:00 2001 From: Jake Walker Date: Tue, 21 Nov 2023 11:42:01 +0000 Subject: [PATCH 05/35] Execute data manipulation bots from agent --- src/primaite/game/agent/interface.py | 38 ++++++++++++++++++- src/primaite/game/session.py | 4 +- .../system/applications/database_client.py | 2 +- 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 70eb1980..94878947 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -24,6 +24,20 @@ class AgentExecutionDefinition(BaseModel): data_manipulation_p_of_success: float = 0.1 "The probability of data manipulation succeeding." + @classmethod + def from_config(cls, config: Optional[Dict]) -> "AgentExecutionDefinition": + """Construct an AgentExecutionDefinition from a config dictionary. + + :param config: A dict of options for the execution definition. + :type config: Dict + :return: The execution definition. + :rtype: AgentExecutionDefinition + """ + if config is None: + return cls() + + return cls(**config) + class AgentStartSettings(BaseModel): """Configuration values for when an agent starts performing actions.""" @@ -42,6 +56,20 @@ class AgentSettings(BaseModel): start_settings: Optional[AgentStartSettings] = None "Configuration for when an agent begins performing it's actions" + @classmethod + def from_config(cls, config: Optional[Dict]) -> "AgentSettings": + """Construct agent settings from a config dictionary. + + :param config: A dict of options for the agent settings. + :type config: Dict + :return: The agent settings. + :rtype: AgentSettings + """ + if config is None: + return cls() + + return cls(**config) + class AbstractAgent(ABC): """Base class for scripted and RL agents.""" @@ -149,6 +177,8 @@ class RandomAgent(AbstractScriptedAgent): class DataManipulationAgent(AbstractScriptedAgent): """Agent that uses a DataManipulationBot to perform an SQL injection attack.""" + data_manipulation_bots: List["DataManipulationBot"] = [] + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -163,6 +193,7 @@ class DataManipulationAgent(AbstractScriptedAgent): if bot_sw is not None: bot_sw.execution_definition = self.execution_definition + self.data_manipulation_bots.append(bot_sw) def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: """Randomly sample an action from the action space. @@ -174,7 +205,12 @@ class DataManipulationAgent(AbstractScriptedAgent): :return: _description_ :rtype: Tuple[str, Dict] """ - return self.action_space.get_action(self.action_space.space.sample()) + # TODO: Move this to the appropriate place + # return self.action_space.get_action(self.action_space.space.sample()) + for bot in self.data_manipulation_bots: + bot.execute() + + return ("DONOTHING", {"dummy": 0}) class AbstractGATEAgent(AbstractAgent): diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 9701fec9..1b086c35 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -444,8 +444,8 @@ class PrimaiteSession: # CREATE REWARD FUNCTION rew_function = RewardFunction.from_config(reward_function_cfg, session=sess) - execution_definition = AgentExecutionDefinition(**agent_cfg.get("execution_definition", {})) - agent_settings = AgentSettings(**agent_cfg.get("agent_settings", {})) + execution_definition = AgentExecutionDefinition.from_config(agent_cfg.get("execution_definition")) + agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) # CREATE AGENT if agent_type == "GreenWebBrowsingAgent": diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 28e826fd..e15249e3 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -130,7 +130,7 @@ class DatabaseClient(Application): def execute(self) -> None: """Run the DatabaseClient.""" - super().execute() + # super().execute() if self.operating_state == ApplicationOperatingState.RUNNING: self.connect() From d8154bbebd4e6d98aaf6cf51628be1b176ad00b8 Mon Sep 17 00:00:00 2001 From: Jake Walker Date: Tue, 21 Nov 2023 11:43:47 +0000 Subject: [PATCH 06/35] Add tests for data manipulation bot attack stages --- .../test_data_manipulation_bot.py | 61 +++++++++++++++++-- 1 file changed, 57 insertions(+), 4 deletions(-) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py index dd785cc1..5127254c 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py @@ -1,20 +1,73 @@ from ipaddress import IPv4Address +import pytest + from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.networks import arcd_uc2_network from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port -from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot +from primaite.simulator.system.services.red_services.data_manipulation_bot import ( + DataManipulationAttackStage, + DataManipulationBot, +) -def test_creation(): +@pytest.fixture(scope="function") +def dm_client() -> Node: network = arcd_uc2_network() + return network.get_node_by_hostname("client_1") - client_1: Node = network.get_node_by_hostname("client_1") - data_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"] +@pytest.fixture +def dm_bot(dm_client) -> DataManipulationBot: + return dm_client.software_manager.software["DataManipulationBot"] + + +def test_create_dm_bot(dm_client): + data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software["DataManipulationBot"] assert data_manipulation_bot.name == "DataManipulationBot" assert data_manipulation_bot.port == Port.POSTGRES_SERVER assert data_manipulation_bot.protocol == IPProtocol.TCP assert data_manipulation_bot.payload == "DROP TABLE IF EXISTS user;" + + +def test_dm_bot_logon(dm_bot): + dm_bot.attack_stage = DataManipulationAttackStage.NOT_STARTED + + dm_bot._logon() + + assert dm_bot.attack_stage == DataManipulationAttackStage.LOGON + + +def test_dm_bot_perform_port_scan_no_success(dm_bot): + dm_bot.attack_stage = DataManipulationAttackStage.LOGON + + dm_bot._perform_port_scan(p_of_success=0.0) + + assert dm_bot.attack_stage == DataManipulationAttackStage.LOGON + + +def test_dm_bot_perform_port_scan_success(dm_bot): + dm_bot.attack_stage = DataManipulationAttackStage.LOGON + + dm_bot._perform_port_scan(p_of_success=1.0) + + assert dm_bot.attack_stage == DataManipulationAttackStage.PORT_SCAN + + +def test_dm_bot_perform_data_manipulation_no_success(dm_bot): + dm_bot.attack_stage = DataManipulationAttackStage.PORT_SCAN + + dm_bot._perform_data_manipulation(p_of_success=0.0) + + assert dm_bot.attack_stage == DataManipulationAttackStage.PORT_SCAN + + +def test_dm_bot_perform_data_manipulation_success(dm_bot): + dm_bot.attack_stage = DataManipulationAttackStage.PORT_SCAN + + dm_bot._perform_data_manipulation(p_of_success=1.0) + + assert dm_bot.attack_stage in (DataManipulationAttackStage.COMPLETE, DataManipulationAttackStage.FAILED) + assert dm_bot.connected From 48af0229637726c9fc953ecf54b2329947151a1a Mon Sep 17 00:00:00 2001 From: Jake Walker Date: Tue, 21 Nov 2023 13:41:38 +0000 Subject: [PATCH 07/35] Run agent at configured timesteps --- src/primaite/game/agent/interface.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 94878947..d2479b38 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -1,4 +1,5 @@ """Interface for agents.""" +import random from abc import ABC, abstractmethod from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, TypeAlias, Union @@ -178,10 +179,13 @@ class DataManipulationAgent(AbstractScriptedAgent): """Agent that uses a DataManipulationBot to perform an SQL injection attack.""" data_manipulation_bots: List["DataManipulationBot"] = [] + next_execution_timestep: int = 0 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.next_execution_timestep = self.agent_settings.start_settings.start_step + # get node ids that are part of the agent's observation space node_ids: List[str] = [n.where[-1] for n in self.observation_space.obs.nodes] # get all nodes from their ids @@ -207,10 +211,19 @@ class DataManipulationAgent(AbstractScriptedAgent): """ # TODO: Move this to the appropriate place # return self.action_space.get_action(self.action_space.space.sample()) + + timestep = self.action_space.session.step_counter + + if timestep < self.next_execution_timestep: + return "DONOTHING", {"dummy": 0} + + var = random.randint(-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance) + self.next_execution_timestep = timestep + self.agent_settings.start_settings.frequency + var + for bot in self.data_manipulation_bots: bot.execute() - return ("DONOTHING", {"dummy": 0}) + return "DONOTHING", {"dummy": 0} class AbstractGATEAgent(AbstractAgent): From aa65c53a95ad33b356a588ed054b9e0a0dfaf3cc Mon Sep 17 00:00:00 2001 From: Jake Walker Date: Tue, 21 Nov 2023 15:09:51 +0000 Subject: [PATCH 08/35] Pass probability of success through to functions --- .../system/services/red_services/data_manipulation_bot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py index 35ea413a..5e4e2d3f 100644 --- a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py +++ b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py @@ -135,8 +135,8 @@ class DataManipulationBot(DatabaseClient): if self.server_ip_address and self.payload and self.operating_state: self.sys_log.info(f"{self.name}: Running") self._logon() - self._perform_port_scan() - self._perform_data_manipulation() + self._perform_port_scan(p_of_success=self.execution_definition.port_scan_p_of_success) + self._perform_data_manipulation(p_of_success=self.execution_definition.data_manipulation_p_of_success) else: self.sys_log.error(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.") From 061e5081871a7f9143769b545ca6de8a44f8c158 Mon Sep 17 00:00:00 2001 From: Jake Walker Date: Wed, 22 Nov 2023 16:24:17 +0000 Subject: [PATCH 09/35] Add repeat option to data manipulation bot --- .../red_services/data_manipulation_bot.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py index 5e4e2d3f..eae3f0e3 100644 --- a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py +++ b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py @@ -38,13 +38,19 @@ class DataManipulationBot(DatabaseClient): server_password: Optional[str] = None attack_stage: DataManipulationAttackStage = DataManipulationAttackStage.NOT_STARTED execution_definition: AgentExecutionDefinition = AgentExecutionDefinition() + repeat: bool = False + "Whether to repeat attacking once finished." def __init__(self, **kwargs): super().__init__(**kwargs) self.name = "DataManipulationBot" def configure( - self, server_ip_address: IPv4Address, server_password: Optional[str] = None, payload: Optional[str] = None + self, + server_ip_address: IPv4Address, + server_password: Optional[str] = None, + payload: Optional[str] = None, + repeat: bool = False, ): """ Configure the DataManipulatorBot to communicate with a DatabaseService. @@ -52,12 +58,15 @@ class DataManipulationBot(DatabaseClient): :param server_ip_address: The IP address of the Node the DatabaseService is on. :param server_password: The password on the DatabaseService. :param payload: The data manipulation query payload. + :param repeat: Whether to repeat attacking once finished. """ self.server_ip_address = server_ip_address self.payload = payload self.server_password = server_password + self.repeat = repeat self.sys_log.info( - f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}." + f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}, " + f"{repeat=}." ) def _logon(self): @@ -100,7 +109,7 @@ class DataManipulationBot(DatabaseClient): if self.attack_stage == DataManipulationAttackStage.PORT_SCAN: # perform the actual data manipulation attack if simulate_trial(p_of_success): - self.sys_log.info(f"{self.name}: Performing port scan") + self.sys_log.info(f"{self.name}: Performing data manipulation") # perform the attack if not self.connected: self.connect() @@ -109,10 +118,10 @@ class DataManipulationBot(DatabaseClient): self.sys_log.info(f"{self.name} payload delivered: {self.payload}") attack_successful = True if attack_successful: - self.sys_log.info(f"{self.name}: Performing port scan") + self.sys_log.info(f"{self.name}: Data manipulation successful") self.attack_stage = DataManipulationAttackStage.COMPLETE else: - self.sys_log.info(f"{self.name}: Performing port scan") + self.sys_log.info(f"{self.name}: Data manipulation failed") self.attack_stage = DataManipulationAttackStage.FAILED def execute(self): @@ -137,6 +146,12 @@ class DataManipulationBot(DatabaseClient): self._logon() self._perform_port_scan(p_of_success=self.execution_definition.port_scan_p_of_success) self._perform_data_manipulation(p_of_success=self.execution_definition.data_manipulation_p_of_success) + + if self.repeat and self.attack_stage in ( + DataManipulationAttackStage.COMPLETE, + DataManipulationAttackStage.FAILED, + ): + self.attack_stage = DataManipulationAttackStage.NOT_STARTED else: self.sys_log.error(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.") From c93705867f28d2b68e3e98888a5de1cb424bc890 Mon Sep 17 00:00:00 2001 From: Jake Walker Date: Thu, 23 Nov 2023 15:53:47 +0000 Subject: [PATCH 10/35] Move configuration from agent to data manipulation bot --- .../config/_package_data/example_config.yaml | 14 +++--- src/primaite/game/agent/interface.py | 43 ------------------- src/primaite/game/session.py | 22 +++++----- .../red_services/data_manipulation_bot.py | 11 ++++- 4 files changed, 25 insertions(+), 65 deletions(-) diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 700a45fd..274da7aa 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -59,10 +59,6 @@ game_config: team: RED type: RedDatabaseCorruptingAgent - execution_definition: - port_scan_p_of_success: 0.1 - data_manipulation_p_of_success: 0.1 - observation_space: type: UC2RedObservation options: @@ -83,11 +79,6 @@ game_config: - type: DONOTHING # "AgentExecutionDefinition": - """Construct an AgentExecutionDefinition from a config dictionary. - - :param config: A dict of options for the execution definition. - :type config: Dict - :return: The execution definition. - :rtype: AgentExecutionDefinition - """ - if config is None: - return cls() - - return cls(**config) - - class AgentStartSettings(BaseModel): """Configuration values for when an agent starts performing actions.""" @@ -81,7 +57,6 @@ class AbstractAgent(ABC): action_space: Optional[ActionManager], observation_space: Optional[ObservationSpace], reward_function: Optional[RewardFunction], - execution_definition: Optional[AgentExecutionDefinition], agent_settings: Optional[AgentSettings], ) -> None: """ @@ -100,11 +75,6 @@ class AbstractAgent(ABC): self.action_space: Optional[ActionManager] = action_space self.observation_space: Optional[ObservationSpace] = observation_space self.reward_function: Optional[RewardFunction] = reward_function - - # exection definiton converts CAOS action to Primaite simulator request, sometimes having to enrich the info - # by for example specifying target ip addresses, or converting a node ID into a uuid - self.execution_definition = execution_definition or AgentExecutionDefinition() - self.agent_settings = agent_settings or AgentSettings() def convert_state_to_obs(self, state: Dict) -> ObsType: @@ -186,19 +156,6 @@ class DataManipulationAgent(AbstractScriptedAgent): self.next_execution_timestep = self.agent_settings.start_settings.start_step - # get node ids that are part of the agent's observation space - node_ids: List[str] = [n.where[-1] for n in self.observation_space.obs.nodes] - # get all nodes from their ids - nodes: List[Node] = [n for n_id, n in self.action_space.sim.network.nodes.items() if n_id in node_ids] - - # get execution definition for data manipulation bots - for node in nodes: - bot_sw: Optional["DataManipulationBot"] = node.software_manager.software.get("DataManipulationBot") - - if bot_sw is not None: - bot_sw.execution_definition = self.execution_definition - self.data_manipulation_bots.append(bot_sw) - def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: """Randomly sample an action from the action space. diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 1b086c35..f675e33c 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -10,13 +10,7 @@ from pydantic import BaseModel from primaite import getLogger from primaite.game.agent.actions import ActionManager -from primaite.game.agent.interface import ( - AbstractAgent, - AgentExecutionDefinition, - AgentSettings, - DataManipulationAgent, - RandomAgent, -) +from primaite.game.agent.interface import AbstractAgent, AgentSettings, DataManipulationAgent, RandomAgent from primaite.game.agent.observations import ObservationSpace from primaite.game.agent.rewards import RewardFunction from primaite.simulator.network.hardware.base import Link, NIC, Node @@ -366,6 +360,16 @@ class PrimaiteSession: if "domain_mapping" in opt: for domain, ip in opt["domain_mapping"].items(): new_service.dns_register(domain, ip) + if service_type == "DataManipulationBot": + if "options" in service_cfg: + opt = service_cfg["options"] + new_service.configure( + server_ip_address=opt.get("server_ip"), + payload=opt.get("payload"), + port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), + data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")), + ) + if "applications" in node_cfg: for application_cfg in node_cfg["applications"]: application_ref = application_cfg["ref"] @@ -444,7 +448,6 @@ class PrimaiteSession: # CREATE REWARD FUNCTION rew_function = RewardFunction.from_config(reward_function_cfg, session=sess) - execution_definition = AgentExecutionDefinition.from_config(agent_cfg.get("execution_definition")) agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) # CREATE AGENT @@ -455,7 +458,6 @@ class PrimaiteSession: action_space=action_space, observation_space=obs_space, reward_function=rew_function, - execution_definition=execution_definition, agent_settings=agent_settings, ) sess.agents.append(new_agent) @@ -465,7 +467,6 @@ class PrimaiteSession: action_space=action_space, observation_space=obs_space, reward_function=rew_function, - execution_definition=execution_definition, agent_settings=agent_settings, ) sess.agents.append(new_agent) @@ -476,7 +477,6 @@ class PrimaiteSession: action_space=action_space, observation_space=obs_space, reward_function=rew_function, - execution_definition=execution_definition, agent_settings=agent_settings, ) sess.agents.append(new_agent) diff --git a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py index eae3f0e3..e3f5b95d 100644 --- a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py +++ b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py @@ -2,7 +2,6 @@ from enum import IntEnum from ipaddress import IPv4Address from typing import Optional -from primaite.game.agent.interface import AgentExecutionDefinition from primaite.game.science import simulate_trial from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient @@ -36,8 +35,10 @@ class DataManipulationBot(DatabaseClient): server_ip_address: Optional[IPv4Address] = None payload: Optional[str] = None server_password: Optional[str] = None + port_scan_p_of_success: float = 0.1 + data_manipulation_p_of_success: float = 0.1 + attack_stage: DataManipulationAttackStage = DataManipulationAttackStage.NOT_STARTED - execution_definition: AgentExecutionDefinition = AgentExecutionDefinition() repeat: bool = False "Whether to repeat attacking once finished." @@ -50,6 +51,8 @@ class DataManipulationBot(DatabaseClient): server_ip_address: IPv4Address, server_password: Optional[str] = None, payload: Optional[str] = None, + port_scan_p_of_success: float = 0.1, + data_manipulation_p_of_success: float = 0.1, repeat: bool = False, ): """ @@ -58,11 +61,15 @@ class DataManipulationBot(DatabaseClient): :param server_ip_address: The IP address of the Node the DatabaseService is on. :param server_password: The password on the DatabaseService. :param payload: The data manipulation query payload. + :param port_scan_p_of_success: The probability of success for the port scan stage. + :param data_manipulation_p_of_success: The probability of success for the data manipulation stage. :param repeat: Whether to repeat attacking once finished. """ self.server_ip_address = server_ip_address self.payload = payload self.server_password = server_password + self.port_scan_p_of_success = port_scan_p_of_success + self.data_manipulation_p_of_success = data_manipulation_p_of_success self.repeat = repeat self.sys_log.info( f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}, " From 5f1a5af1b45eccf5154f35e0143282dc3491e089 Mon Sep 17 00:00:00 2001 From: Jake Walker Date: Thu, 23 Nov 2023 16:06:19 +0000 Subject: [PATCH 11/35] Add data manipulation bot action manager --- .../config/_package_data/example_config.yaml | 8 +-- src/primaite/game/agent/actions.py | 49 +++++++++++++++++++ src/primaite/game/agent/interface.py | 27 +++++----- .../red_services/data_manipulation_bot.py | 8 +++ .../test_data_manipulation_bot.py | 2 - 5 files changed, 76 insertions(+), 18 deletions(-) diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 274da7aa..aff54d62 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -78,12 +78,12 @@ game_config: action_list: - type: DONOTHING # None: + super().__init__(manager=manager) + self.shape: Dict[str, int] = {"node_id": num_nodes, "application_id": num_applications} + self.verb: str + + def form_request(self, node_id: int, application_id: int) -> List[str]: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + node_uuid = self.manager.get_node_uuid_by_idx(node_id) + application_uuid = self.manager.get_application_uuid_by_idx(node_id, application_id) + if node_uuid is None or application_uuid is None: + return ["do_nothing"] + return ["network", "node", node_uuid, "application", application_uuid, self.verb] + + +class NodeApplicationExecuteAction(NodeApplicationAbstractAction): + """Action which executes an application.""" + + def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None: + super().__init__(manager=manager, num_nodes=num_nodes, num_applications=num_applications) + self.verb = "execute" + + class NodeFolderAbstractAction(AbstractAction): """ Base class for folder actions. @@ -536,6 +567,7 @@ class ActionManager: "NODE_SERVICE_RESTART": NodeServiceRestartAction, "NODE_SERVICE_DISABLE": NodeServiceDisableAction, "NODE_SERVICE_ENABLE": NodeServiceEnableAction, + "NODE_APPLICATION_EXECUTE": NodeApplicationExecuteAction, "NODE_FILE_SCAN": NodeFileScanAction, "NODE_FILE_CHECKHASH": NodeFileCheckhashAction, "NODE_FILE_DELETE": NodeFileDeleteAction, @@ -565,6 +597,7 @@ class ActionManager: max_folders_per_node: int = 2, # allows calculating shape max_files_per_folder: int = 2, # allows calculating shape max_services_per_node: int = 2, # allows calculating shape + max_applications_per_node: int = 10, # allows calculating shape max_nics_per_node: int = 8, # allows calculating shape max_acl_rules: int = 10, # allows calculating shape protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol @@ -622,6 +655,7 @@ class ActionManager: "num_folders": max_folders_per_node, "num_files": max_files_per_folder, "num_services": max_services_per_node, + "num_applications": max_applications_per_node, "num_nics": max_nics_per_node, "num_acl_rules": max_acl_rules, "num_protocols": len(self.protocols), @@ -775,6 +809,21 @@ class ActionManager: service_uuids = list(node.services.keys()) return service_uuids[service_idx] if len(service_uuids) > service_idx else None + def get_application_uuid_by_idx(self, node_idx: int, application_idx: int) -> Optional[str]: + """Get the application UUID corresponding to the given node and service indices. + + :param node_idx: The index of the node. + :type node_idx: int + :param application_idx: The index of the service on the node. + :type application_idx: int + :return: The UUID of the service. Or None if the node has fewer services than the given index. + :rtype: Optional[str] + """ + node_uuid = self.get_node_uuid_by_idx(node_idx) + node = self.sim.network.nodes[node_uuid] + application_uuids = list(node.applications.keys()) + return application_uuids[application_idx] if len(application_uuids) > application_idx else None + def get_internet_protocol_by_idx(self, protocol_idx: int) -> str: """Get the internet protocol corresponding to the given index. diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 5e73a423..33932df2 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -154,7 +154,17 @@ class DataManipulationAgent(AbstractScriptedAgent): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.next_execution_timestep = self.agent_settings.start_settings.start_step + self._set_next_execution_timestep(self.agent_settings.start_settings.start_step) + + def _set_next_execution_timestep(self, timestep: int) -> None: + """Set the next execution timestep with a configured random variance. + + :param timestep: The timestep to add variance to. + """ + random_timestep_increment = random.randint( + -self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance + ) + self.next_execution_timestep = timestep + random_timestep_increment def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: """Randomly sample an action from the action space. @@ -166,21 +176,14 @@ class DataManipulationAgent(AbstractScriptedAgent): :return: _description_ :rtype: Tuple[str, Dict] """ - # TODO: Move this to the appropriate place - # return self.action_space.get_action(self.action_space.space.sample()) + current_timestep = self.action_space.session.step_counter - timestep = self.action_space.session.step_counter - - if timestep < self.next_execution_timestep: + if current_timestep < self.next_execution_timestep: return "DONOTHING", {"dummy": 0} - var = random.randint(-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance) - self.next_execution_timestep = timestep + self.agent_settings.start_settings.frequency + var + self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency) - for bot in self.data_manipulation_bots: - bot.execute() - - return "DONOTHING", {"dummy": 0} + return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0} class AbstractGATEAgent(AbstractAgent): diff --git a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py index e3f5b95d..f4b31cb1 100644 --- a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py +++ b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py @@ -3,6 +3,7 @@ from ipaddress import IPv4Address from typing import Optional from primaite.game.science import simulate_trial +from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient @@ -46,6 +47,13 @@ class DataManipulationBot(DatabaseClient): super().__init__(**kwargs) self.name = "DataManipulationBot" + def _init_request_manager(self) -> RequestManager: + rm = super()._init_request_manager() + + rm.add_request(name="execute", request_type=RequestType(func=self.execute)) + + return rm + def configure( self, server_ip_address: IPv4Address, diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py index 5127254c..04e23e84 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py @@ -1,5 +1,3 @@ -from ipaddress import IPv4Address - import pytest from primaite.simulator.network.hardware.base import Node From b13a9d3daf34f38992b19f7854cbbf0eeb3e2723 Mon Sep 17 00:00:00 2001 From: Jake Walker Date: Fri, 24 Nov 2023 09:25:55 +0000 Subject: [PATCH 12/35] Add application execution action for data manipulation bot --- .../config/_package_data/example_config.yaml | 9 ++++++--- src/primaite/game/agent/actions.py | 7 +++---- src/primaite/game/session.py | 14 ++++++++++++++ .../system/applications/database_client.py | 2 +- .../services/red_services/data_manipulation_bot.py | 2 +- 5 files changed, 25 insertions(+), 9 deletions(-) diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index aff54d62..8ea1c83c 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -67,8 +67,8 @@ game_config: observations: - logon_status - operating_status - services: - - service_ref: data_manipulation_bot + applications: + - application_ref: data_manipulation_bot observations: operating_status health_status @@ -89,6 +89,8 @@ game_config: options: nodes: - node_ref: client_1 + applications: + - application_ref: data_manipulation_bot max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 @@ -650,7 +652,7 @@ simulation: subnet_mask: 255.255.255.0 default_gateway: 192.168.10.1 dns_server: 192.168.1.10 - services: + applications: - ref: data_manipulation_bot type: DataManipulationBot options: @@ -658,6 +660,7 @@ simulation: data_manipulation_p_of_success: 0.1 payload: "DROP TABLE IF EXISTS user;" server_ip: 192.168.1.14 + services: - ref: client_1_dns_client type: DNSClient diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 0c78dac7..64d89722 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -594,6 +594,7 @@ class ActionManager: session: "PrimaiteSession", # reference to session for looking up stuff actions: List[str], # stores list of actions available to agent node_uuids: List[str], # allows mapping index to node + application_uuids: List[List[str]], # allows mapping index to application max_folders_per_node: int = 2, # allows calculating shape max_files_per_folder: int = 2, # allows calculating shape max_services_per_node: int = 2, # allows calculating shape @@ -635,6 +636,7 @@ class ActionManager: self.session: "PrimaiteSession" = session self.sim: Simulation = self.session.simulation self.node_uuids: List[str] = node_uuids + self.application_uuids: List[List[str]] = application_uuids self.protocols: List[str] = protocols self.ports: List[str] = ports @@ -819,10 +821,7 @@ class ActionManager: :return: The UUID of the service. Or None if the node has fewer services than the given index. :rtype: Optional[str] """ - node_uuid = self.get_node_uuid_by_idx(node_idx) - node = self.sim.network.nodes[node_uuid] - application_uuids = list(node.applications.keys()) - return application_uuids[application_idx] if len(application_uuids) > application_idx else None + return self.application_uuids[node_idx][application_idx] def get_internet_protocol_by_idx(self, protocol_idx: int) -> str: """Get the internet protocol corresponding to the given index. diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index f675e33c..cc4036ef 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -426,11 +426,25 @@ class PrimaiteSession: # CREATE ACTION SPACE action_space_cfg["options"]["node_uuids"] = [] + action_space_cfg["options"]["application_uuids"] = [] + # if a list of nodes is defined, convert them from node references to node UUIDs for action_node_option in action_space_cfg.get("options", {}).pop("nodes", {}): if "node_ref" in action_node_option: node_uuid = sess.ref_map_nodes[action_node_option["node_ref"]] action_space_cfg["options"]["node_uuids"].append(node_uuid) + + if "applications" in action_node_option: + node_application_uuids = [] + for application_option in action_node_option["applications"]: + # TODO: remove inconsistency with the above nodes + application_uuid = sess.ref_map_applications[application_option["application_ref"]].uuid + node_application_uuids.append(application_uuid) + + action_space_cfg["options"]["application_uuids"].append(node_application_uuids) + else: + action_space_cfg["options"]["application_uuids"].append([]) + # Each action space can potentially have a different list of nodes that it can apply to. Therefore, # we will pass node_uuids as a part of the action space config. # However, it's not possible to specify the node uuids directly in the config, as they are generated diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index e15249e3..9d85221e 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -141,7 +141,7 @@ class DatabaseClient(Application): :param sql: The SQL query. :return: True if the query was successful, otherwise False. """ - if self.connected and self.operating_state.RUNNING: + if self.connected and self.operating_state == ApplicationOperatingState.RUNNING: query_id = str(uuid4()) # Initialise the tracker of this ID to False diff --git a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py index f4b31cb1..0ec64950 100644 --- a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py +++ b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py @@ -50,7 +50,7 @@ class DataManipulationBot(DatabaseClient): def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() - rm.add_request(name="execute", request_type=RequestType(func=self.execute)) + rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.execute())) return rm From 92dabe59f7d31a270d7e4b937e3075eeb114f913 Mon Sep 17 00:00:00 2001 From: Jake Walker Date: Fri, 24 Nov 2023 10:04:19 +0000 Subject: [PATCH 13/35] Fix data manipulation bot configuration --- src/primaite/game/session.py | 26 +++++++++++-------- .../red_services/data_manipulation_bot.py | 4 +-- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index cc4036ef..286de498 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -331,6 +331,7 @@ class PrimaiteSession: print("invalid node type") if "services" in node_cfg: for service_cfg in node_cfg["services"]: + new_service = None service_ref = service_cfg["ref"] service_type = service_cfg["type"] service_types_mapping = { @@ -339,7 +340,6 @@ class PrimaiteSession: "DatabaseClient": DatabaseClient, "DatabaseService": DatabaseService, "WebServer": WebServer, - "DataManipulationBot": DataManipulationBot, } if service_type in service_types_mapping: print(f"installing {service_type} on node {new_node.hostname}") @@ -360,22 +360,15 @@ class PrimaiteSession: if "domain_mapping" in opt: for domain, ip in opt["domain_mapping"].items(): new_service.dns_register(domain, ip) - if service_type == "DataManipulationBot": - if "options" in service_cfg: - opt = service_cfg["options"] - new_service.configure( - server_ip_address=opt.get("server_ip"), - payload=opt.get("payload"), - port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), - data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")), - ) if "applications" in node_cfg: for application_cfg in node_cfg["applications"]: + new_application = None application_ref = application_cfg["ref"] application_type = application_cfg["type"] application_types_mapping = { "WebBrowser": WebBrowser, + "DataManipulationBot": DataManipulationBot, } if application_type in application_types_mapping: new_node.software_manager.install(application_types_mapping[application_type]) @@ -383,6 +376,16 @@ class PrimaiteSession: sess.ref_map_applications[application_ref] = new_application else: print(f"application type not found {application_type}") + + if application_type == "DataManipulationBot": + if "options" in application_cfg: + opt = application_cfg["options"] + new_application.configure( + server_ip_address=opt.get("server_ip"), + payload=opt.get("payload"), + port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), + data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")), + ) if "nics" in node_cfg: for nic_num, nic_cfg in node_cfg["nics"].items(): new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"])) @@ -437,7 +440,8 @@ class PrimaiteSession: if "applications" in action_node_option: node_application_uuids = [] for application_option in action_node_option["applications"]: - # TODO: remove inconsistency with the above nodes + # TODO: fix inconsistency with node uuids and application uuids. The node object get added to + # node_uuid, whereas here the application gets added by uuid. application_uuid = sess.ref_map_applications[application_option["application_ref"]].uuid node_application_uuids.append(application_uuid) diff --git a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py index 0ec64950..2b0bed30 100644 --- a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py +++ b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py @@ -159,8 +159,8 @@ class DataManipulationBot(DatabaseClient): if self.server_ip_address and self.payload and self.operating_state: self.sys_log.info(f"{self.name}: Running") self._logon() - self._perform_port_scan(p_of_success=self.execution_definition.port_scan_p_of_success) - self._perform_data_manipulation(p_of_success=self.execution_definition.data_manipulation_p_of_success) + self._perform_port_scan(p_of_success=self.port_scan_p_of_success) + self._perform_data_manipulation(p_of_success=self.data_manipulation_p_of_success) if self.repeat and self.attack_stage in ( DataManipulationAttackStage.COMPLETE, From 178d911be005fc7f888d1aa1e679d6268a66cda3 Mon Sep 17 00:00:00 2001 From: Jake Walker Date: Fri, 24 Nov 2023 10:05:36 +0000 Subject: [PATCH 14/35] Update data manipulation bot --- .../system/data_manipulation_bot.rst | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/docs/source/simulation_components/system/data_manipulation_bot.rst b/docs/source/simulation_components/system/data_manipulation_bot.rst index c9f8977a..e93c4e54 100644 --- a/docs/source/simulation_components/system/data_manipulation_bot.rst +++ b/docs/source/simulation_components/system/data_manipulation_bot.rst @@ -8,6 +8,8 @@ DataManipulationBot The ``DataManipulationBot`` class provides functionality to connect to a ``DatabaseService`` and execute malicious SQL statements. +The bot is controlled by a ``DataManipulationAgent``. + Overview -------- @@ -16,15 +18,25 @@ The bot is intended to simulate a malicious actor carrying out attacks like: - Dropping tables - Deleting records - Modifying data + On a database server by abusing an application's trusted database connectivity. +The bot performs attacks in the following stages to simulate the real pattern of an attack: + +- Logon - *The bot gains access to the node.* +- Port Scan - *The bot finds accessible database servers on the network.* +- Attacking - *The bot delivers the payload to the discovered database servers.* + +Each of these stages has a random, configurable probability of succeeding. The bot can also be configured to repeat the attack once complete. + Usage ----- - Create an instance and call ``configure`` to set: - - Target database server IP - - Database password (if needed) - - SQL statement payload + - Target database server IP + - Database password (if needed) + - SQL statement payload + - Probabilities for succeeding each of the above attack stages - Call ``run`` to connect and execute the statement. The bot handles connecting, executing the statement, and disconnecting. @@ -52,7 +64,7 @@ Implementation The bot extends ``DatabaseClient`` and leverages its connectivity. - Uses the Application base class for lifecycle management. -- Credentials and target IP set via ``configure``. +- Credentials, target IP and other options set via ``configure``. - ``run`` handles connecting, executing statement, and disconnecting. - SQL payload executed via ``query`` method. - Results in malicious SQL being executed on remote database server. From ff8b773c102243549d66eeaa357fa56df9be4094 Mon Sep 17 00:00:00 2001 From: Jake Walker Date: Fri, 24 Nov 2023 11:10:34 +0000 Subject: [PATCH 15/35] Database Manipulation Bot bug fixes --- .../config/_package_data/example_config.yaml | 2 +- src/primaite/game/agent/interface.py | 4 +- src/primaite/simulator/network/networks.py | 7 ++- .../system/applications/database_client.py | 4 +- .../red_services/data_manipulation_bot.py | 8 +-- .../assets/configs/bad_primaite_session.yaml | 51 ++++++++++++------- .../configs/eval_only_primaite_session.yaml | 45 +++++++++------- .../assets/configs/test_primaite_session.yaml | 41 ++++++++------- .../configs/train_only_primaite_session.yaml | 45 +++++++++------- 9 files changed, 124 insertions(+), 83 deletions(-) diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 270760f5..af872a01 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -665,7 +665,7 @@ simulation: options: port_scan_p_of_success: 0.1 data_manipulation_p_of_success: 0.1 - payload: "DROP TABLE IF EXISTS user;" + payload: "DELETE" server_ip: 192.168.1.14 services: - ref: client_1_dns_client diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index ff0986a8..38116987 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -58,7 +58,7 @@ class AbstractAgent(ABC): action_space: Optional[ActionManager], observation_space: Optional[ObservationManager], reward_function: Optional[RewardFunction], - agent_settings: Optional[AgentSettings], + agent_settings: Optional[AgentSettings] = None, ) -> None: """ Initialize an agent. @@ -217,7 +217,7 @@ class DataManipulationAgent(AbstractScriptedAgent): :return: _description_ :rtype: Tuple[str, Dict] """ - current_timestep = self.action_space.session.step_counter + current_timestep = self.action_manager.session.step_counter if current_timestep < self.next_execution_timestep: return "DONOTHING", {"dummy": 0} diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index c0f9a07e..ea767b54 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -140,7 +140,12 @@ def arcd_uc2_network() -> Network: network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1]) client_1.software_manager.install(DataManipulationBot) db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"] - db_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE") + db_manipulation_bot.configure( + server_ip_address=IPv4Address("192.168.1.14"), + payload="DELETE", + port_scan_p_of_success=1.0, + data_manipulation_p_of_success=1.0, + ) # Client 2 client_2 = Computer( diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index a5c213cd..da2299c4 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -129,9 +129,9 @@ class DatabaseClient(Application): ) return self._query(sql=sql, query_id=query_id, is_reattempt=True) - def execute(self) -> None: + def run(self) -> None: """Run the DatabaseClient.""" - super().execute() + super().run() if self.operating_state == ApplicationOperatingState.RUNNING: self.connect() diff --git a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py index 2b0bed30..17b89386 100644 --- a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py +++ b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py @@ -50,7 +50,7 @@ class DataManipulationBot(DatabaseClient): def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() - rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.execute())) + rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.run())) return rm @@ -139,13 +139,13 @@ class DataManipulationBot(DatabaseClient): self.sys_log.info(f"{self.name}: Data manipulation failed") self.attack_stage = DataManipulationAttackStage.FAILED - def execute(self): + def run(self): """ - Execute the Data Manipulation Bot. + Run the Data Manipulation Bot. Calls the parent classes execute method before starting the application loop. """ - super().execute() + super().run() self._application_loop() def _application_loop(self): diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index 80567aea..6344eac0 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -2,9 +2,17 @@ training_config: rl_framework: SB3 rl_algorithm: PPO se3ed: 333 # Purposeful typo to check that error is raised with bad configuration. - n_learn_steps: 2560 + n_learn_episodes: 25 n_eval_episodes: 5 + max_steps_per_episode: 128 + deterministic_eval: false + n_agents: 1 + agent_references: + - defender +io_settings: + save_checkpoints: true + checkpoint_interval: 5 game_config: @@ -49,9 +57,10 @@ game_config: - type: DUMMY agent_settings: - start_step: 5 - frequency: 4 - variance: 3 + start_settings: + start_step: 5 + frequency: 4 + variance: 3 - ref: client_1_data_manipulation_red_bot team: RED @@ -65,8 +74,8 @@ game_config: observations: - logon_status - operating_status - services: - - service_ref: data_manipulation_bot + applications: + - application_ref: data_manipulation_bot observations: operating_status health_status @@ -76,22 +85,19 @@ game_config: action_list: - type: DONOTHING # Date: Fri, 24 Nov 2023 11:52:33 +0000 Subject: [PATCH 16/35] #1859 - DB query now returns false if the query isn't ran due to the node being off --- src/primaite/game/agent/interface.py | 5 +---- .../simulator/system/applications/database_client.py | 1 + 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 38116987..b321b17c 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -1,9 +1,8 @@ """Interface for agents.""" import random from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, TypeAlias, Union +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING -import numpy as np from gymnasium.core import ActType, ObsType from pydantic import BaseModel @@ -14,8 +13,6 @@ from primaite.game.agent.rewards import RewardFunction if TYPE_CHECKING: from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot -ObsType: TypeAlias = Union[Dict, np.ndarray] - class AgentStartSettings(BaseModel): """Configuration values for when an agent starts performing actions.""" diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index da2299c4..3c4f1b75 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -148,6 +148,7 @@ class DatabaseClient(Application): # Initialise the tracker of this ID to False self._query_success_tracker[query_id] = False return self._query(sql=sql, query_id=query_id) + return False def receive(self, payload: Any, session_id: str, **kwargs) -> bool: """ From e609f8eb50e935515a0d63ad85e9321404f8fd98 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 24 Nov 2023 14:56:17 +0000 Subject: [PATCH 17/35] Fix misconfiguration in uc2 config and session --- .../config/_package_data/example_config.yaml | 18 +++++++++-- src/primaite/game/session.py | 31 ++++++++++++++++--- .../assets/configs/bad_primaite_session.yaml | 18 +++++++++-- .../configs/eval_only_primaite_session.yaml | 18 +++++++++-- .../assets/configs/test_primaite_session.yaml | 18 +++++++++-- .../configs/train_only_primaite_session.yaml | 18 +++++++++-- 6 files changed, 102 insertions(+), 19 deletions(-) diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index af872a01..6455272c 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -560,7 +560,7 @@ simulation: ip_address: 192.168.1.1 subnet_mask: 255.255.255.0 2: - ip_address: 192.168.1.1 + ip_address: 192.168.10.1 subnet_mask: 255.255.255.0 acl: 0: @@ -571,6 +571,14 @@ simulation: action: PERMIT src_port: DNS dst_port: DNS + 2: + action: PERMIT + src_port: FTP + dst_port: FTP + 3: + action: PERMIT + src_port: HTTP + dst_port: HTTP 22: action: PERMIT src_port: ARP @@ -607,7 +615,7 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - ref: web_server_database_client @@ -628,6 +636,10 @@ simulation: services: - ref: database_service type: DatabaseService + options: + backup_server_ip: 192.168.1.16 + - ref: database_ftp_client + type: FTPClient - ref: backup_server type: server @@ -638,7 +650,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 7856cc9f..f0dcdd61 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -16,7 +16,7 @@ from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction from primaite.game.io import SessionIO, SessionIOSettings from primaite.game.policy.policy import PolicyABC -from primaite.simulator.network.hardware.base import Link, NIC, Node +from primaite.simulator.network.hardware.base import Link, NIC, Node, NodeOperatingState from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.server import Server @@ -30,6 +30,8 @@ from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer +from primaite.simulator.system.services.ftp.ftp_client import FTPClient +from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.web_server.web_server import WebServer @@ -334,6 +336,7 @@ class PrimaiteSession: subnet_mask=node_cfg["subnet_mask"], default_gateway=node_cfg["default_gateway"], dns_server=node_cfg["dns_server"], + operating_state=NodeOperatingState.ON, ) elif n_type == "server": new_node = Server( @@ -342,16 +345,26 @@ class PrimaiteSession: subnet_mask=node_cfg["subnet_mask"], default_gateway=node_cfg["default_gateway"], dns_server=node_cfg.get("dns_server"), + operating_state=NodeOperatingState.ON, ) elif n_type == "switch": - new_node = Switch(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports")) + new_node = Switch( + hostname=node_cfg["hostname"], + num_ports=node_cfg.get("num_ports"), + operating_state=NodeOperatingState.ON, + ) elif n_type == "router": - new_node = Router(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports")) + new_node = Router( + hostname=node_cfg["hostname"], + num_ports=node_cfg.get("num_ports"), + operating_state=NodeOperatingState.ON, + ) if "ports" in node_cfg: for port_num, port_cfg in node_cfg["ports"].items(): new_node.configure_port( port=port_num, ip_address=port_cfg["ip_address"], subnet_mask=port_cfg["subnet_mask"] ) + # new_node.enable_port(port_num) if "acl" in node_cfg: for r_num, r_cfg in node_cfg["acl"].items(): # excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating @@ -379,6 +392,8 @@ class PrimaiteSession: "DatabaseClient": DatabaseClient, "DatabaseService": DatabaseService, "WebServer": WebServer, + "FTPClient": FTPClient, + "FTPServer": FTPServer, } if service_type in service_types_mapping: print(f"installing {service_type} on node {new_node.hostname}") @@ -399,6 +414,12 @@ class PrimaiteSession: if "domain_mapping" in opt: for domain, ip in opt["domain_mapping"].items(): new_service.dns_register(domain, ip) + if service_type == "DatabaseService": + if "options" in service_cfg: + opt = service_cfg["options"] + if "backup_server_ip" in opt: + new_service.configure_backup(backup_server=IPv4Address(opt["backup_server_ip"])) + new_service.start() if "applications" in node_cfg: for application_cfg in node_cfg["applications"]: @@ -435,7 +456,7 @@ class PrimaiteSession: node_ref ] = ( new_node.uuid - ) # TODO: fix incosistency with service and link. Node gets added by uuid, but service by object + ) # TODO: fix inconsistency with service and link. Node gets added by uuid, but service by object # 2. create links between nodes for link_cfg in links_cfg: @@ -451,6 +472,8 @@ class PrimaiteSession: endpoint_b = node_b.ethernet_port[link_cfg["endpoint_b_port"]] new_link = net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b) sess.ref_map_links[link_cfg["ref"]] = new_link.uuid + # endpoint_a.enable() + # endpoint_b.enable() # 3. create agents game_cfg = cfg["game_config"] diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index 6344eac0..4d8e4669 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -560,7 +560,7 @@ simulation: ip_address: 192.168.1.1 subnet_mask: 255.255.255.0 2: - ip_address: 192.168.1.1 + ip_address: 192.168.10.1 subnet_mask: 255.255.255.0 acl: 0: @@ -571,6 +571,14 @@ simulation: action: PERMIT src_port: DNS dst_port: DNS + 2: + action: PERMIT + src_port: FTP + dst_port: FTP + 3: + action: PERMIT + src_port: HTTP + dst_port: HTTP 22: action: PERMIT src_port: ARP @@ -607,7 +615,7 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - ref: web_server_database_client @@ -628,6 +636,10 @@ simulation: services: - ref: database_service type: DatabaseService + options: + backup_server_ip: 192.168.1.16 + - ref: database_ftp_client + type: FTPClient - ref: backup_server type: server @@ -638,7 +650,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/tests/assets/configs/eval_only_primaite_session.yaml b/tests/assets/configs/eval_only_primaite_session.yaml index aa8c8b1f..27a18d9f 100644 --- a/tests/assets/configs/eval_only_primaite_session.yaml +++ b/tests/assets/configs/eval_only_primaite_session.yaml @@ -560,7 +560,7 @@ simulation: ip_address: 192.168.1.1 subnet_mask: 255.255.255.0 2: - ip_address: 192.168.1.1 + ip_address: 192.168.10.1 subnet_mask: 255.255.255.0 acl: 0: @@ -571,6 +571,14 @@ simulation: action: PERMIT src_port: DNS dst_port: DNS + 2: + action: PERMIT + src_port: FTP + dst_port: FTP + 3: + action: PERMIT + src_port: HTTP + dst_port: HTTP 22: action: PERMIT src_port: ARP @@ -607,7 +615,7 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - ref: web_server_database_client @@ -628,6 +636,10 @@ simulation: services: - ref: database_service type: DatabaseService + options: + backup_server_ip: 192.168.1.16 + - ref: database_ftp_client + type: FTPClient - ref: backup_server type: server @@ -638,7 +650,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index 8133c5d9..64be5488 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -560,7 +560,7 @@ simulation: ip_address: 192.168.1.1 subnet_mask: 255.255.255.0 2: - ip_address: 192.168.1.1 + ip_address: 192.168.10.1 subnet_mask: 255.255.255.0 acl: 0: @@ -571,6 +571,14 @@ simulation: action: PERMIT src_port: DNS dst_port: DNS + 2: + action: PERMIT + src_port: FTP + dst_port: FTP + 3: + action: PERMIT + src_port: HTTP + dst_port: HTTP 22: action: PERMIT src_port: ARP @@ -607,7 +615,7 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - ref: web_server_database_client @@ -628,6 +636,10 @@ simulation: services: - ref: database_service type: DatabaseService + options: + backup_server_ip: 192.168.1.16 + - ref: database_ftp_client + type: FTPClient - ref: backup_server type: server @@ -638,7 +650,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/tests/assets/configs/train_only_primaite_session.yaml b/tests/assets/configs/train_only_primaite_session.yaml index f1e317d3..4cfe4df4 100644 --- a/tests/assets/configs/train_only_primaite_session.yaml +++ b/tests/assets/configs/train_only_primaite_session.yaml @@ -560,7 +560,7 @@ simulation: ip_address: 192.168.1.1 subnet_mask: 255.255.255.0 2: - ip_address: 192.168.1.1 + ip_address: 192.168.10.1 subnet_mask: 255.255.255.0 acl: 0: @@ -571,6 +571,14 @@ simulation: action: PERMIT src_port: DNS dst_port: DNS + 2: + action: PERMIT + src_port: FTP + dst_port: FTP + 3: + action: PERMIT + src_port: HTTP + dst_port: HTTP 22: action: PERMIT src_port: ARP @@ -607,7 +615,7 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - ref: web_server_database_client @@ -628,6 +636,10 @@ simulation: services: - ref: database_service type: DatabaseService + options: + backup_server_ip: 192.168.1.16 + - ref: database_ftp_client + type: FTPClient - ref: backup_server type: server @@ -638,7 +650,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server From e6f75f8b320f188475782b5564cc3f0bcc3413fe Mon Sep 17 00:00:00 2001 From: Jake Walker Date: Fri, 24 Nov 2023 15:15:24 +0000 Subject: [PATCH 18/35] Improve data manipulation bot documentation --- .../system/data_manipulation_bot.rst | 76 ++++++++++++++++++- 1 file changed, 72 insertions(+), 4 deletions(-) diff --git a/docs/source/simulation_components/system/data_manipulation_bot.rst b/docs/source/simulation_components/system/data_manipulation_bot.rst index 03f2208b..eeae0b0a 100644 --- a/docs/source/simulation_components/system/data_manipulation_bot.rst +++ b/docs/source/simulation_components/system/data_manipulation_bot.rst @@ -8,8 +8,6 @@ DataManipulationBot The ``DataManipulationBot`` class provides functionality to connect to a ``DatabaseService`` and execute malicious SQL statements. -The bot is controlled by a ``DataManipulationAgent``. - Overview -------- @@ -23,11 +21,11 @@ On a database server by abusing an application's trusted database connectivity. The bot performs attacks in the following stages to simulate the real pattern of an attack: -- Logon - *The bot gains access to the node.* +- Logon - *The bot gains credentials and accesses the node.* - Port Scan - *The bot finds accessible database servers on the network.* - Attacking - *The bot delivers the payload to the discovered database servers.* -Each of these stages has a random, configurable probability of succeeding. The bot can also be configured to repeat the attack once complete. +Each of these stages has a random, configurable probability of succeeding (by default 10%). The bot can also be configured to repeat the attack once complete. Usage ----- @@ -41,6 +39,8 @@ Usage The bot handles connecting, executing the statement, and disconnecting. +In a simulation, the bot can be controlled by using ``DataManipulationAgent`` which calls ``run`` on the bot at configured timesteps. + Example ------- @@ -58,6 +58,74 @@ Example This would connect to the database service at 192.168.1.14, authenticate, and execute the SQL statement to drop the 'users' table. +Example with ``DataManipulationAgent`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If not using the data manipulation bot manually, it needs to be used with a data manipulation agent. Below is an example section of configuration file for setting up a simulation with data manipulation bot and agent. + +.. code-block:: yaml + + game_config: + # ... + agents: + - ref: data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + applications: + - application_ref: data_manipulation_bot + observations: + operating_status + health_status + folders: {} + + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_ref: client_1 + applications: + - application_ref: data_manipulation_bot + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_settings: + start_step: 25 + frequency: 20 + variance: 5 + # ... + + simulation: + network: + nodes: + - ref: client_1 + type: computer + # ... additional configuration here + applications: + - ref: data_manipulation_bot + type: DataManipulationBot + options: + port_scan_p_of_success: 0.1 + data_manipulation_p_of_success: 0.1 + payload: "DELETE" + server_ip: 192.168.1.14 + Implementation -------------- From c5cfbb825a275398d56799253b11cc3656d20777 Mon Sep 17 00:00:00 2001 From: Jake Walker Date: Fri, 24 Nov 2023 15:15:45 +0000 Subject: [PATCH 19/35] Fix database client connect method --- src/primaite/simulator/system/applications/database_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 3c4f1b75..b24b6062 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -54,7 +54,7 @@ class DatabaseClient(Application): def connect(self) -> bool: """Connect to a Database Service.""" - if not self.connected and self.operating_state.RUNNING: + if not self.connected and self.operating_state == ApplicationOperatingState.RUNNING: return self._connect(self.server_ip_address, self.server_password) return False From e62ca22cb7d45fe7fa8b03d582b3c3f8fc66f676 Mon Sep 17 00:00:00 2001 From: Jake Walker Date: Fri, 24 Nov 2023 15:53:07 +0000 Subject: [PATCH 20/35] Fix data manipulation bot tests --- .../red_services/data_manipulation_bot.py | 20 +++++++++---------- .../test_data_manipulation_bot.py | 2 ++ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py index 17b89386..6db9e1aa 100644 --- a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py +++ b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py @@ -128,16 +128,16 @@ class DataManipulationBot(DatabaseClient): # perform the attack if not self.connected: self.connect() - if self.connected: - self.query(self.payload) - self.sys_log.info(f"{self.name} payload delivered: {self.payload}") - attack_successful = True - if attack_successful: - self.sys_log.info(f"{self.name}: Data manipulation successful") - self.attack_stage = DataManipulationAttackStage.COMPLETE - else: - self.sys_log.info(f"{self.name}: Data manipulation failed") - self.attack_stage = DataManipulationAttackStage.FAILED + if self.connected: + self.query(self.payload) + self.sys_log.info(f"{self.name} payload delivered: {self.payload}") + attack_successful = True + if attack_successful: + self.sys_log.info(f"{self.name}: Data manipulation successful") + self.attack_stage = DataManipulationAttackStage.COMPLETE + else: + self.sys_log.info(f"{self.name}: Data manipulation failed") + self.attack_stage = DataManipulationAttackStage.FAILED def run(self): """ diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py index 8a78beae..936f7c5c 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py @@ -4,6 +4,7 @@ from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.networks import arcd_uc2_network from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.services.red_services.data_manipulation_bot import ( DataManipulationAttackStage, DataManipulationBot, @@ -64,6 +65,7 @@ def test_dm_bot_perform_data_manipulation_no_success(dm_bot): def test_dm_bot_perform_data_manipulation_success(dm_bot): dm_bot.attack_stage = DataManipulationAttackStage.PORT_SCAN + dm_bot.operating_state = ApplicationOperatingState.RUNNING dm_bot._perform_data_manipulation(p_of_success=1.0) From 08c1b3cfb99ceae8aefbecf2331b868393ff1f59 Mon Sep 17 00:00:00 2001 From: Jake Walker Date: Fri, 24 Nov 2023 15:56:04 +0000 Subject: [PATCH 21/35] Fix code style issues --- src/primaite/game/science.py | 2 +- src/primaite/simulator/system/applications/application.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/primaite/game/science.py b/src/primaite/game/science.py index f6215127..19a86237 100644 --- a/src/primaite/game/science.py +++ b/src/primaite/game/science.py @@ -1,7 +1,7 @@ from random import random -def simulate_trial(p_of_success: float): +def simulate_trial(p_of_success: float) -> bool: """ Simulates the outcome of a single trial in a Bernoulli process. diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 7f79ac2b..9a58c98a 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -66,7 +66,7 @@ class Application(IOSoftware): self.operating_state = ApplicationOperatingState.RUNNING def _application_loop(self): - """THe main application loop.""" + """The main application loop.""" pass def close(self) -> None: From afce6ca5159db196e50997207c3e4a637712e925 Mon Sep 17 00:00:00 2001 From: Jake Walker Date: Fri, 24 Nov 2023 16:04:11 +0000 Subject: [PATCH 22/35] Update changelog for data manipulator bot & agent --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3af5c14c..9ddd0398 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,7 +31,8 @@ SessionManager. - `DatabaseClient` and `DatabaseService` created to allow emulation of database actions - Ability for `DatabaseService` to backup its data to another server via FTP and restore data from backup - Red Agent Services: - - Data Manipulator Bot - A red agent service which sends a payload to a target machine. (By default this payload is a SQL query that breaks a database) + - Data Manipulator Bot - A red agent service which sends a payload to a target machine. (By default this payload is a SQL query that breaks a database). The attack runs in stages with a random, configurable probability of succeeding. + - `DataManipulationAgent` runs the Data Manipulator Bot according to a configured start step, frequency and variance. - DNS Services: `DNSClient` and `DNSServer` - FTP Services: `FTPClient` and `FTPServer` - HTTP Services: `WebBrowser` to simulate a web client and `WebServer` From cbdaa6c44418ba5d34c2221313054785defdf978 Mon Sep 17 00:00:00 2001 From: Jake Walker Date: Fri, 24 Nov 2023 16:32:04 +0000 Subject: [PATCH 23/35] Move data manipulation agent into individual file --- .../game/agent/data_manipulation_agent.py | 0 .../game/agent/data_manipulation_bot.py | 48 +++++++++++++++++++ src/primaite/game/agent/interface.py | 44 +---------------- src/primaite/game/session.py | 3 +- 4 files changed, 51 insertions(+), 44 deletions(-) create mode 100644 src/primaite/game/agent/data_manipulation_agent.py create mode 100644 src/primaite/game/agent/data_manipulation_bot.py diff --git a/src/primaite/game/agent/data_manipulation_agent.py b/src/primaite/game/agent/data_manipulation_agent.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/game/agent/data_manipulation_bot.py b/src/primaite/game/agent/data_manipulation_bot.py new file mode 100644 index 00000000..51221154 --- /dev/null +++ b/src/primaite/game/agent/data_manipulation_bot.py @@ -0,0 +1,48 @@ +import random +from typing import Dict, List, Tuple + +from gymnasium.core import ObsType + +from primaite.game.agent.interface import AbstractScriptedAgent +from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot + + +class DataManipulationAgent(AbstractScriptedAgent): + """Agent that uses a DataManipulationBot to perform an SQL injection attack.""" + + data_manipulation_bots: List["DataManipulationBot"] = [] + next_execution_timestep: int = 0 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._set_next_execution_timestep(self.agent_settings.start_settings.start_step) + + def _set_next_execution_timestep(self, timestep: int) -> None: + """Set the next execution timestep with a configured random variance. + + :param timestep: The timestep to add variance to. + """ + random_timestep_increment = random.randint( + -self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance + ) + self.next_execution_timestep = timestep + random_timestep_increment + + def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: + """Randomly sample an action from the action space. + + :param obs: _description_ + :type obs: ObsType + :param reward: _description_, defaults to None + :type reward: float, optional + :return: _description_ + :rtype: Tuple[str, Dict] + """ + current_timestep = self.action_manager.session.step_counter + + if current_timestep < self.next_execution_timestep: + return "DONOTHING", {"dummy": 0} + + self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency) + + return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0} diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index b321b17c..6e783725 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -1,5 +1,4 @@ """Interface for agents.""" -import random from abc import ABC, abstractmethod from typing import Dict, List, Optional, Tuple, TYPE_CHECKING @@ -11,7 +10,7 @@ from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction if TYPE_CHECKING: - from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot + pass class AgentStartSettings(BaseModel): @@ -183,47 +182,6 @@ class ProxyAgent(AbstractAgent): self.most_recent_action = action -class DataManipulationAgent(AbstractScriptedAgent): - """Agent that uses a DataManipulationBot to perform an SQL injection attack.""" - - data_manipulation_bots: List["DataManipulationBot"] = [] - next_execution_timestep: int = 0 - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self._set_next_execution_timestep(self.agent_settings.start_settings.start_step) - - def _set_next_execution_timestep(self, timestep: int) -> None: - """Set the next execution timestep with a configured random variance. - - :param timestep: The timestep to add variance to. - """ - random_timestep_increment = random.randint( - -self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance - ) - self.next_execution_timestep = timestep + random_timestep_increment - - def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: - """Randomly sample an action from the action space. - - :param obs: _description_ - :type obs: ObsType - :param reward: _description_, defaults to None - :type reward: float, optional - :return: _description_ - :rtype: Tuple[str, Dict] - """ - current_timestep = self.action_manager.session.step_counter - - if current_timestep < self.next_execution_timestep: - return "DONOTHING", {"dummy": 0} - - self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency) - - return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0} - - class AbstractGATEAgent(AbstractAgent): """Base class for actors controlled via external messages, such as RL policies.""" diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index f0dcdd61..095458b7 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -11,7 +11,8 @@ from pydantic import BaseModel, ConfigDict from primaite import getLogger from primaite.game.agent.actions import ActionManager -from primaite.game.agent.interface import AbstractAgent, AgentSettings, DataManipulationAgent, ProxyAgent, RandomAgent +from primaite.game.agent.data_manipulation_bot import DataManipulationAgent +from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction from primaite.game.io import SessionIO, SessionIOSettings From ece9b14d6365c73b4278320c605e4f85113d613d Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 26 Nov 2023 23:29:14 +0000 Subject: [PATCH 24/35] Resolve merge conflicts --- docs/source/primaite_session.rst | 215 +-- pyproject.toml | 1 + .../config/_package_data/example_config.yaml | 983 +++++++------- .../example_config_2_rl_agents.yaml | 1164 +++++++++++++++++ src/primaite/game/agent/actions.py | 26 +- .../game/agent/data_manipulation_bot.py | 2 +- src/primaite/game/agent/observations.py | 110 +- src/primaite/game/agent/rewards.py | 40 +- src/primaite/game/{session.py => game.py} | 238 +--- src/primaite/game/policy/__init__.py | 3 - src/primaite/main.py | 8 +- .../training_example_ray_multi_agent.ipynb | 127 ++ .../training_example_ray_single_agent.ipynb | 122 ++ .../notebooks/training_example_sb3.ipynb | 102 ++ src/primaite/notebooks/uc2_demo.ipynb | 306 +++++ src/primaite/session/__init__.py | 0 src/primaite/session/environment.py | 162 +++ src/primaite/{game => session}/io.py | 0 src/primaite/session/policy/__init__.py | 4 + .../{game => session}/policy/policy.py | 4 +- src/primaite/session/policy/rllib.py | 106 ++ src/primaite/{game => session}/policy/sb3.py | 4 +- src/primaite/session/session.py | 113 ++ .../assets/configs/bad_primaite_session.yaml | 1003 +++++++------- .../configs/eval_only_primaite_session.yaml | 1002 +++++++------- tests/assets/configs/multi_agent_session.yaml | 1156 ++++++++++++++++ .../assets/configs/test_primaite_session.yaml | 999 +++++++------- .../configs/train_only_primaite_session.yaml | 1003 +++++++------- tests/conftest.py | 3 +- .../test_rllib_multi_agent_environment.py | 45 + .../test_rllib_single_agent_environment.py | 40 + .../environments/test_sb3_environment.py | 27 + .../test_primaite_session.py | 24 +- 33 files changed, 6074 insertions(+), 3068 deletions(-) create mode 100644 src/primaite/config/_package_data/example_config_2_rl_agents.yaml rename src/primaite/game/{session.py => game.py} (71%) delete mode 100644 src/primaite/game/policy/__init__.py create mode 100644 src/primaite/notebooks/training_example_ray_multi_agent.ipynb create mode 100644 src/primaite/notebooks/training_example_ray_single_agent.ipynb create mode 100644 src/primaite/notebooks/training_example_sb3.ipynb create mode 100644 src/primaite/notebooks/uc2_demo.ipynb create mode 100644 src/primaite/session/__init__.py create mode 100644 src/primaite/session/environment.py rename src/primaite/{game => session}/io.py (100%) create mode 100644 src/primaite/session/policy/__init__.py rename src/primaite/{game => session}/policy/policy.py (93%) create mode 100644 src/primaite/session/policy/rllib.py rename src/primaite/{game => session}/policy/sb3.py (96%) create mode 100644 src/primaite/session/session.py create mode 100644 tests/assets/configs/multi_agent_session.yaml create mode 100644 tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py create mode 100644 tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py create mode 100644 tests/e2e_integration_tests/environments/test_sb3_environment.py diff --git a/docs/source/primaite_session.rst b/docs/source/primaite_session.rst index 472a361f..f3ef0399 100644 --- a/docs/source/primaite_session.rst +++ b/docs/source/primaite_session.rst @@ -7,207 +7,28 @@ Run a PrimAITE Session ====================== +``PrimaiteSession`` allows the user to train or evaluate an RL agent on the primaite simulation with just a config file, +no code required. It manages the lifecycle of a training or evaluation session, including the setup of the environment, +policy, simulator, agents, and IO. + +If you want finer control over the RL policy, you can interface with the :py:module::`primaite.session.environment` +module directly without running a session. + + + Run --- -A PrimAITE session can be ran either with the ``primaite session`` command from the cli +A PrimAITE session can be started either with the ``primaite session`` command from the cli (See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` from a Python terminal or Jupyter Notebook. -Both the ``primaite session`` and :func:`primaite.main.run` take a training config and a lay down config as parameters. -.. note:: - 🚧 *UNDER CONSTRUCTION* 🚧 +There are two parameters that can be specified: + - ``--config``: The path to the config file to use. If not specified, the default config file is used. + - ``--agent-load-file``: The path to the pre-trained agent to load. If not specified, a new agent is created. -.. - .. code-block:: bash - :caption: Unix CLI +Outputs +------- - cd ~/primaite/2.0.0 - source ./.venv/bin/activate - primaite session --tc ./config/my_training_config.yaml --ldc ./config/my_lay_down_config.yaml - - .. code-block:: powershell - :caption: Powershell CLI - - cd ~\primaite\2.0.0 - .\.venv\Scripts\activate - primaite session --tc .\config\my_training_config.yaml --ldc .\config\my_lay_down_config.yaml - - - .. code-block:: python - :caption: Python - - from primaite.main import run - - training_config = - lay_down_config = - run(training_config, lay_down_config) - - When a session is ran, a session output sub-directory is created in the users app sessions directory (``~/primaite/2.0.0/sessions``). - The sub-directory is formatted as such: ``~/primaite/2.0.0/sessions//_/`` - - For example, when running a session at 17:30:00 on 31st January 2023, the session will output to: - ``~/primaite/2.0.0/sessions/2023-01-31/2023-01-31_17-30-00/``. - - ``primaite session`` can be ran in the terminal/command prompt without arguments. It will use the default configs in the directory ``primaite/config/example_config``. - - To run a PrimAITE session using legacy training or laydown config files, add the ``--legacy-tc`` and/or ``legacy-ldc`` options. - - - - .. code-block:: bash - :caption: Unix CLI - - cd ~/primaite/2.0.0 - source ./.venv/bin/activate - primaite session --tc ./config/my_legacy_training_config.yaml --legacy-tc --ldc ./config/my_legacy_lay_down_config.yaml --legacy-ldc - - .. code-block:: powershell - :caption: Powershell CLI - - cd ~\primaite\2.0.0 - .\.venv\Scripts\activate - primaite session --tc .\config\my_legacy_training_config.yaml --legacy-tc --ldc .\config\my_legacy_lay_down_config.yaml --legacy-ldc - - - .. code-block:: python - :caption: Python - - from primaite.main import run - - training_config = - lay_down_config = - run(training_config, lay_down_config, legacy_training_config=True, legacy_lay_down_config=True) - - - - - Outputs - ------- - - PrimAITE produces four types of outputs: - - * Session Metadata - * Results - * Diagrams - * Saved agents (training checkpoints and a final trained agent) - - - **Session Metadata** - - PrimAITE creates a ``session_metadata.json`` file that contains the following metadata: - - * **uuid** - The UUID assigned to the session upon instantiation. - * **start_datetime** - The date & time the session started in iso format. - * **end_datetime** - The date & time the session ended in iso format. - * **learning** - * **total_episodes** - The total number of training episodes completed. - * **total_time_steps** - The total number of training time steps completed. - * **evaluation** - * **total_episodes** - The total number of evaluation episodes completed. - * **total_time_steps** - The total number of evaluation time steps completed. - * **env** - * **training_config** - * **All training config items** - * **lay_down_config** - * **All lay down config items** - - - **Results** - - PrimAITE automatically creates two sets of results from each learning and evaluation session: - - * Average reward per episode - a csv file listing the average reward for each episode of the session. This provides, for example, an indication of the change over a training session of the reward value - * All transactions - a csv file listing the following values for every step of every episode: - - * Timestamp - * Episode number - * Step number - * Reward value - * Action taken (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X - * Initial observation space (what the blue agent observed when it decided its action) - - **Diagrams** - - * For each session, PrimAITE automatically creates a visualisation of the system / network lay down configuration. - * For each learning and evaluation task within the session, PrimAITE automatically plots the average reward per episode using PlotLY and saves it to the learning or evaluation subdirectory in the session directory. - - **Saved agents** - - For each training session, assuming the agent being trained implements the *save()* function and this function is called by the code, PrimAITE automatically saves the agent state. - - **Example Session Directory Structure** - - .. code-block:: text - - ~/ - └── primaite/ - └── 2.0.0/ - └── sessions/ - └── 2023-07-18/ - └── 2023-07-18_11-06-04/ - ├── evaluation/ - │ ├── all_transactions_2023-07-18_11-06-04.csv - │ ├── average_reward_per_episode_2023-07-18_11-06-04.csv - │ └── average_reward_per_episode_2023-07-18_11-06-04.png - ├── learning/ - │ ├── all_transactions_2023-07-18_11-06-04.csv - │ ├── average_reward_per_episode_2023-07-18_11-06-04.csv - │ ├── average_reward_per_episode_2023-07-18_11-06-04.png - │ ├── checkpoints/ - │ │ └── sb3ppo_10.zip - │ ├── SB3_PPO.zip - │ └── tensorboard_logs/ - │ ├── PPO_1/ - │ │ └── events.out.tfevents.1689674765.METD-9PMRFB3.42960.0 - │ ├── PPO_2/ - │ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.1 - │ ├── PPO_3/ - │ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.2 - │ ├── PPO_4/ - │ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.3 - │ ├── PPO_5/ - │ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.4 - │ ├── PPO_6/ - │ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.5 - │ ├── PPO_7/ - │ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.6 - │ ├── PPO_8/ - │ │ └── events.out.tfevents.1689674769.METD-9PMRFB3.42960.7 - │ ├── PPO_9/ - │ │ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.8 - │ └── PPO_10/ - │ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.9 - ├── network_2023-07-18_11-06-04.png - └── session_metadata.json - - Loading a session - ----------------- - - A previous session can be loaded by providing the **directory** of the previous session to either the ``primaite session`` command from the cli - (See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` with session_path. - - .. tabs:: - - .. code-tab:: bash - :caption: Unix CLI - - cd ~/primaite/2.0.0 - source ./.venv/bin/activate - primaite session --load "path/to/session" - - .. code-tab:: bash - :caption: Powershell CLI - - cd ~\primaite\2.0.0 - .\.venv\Scripts\activate - primaite session --load "path\to\session" - - - .. code-tab:: python - :caption: Python - - from primaite.main import run - - run(session_path=) - - When PrimAITE runs a loaded session, PrimAITE will output in the provided session directory +Running a session creates a session output directory in your user data folder. The filepath looks like this: +``~/primaite/3.0.0/sessions/YYYY-MM-DD/HH-MM-SS/``. This folder contains the simulation sys logs generated by each node, +the saved agent checkpoints, and final model. diff --git a/pyproject.toml b/pyproject.toml index 92f78ec0..3e5b959a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "tensorflow==2.12.0", "typer[all]==0.9.0", "pydantic==2.1.1", + "ray[rllib] == 2.8.0, < 3" ] [tool.setuptools.dynamic] diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 6455272c..d9896b01 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -1,8 +1,8 @@ training_config: - rl_framework: SB3 + rl_framework: RLLIB_single_agent rl_algorithm: PPO seed: 333 - n_learn_episodes: 25 + n_learn_episodes: 1 n_eval_episodes: 5 max_steps_per_episode: 128 deterministic_eval: false @@ -15,7 +15,8 @@ io_settings: checkpoint_interval: 5 -game_config: +game: + max_episode_length: 256 ports: - ARP - DNS @@ -26,522 +27,504 @@ game_config: - TCP - UDP - agents: - - ref: client_1_green_user - team: GREEN - type: GreenWebBrowsingAgent - observation_space: - type: UC2GreenObservation - action_space: - action_list: - - type: DONOTHING - # - # - type: NODE_LOGON - # - type: NODE_LOGOFF - # - type: NODE_APPLICATION_EXECUTE - # options: - # execution_definition: - # target_address: arcd.com +agents: + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com - options: - nodes: - - node_ref: client_2 - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 - reward_function: - reward_components: - - type: DUMMY + reward_function: + reward_components: + - type: DUMMY - agent_settings: - start_settings: - start_step: 5 - frequency: 4 - variance: 3 + agent_settings: + start_settings: + start_step: 5 + frequency: 4 + variance: 3 - - ref: client_1_data_manipulation_red_bot - team: RED - type: RedDatabaseCorruptingAgent + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: + observation_space: + type: UC2RedObservation + options: + nodes: {} + + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_ref: client_1 + applications: + - application_ref: data_manipulation_bot + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: # options specific to this particular agent type, basically args of __init__(self) + start_settings: + start_step: 25 + frequency: 20 + variance: 5 + + - ref: defender + team: BLUE + type: ProxyAgent + + observation_space: + type: UC2BlueObservation + options: + num_services_per_node: 1 + num_folders_per_node: 1 + num_files_per_folder: 1 + num_nics_per_node: 2 + nodes: + - node_ref: domain_controller + services: + - service_ref: domain_controller_dns_server + - node_ref: web_server + services: + - service_ref: web_server_database_client + - node_ref: database_server + services: + - service_ref: database_service + folders: + - folder_name: database + files: + - file_name: database.db + - node_ref: backup_server + # services: + # - service_ref: backup_service + - node_ref: security_suite + - node_ref: client_1 + - node_ref: client_2 + links: + - link_ref: router_1___switch_1 + - link_ref: router_1___switch_2 + - link_ref: switch_1___domain_controller + - link_ref: switch_1___web_server + - link_ref: switch_1___database_server + - link_ref: switch_1___backup_server + - link_ref: switch_1___security_suite + - link_ref: switch_2___client_1 + - link_ref: switch_2___client_2 + - link_ref: switch_2___security_suite + acl: + options: + max_acl_rules: 10 + router_node_ref: router_1 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 - node_ref: client_1 - observations: - - logon_status - - operating_status - applications: - - application_ref: data_manipulation_bot - observations: - operating_status - health_status - folders: {} + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 + ics: null - action_space: - action_list: - - type: DONOTHING - # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com + + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 + + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot + observations: + operating_status + health_status + folders: {} + + action_space: + action_list: + - type: DONOTHING + # None: """Init method for ActionManager. - :param session: Reference to the session to which the agent belongs. - :type session: PrimaiteSession + :param game: Reference to the game to which the agent belongs. + :type game: PrimaiteGame :param actions: List of action types which should be made available to the agent. :type actions: List[str] :param node_uuids: List of node UUIDs that this agent can act on. @@ -633,8 +633,8 @@ class ActionManager: :param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions. :type act_map: Optional[Dict[int, Dict]] """ - self.session: "PrimaiteSession" = session - self.sim: Simulation = self.session.simulation + self.game: "PrimaiteGame" = game + self.sim: Simulation = self.game.simulation self.node_uuids: List[str] = node_uuids self.application_uuids: List[List[str]] = application_uuids self.protocols: List[str] = protocols @@ -874,7 +874,7 @@ class ActionManager: return nics[nic_idx] @classmethod - def from_config(cls, session: "PrimaiteSession", cfg: Dict) -> "ActionManager": + def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager": """ Construct an ActionManager from a config definition. @@ -893,20 +893,20 @@ class ActionManager: These options are used to calculate the shape of the action space, and to provide additional information to the ActionManager which is required to convert the agent's action choice into a CAOS request. - :param session: The Primaite Session to which the agent belongs. - :type session: PrimaiteSession + :param game: The Primaite Game to which the agent belongs. + :type game: PrimaiteGame :param cfg: The action space config. :type cfg: Dict :return: The constructed ActionManager. :rtype: ActionManager """ obj = cls( - session=session, + game=game, actions=cfg["action_list"], # node_uuids=cfg["options"]["node_uuids"], **cfg["options"], - protocols=session.options.protocols, - ports=session.options.ports, + protocols=game.options.protocols, + ports=game.options.ports, ip_address_list=None, act_map=cfg.get("action_map"), ) diff --git a/src/primaite/game/agent/data_manipulation_bot.py b/src/primaite/game/agent/data_manipulation_bot.py index 51221154..8237ce06 100644 --- a/src/primaite/game/agent/data_manipulation_bot.py +++ b/src/primaite/game/agent/data_manipulation_bot.py @@ -38,7 +38,7 @@ class DataManipulationAgent(AbstractScriptedAgent): :return: _description_ :rtype: Tuple[str, Dict] """ - current_timestep = self.action_manager.session.step_counter + current_timestep = self.action_manager.game.step_counter if current_timestep < self.next_execution_timestep: return "DONOTHING", {"dummy": 0} diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index a74771c0..14fb2fa7 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -11,7 +11,7 @@ from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_ST _LOGGER = getLogger(__name__) if TYPE_CHECKING: - from primaite.game.session import PrimaiteSession + from primaite.game.game import PrimaiteGame class AbstractObservation(ABC): @@ -37,10 +37,10 @@ class AbstractObservation(ABC): @classmethod @abstractmethod - def from_config(cls, config: Dict, session: "PrimaiteSession"): + def from_config(cls, config: Dict, game: "PrimaiteGame"): """Create this observation space component form a serialised format. - The `session` parameter is for a the PrimaiteSession object that spawns this component. During deserialisation, + The `game` parameter is for a the PrimaiteGame object that spawns this component. During deserialisation, a subclass of this class may need to translate from a 'reference' to a UUID. """ pass @@ -91,13 +91,13 @@ class FileObservation(AbstractObservation): return spaces.Dict({"health_status": spaces.Discrete(6)}) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where: List[str] = None) -> "FileObservation": + def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation": """Create file observation from a config. :param config: Dictionary containing the configuration for this file observation. :type config: Dict - :param session: _description_ - :type session: PrimaiteSession + :param game: _description_ + :type game: PrimaiteGame :param parent_where: _description_, defaults to None :type parent_where: _type_, optional :return: _description_ @@ -149,20 +149,20 @@ class ServiceObservation(AbstractObservation): @classmethod def from_config( - cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]] = None + cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None ) -> "ServiceObservation": """Create service observation from a config. :param config: Dictionary containing the configuration for this service observation. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame :param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional. :type parent_where: Optional[List[str]], optional :return: Constructed service observation :rtype: ServiceObservation """ - return cls(where=parent_where + ["services", session.ref_map_services[config["service_ref"]].uuid]) + return cls(where=parent_where + ["services", game.ref_map_services[config["service_ref"]].uuid]) class LinkObservation(AbstractObservation): @@ -219,17 +219,17 @@ class LinkObservation(AbstractObservation): return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})}) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "LinkObservation": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation": """Create link observation from a config. :param config: Dictionary containing the configuration for this link observation. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame :return: Constructed link observation :rtype: LinkObservation """ - return cls(where=["network", "links", session.ref_map_links[config["link_ref"]]]) + return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]]) class FolderObservation(AbstractObservation): @@ -310,15 +310,15 @@ class FolderObservation(AbstractObservation): @classmethod def from_config( - cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]], num_files_per_folder: int = 2 + cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2 ) -> "FolderObservation": """Create folder observation from a config. Also creates child file observations. :param config: Dictionary containing the configuration for this folder observation. Includes the name of the folder and the files inside of it. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame :param parent_where: Where in the simulation state dictionary to find the information about this folder's parent node. A typical location for a node ``where`` can be: ['network','nodes',,'file_system'] @@ -332,7 +332,7 @@ class FolderObservation(AbstractObservation): where = parent_where + ["folders", config["folder_name"]] file_configs = config["files"] - files = [FileObservation.from_config(config=f, session=session, parent_where=where) for f in file_configs] + files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs] return cls(where=where, files=files, num_files_per_folder=num_files_per_folder) @@ -376,15 +376,13 @@ class NicObservation(AbstractObservation): return spaces.Dict({"nic_status": spaces.Discrete(3)}) @classmethod - def from_config( - cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]] - ) -> "NicObservation": + def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation": """Create NIC observation from a config. :param config: Dictionary containing the configuration for this NIC observation. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame :param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent node. A typical location for a node ``where`` can be: ['network','nodes',] :type parent_where: Optional[List[str]] @@ -515,7 +513,7 @@ class NodeObservation(AbstractObservation): def from_config( cls, config: Dict, - session: "PrimaiteSession", + game: "PrimaiteGame", parent_where: Optional[List[str]] = None, num_services_per_node: int = 2, num_folders_per_node: int = 2, @@ -526,8 +524,8 @@ class NodeObservation(AbstractObservation): :param config: Dictionary containing the configuration for this node observation. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame :param parent_where: Where in the simulation state dictionary to find the information about this node's parent network. A typical location for it would be: ['network',] :type parent_where: Optional[List[str]] @@ -543,24 +541,24 @@ class NodeObservation(AbstractObservation): :return: Constructed node observation :rtype: NodeObservation """ - node_uuid = session.ref_map_nodes[config["node_ref"]] + node_uuid = game.ref_map_nodes[config["node_ref"]] if parent_where is None: where = ["network", "nodes", node_uuid] else: where = parent_where + ["nodes", node_uuid] svc_configs = config.get("services", {}) - services = [ServiceObservation.from_config(config=c, session=session, parent_where=where) for c in svc_configs] + services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs] folder_configs = config.get("folders", {}) folders = [ FolderObservation.from_config( - config=c, session=session, parent_where=where, num_files_per_folder=num_files_per_folder + config=c, game=game, parent_where=where, num_files_per_folder=num_files_per_folder ) for c in folder_configs ] - nic_uuids = session.simulation.network.nodes[node_uuid].nics.keys() + nic_uuids = game.simulation.network.nodes[node_uuid].nics.keys() nic_configs = [{"nic_uuid": n for n in nic_uuids}] if nic_uuids else [] - nics = [NicObservation.from_config(config=c, session=session, parent_where=where) for c in nic_configs] + nics = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs] logon_status = config.get("logon_status", False) return cls( where=where, @@ -694,13 +692,13 @@ class AclObservation(AbstractObservation): ) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "AclObservation": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation": """Generate ACL observation from a config. :param config: Dictionary containing the configuration for this ACL observation. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame :return: Observation object :rtype: AclObservation """ @@ -709,15 +707,15 @@ class AclObservation(AbstractObservation): for ip_idx, ip_map_config in enumerate(config["ip_address_order"]): node_ref = ip_map_config["node_ref"] nic_num = ip_map_config["nic_num"] - node_obj = session.simulation.network.nodes[session.ref_map_nodes[node_ref]] + node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]] nic_obj = node_obj.ethernet_port[nic_num] node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2 - router_uuid = session.ref_map_nodes[config["router_node_ref"]] + router_uuid = game.ref_map_nodes[config["router_node_ref"]] return cls( node_ip_to_id=node_ip_to_idx, - ports=session.options.ports, - protocols=session.options.protocols, + ports=game.options.ports, + protocols=game.options.protocols, where=["network", "nodes", router_uuid, "acl", "acl"], num_rules=max_acl_rules, ) @@ -740,7 +738,7 @@ class NullObservation(AbstractObservation): return spaces.Discrete(1) @classmethod - def from_config(cls, config: Dict, session: Optional["PrimaiteSession"] = None) -> "NullObservation": + def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation": """ Create null observation from a config. @@ -836,14 +834,14 @@ class UC2BlueObservation(AbstractObservation): ) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "UC2BlueObservation": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation": """Create UC2 blue observation from a config. :param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes, links, ACL and ICS observations. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame :return: Constructed UC2 blue observation :rtype: UC2BlueObservation """ @@ -855,7 +853,7 @@ class UC2BlueObservation(AbstractObservation): nodes = [ NodeObservation.from_config( config=n, - session=session, + game=game, num_services_per_node=num_services_per_node, num_folders_per_node=num_folders_per_node, num_files_per_folder=num_files_per_folder, @@ -865,13 +863,13 @@ class UC2BlueObservation(AbstractObservation): ] link_configs = config["links"] - links = [LinkObservation.from_config(config=link, session=session) for link in link_configs] + links = [LinkObservation.from_config(config=link, game=game) for link in link_configs] acl_config = config["acl"] - acl = AclObservation.from_config(config=acl_config, session=session) + acl = AclObservation.from_config(config=acl_config, game=game) ics_config = config["ics"] - ics = ICSObservation.from_config(config=ics_config, session=session) + ics = ICSObservation.from_config(config=ics_config, game=game) new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"]) return new @@ -907,17 +905,17 @@ class UC2RedObservation(AbstractObservation): ) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "UC2RedObservation": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation": """ Create UC2 red observation from a config. :param config: Dictionary containing the configuration for this UC2 red observation. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame """ node_configs = config["nodes"] - nodes = [NodeObservation.from_config(config=cfg, session=session) for cfg in node_configs] + nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs] return cls(nodes=nodes, where=["network"]) @@ -966,7 +964,7 @@ class ObservationManager: return self.obs.space @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationManager": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager": """Create observation space from a config. :param config: Dictionary containing the configuration for this observation space. @@ -974,14 +972,14 @@ class ObservationManager: UC2BlueObservation, UC2RedObservation, UC2GreenObservation) The other key is 'options' which are passed to the constructor of the selected observation class. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame """ if config["type"] == "UC2BlueObservation": - return cls(UC2BlueObservation.from_config(config.get("options", {}), session=session)) + return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game)) elif config["type"] == "UC2RedObservation": - return cls(UC2RedObservation.from_config(config.get("options", {}), session=session)) + return cls(UC2RedObservation.from_config(config.get("options", {}), game=game)) elif config["type"] == "UC2GreenObservation": - return cls(UC2GreenObservation.from_config(config.get("options", {}), session=session)) + return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game)) else: raise ValueError("Observation space type invalid") diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index da1331b0..8a1c2da4 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -34,7 +34,7 @@ from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_ST _LOGGER = getLogger(__name__) if TYPE_CHECKING: - from primaite.game.session import PrimaiteSession + from primaite.game.game import PrimaiteGame class AbstractReward: @@ -47,13 +47,13 @@ class AbstractReward: @classmethod @abstractmethod - def from_config(cls, config: dict, session: "PrimaiteSession") -> "AbstractReward": + def from_config(cls, config: dict, game: "PrimaiteGame") -> "AbstractReward": """Create a reward function component from a config dictionary. :param config: dict of options for the reward component's constructor :type config: dict - :param session: Reference to the PrimAITE Session object - :type session: PrimaiteSession + :param game: Reference to the PrimAITE Game object + :type game: PrimaiteGame :return: The reward component. :rtype: AbstractReward """ @@ -68,13 +68,13 @@ class DummyReward(AbstractReward): return 0.0 @classmethod - def from_config(cls, config: dict, session: "PrimaiteSession") -> "DummyReward": + def from_config(cls, config: dict, game: "PrimaiteGame") -> "DummyReward": """Create a reward function component from a config dictionary. :param config: dict of options for the reward component's constructor. Should be empty. :type config: dict - :param session: Reference to the PrimAITE Session object - :type session: PrimaiteSession + :param game: Reference to the PrimAITE Game object + :type game: PrimaiteGame """ return cls() @@ -119,13 +119,13 @@ class DatabaseFileIntegrity(AbstractReward): return 0 @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "DatabaseFileIntegrity": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "DatabaseFileIntegrity": """Create a reward function component from a config dictionary. :param config: dict of options for the reward component's constructor :type config: Dict - :param session: Reference to the PrimAITE Session object - :type session: PrimaiteSession + :param game: Reference to the PrimAITE Game object + :type game: PrimaiteGame :return: The reward component. :rtype: DatabaseFileIntegrity """ @@ -147,7 +147,7 @@ class DatabaseFileIntegrity(AbstractReward): f"{cls.__name__} could not be initialised from config because file_name parameter was not specified" ) return DummyReward() # TODO: better error handling - node_uuid = session.ref_map_nodes[node_ref] + node_uuid = game.ref_map_nodes[node_ref] if not node_uuid: _LOGGER.error( ( @@ -193,13 +193,13 @@ class WebServer404Penalty(AbstractReward): return 0.0 @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "WebServer404Penalty": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "WebServer404Penalty": """Create a reward function component from a config dictionary. :param config: dict of options for the reward component's constructor :type config: Dict - :param session: Reference to the PrimAITE Session object - :type session: PrimaiteSession + :param game: Reference to the PrimAITE Game object + :type game: PrimaiteGame :return: The reward component. :rtype: WebServer404Penalty """ @@ -212,8 +212,8 @@ class WebServer404Penalty(AbstractReward): ) _LOGGER.warn(msg) return DummyReward() # TODO: should we error out with incorrect inputs? Probably! - node_uuid = session.ref_map_nodes[node_ref] - service_uuid = session.ref_map_services[service_ref].uuid + node_uuid = game.ref_map_nodes[node_ref] + service_uuid = game.ref_map_services[service_ref].uuid if not (node_uuid and service_uuid): msg = ( f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not" @@ -265,13 +265,13 @@ class RewardFunction: return self.current_reward @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "RewardFunction": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "RewardFunction": """Create a reward function from a config dictionary. :param config: dict of options for the reward manager's constructor :type config: Dict - :param session: Reference to the PrimAITE Session object - :type session: PrimaiteSession + :param game: Reference to the PrimAITE Game object + :type game: PrimaiteGame :return: The reward manager. :rtype: RewardFunction """ @@ -281,6 +281,6 @@ class RewardFunction: rew_type = rew_component_cfg["type"] weight = rew_component_cfg.get("weight", 1.0) rew_class = cls.__rew_class_identifiers[rew_type] - rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}), session=session) + rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}), game=game) new.regsiter_component(component=rew_instance, weight=weight) return new diff --git a/src/primaite/game/session.py b/src/primaite/game/game.py similarity index 71% rename from src/primaite/game/session.py rename to src/primaite/game/game.py index 095458b7..ae60bbc1 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/game.py @@ -1,12 +1,8 @@ -"""PrimAITE session - the main entry point to training agents on PrimAITE.""" +"""PrimAITE game - Encapsulates the simulation and agents.""" from copy import deepcopy -from enum import Enum from ipaddress import IPv4Address -from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple +from typing import Dict, List -import gymnasium -from gymnasium.core import ActType, ObsType from pydantic import BaseModel, ConfigDict from primaite import getLogger @@ -15,8 +11,6 @@ from primaite.game.agent.data_manipulation_bot import DataManipulationAgent from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction -from primaite.game.io import SessionIO, SessionIOSettings -from primaite.game.policy.policy import PolicyABC from primaite.simulator.network.hardware.base import Link, NIC, Node, NodeOperatingState from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.router import ACLAction, Router @@ -40,65 +34,7 @@ from primaite.simulator.system.services.web_server.web_server import WebServer _LOGGER = getLogger(__name__) -class PrimaiteGymEnv(gymnasium.Env): - """ - Thin wrapper env to provide agents with a gymnasium API. - - This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some - assumptions about the agent list always having a list of length 1. - """ - - def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]): - """Initialise the environment.""" - super().__init__() - self.session: "PrimaiteSession" = session - self.agent: ProxyAgent = agents[0] - - def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: - """Perform a step in the environment.""" - # make ProxyAgent store the action chosen my the RL policy - self.agent.store_action(action) - # apply_agent_actions accesses the action we just stored - self.session.apply_agent_actions() - self.session.advance_timestep() - state = self.session.get_sim_state() - self.session.update_agents(state) - - next_obs = self._get_obs() - reward = self.agent.reward_function.current_reward - terminated = False - truncated = self.session.calculate_truncated() - info = {} - - return next_obs, reward, terminated, truncated, info - - def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: - """Reset the environment.""" - self.session.reset() - state = self.session.get_sim_state() - self.session.update_agents(state) - next_obs = self._get_obs() - info = {} - return next_obs, info - - @property - def action_space(self) -> gymnasium.Space: - """Return the action space of the environment.""" - return self.agent.action_manager.space - - @property - def observation_space(self) -> gymnasium.Space: - """Return the observation space of the environment.""" - return gymnasium.spaces.flatten_space(self.agent.observation_manager.space) - - def _get_obs(self) -> ObsType: - """Return the current observation.""" - unflat_space = self.agent.observation_manager.space - unflat_obs = self.agent.observation_manager.current_observation - return gymnasium.spaces.flatten(unflat_space, unflat_obs) - - -class PrimaiteSessionOptions(BaseModel): +class PrimaiteGameOptions(BaseModel): """ Global options which are applicable to all of the agents in the game. @@ -107,40 +43,20 @@ class PrimaiteSessionOptions(BaseModel): model_config = ConfigDict(extra="forbid") + max_episode_length: int = 256 ports: List[str] protocols: List[str] -class TrainingOptions(BaseModel): - """Options for training the RL agent.""" +class PrimaiteGame: + """ + Primaite game encapsulates the simulation and agents which interact with it. - model_config = ConfigDict(extra="forbid") - - rl_framework: Literal["SB3", "RLLIB"] - rl_algorithm: Literal["PPO", "A2C"] - n_learn_episodes: int - n_eval_episodes: Optional[int] = None - max_steps_per_episode: int - # checkpoint_freq: Optional[int] = None - deterministic_eval: bool - seed: Optional[int] - n_agents: int - agent_references: List[str] - - -class SessionMode(Enum): - """Helper to keep track of the current session mode.""" - - TRAIN = "train" - EVAL = "eval" - MANUAL = "manual" - - -class PrimaiteSession: - """The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments.""" + Provides main logic loop for the game. However, it does not provide policy training, or a gymnasium environment. + """ def __init__(self): - """Initialise a PrimaiteSession object.""" + """Initialise a PrimaiteGame object.""" self.simulation: Simulation = Simulation() """Simulation object with which the agents will interact.""" @@ -159,15 +75,9 @@ class PrimaiteSession: self.episode_counter: int = 0 """Current episode number.""" - self.options: PrimaiteSessionOptions + self.options: PrimaiteGameOptions """Special options that apply for the entire game.""" - self.training_options: TrainingOptions - """Options specific to agent training.""" - - self.policy: PolicyABC - """The reinforcement learning policy.""" - self.ref_map_nodes: Dict[str, Node] = {} """Mapping from unique node reference name to node object. Used when parsing config files.""" @@ -180,40 +90,6 @@ class PrimaiteSession: self.ref_map_links: Dict[str, Link] = {} """Mapping from human-readable link reference to link object. Used when parsing config files.""" - self.env: PrimaiteGymEnv - """The environment that the agent can consume. Could be PrimaiteEnv.""" - - self.mode: SessionMode = SessionMode.MANUAL - """Current session mode.""" - - self.io_manager = SessionIO() - """IO manager for the session.""" - - def start_session(self) -> None: - """Commence the training session.""" - self.mode = SessionMode.TRAIN - n_learn_episodes = self.training_options.n_learn_episodes - n_eval_episodes = self.training_options.n_eval_episodes - max_steps_per_episode = self.training_options.max_steps_per_episode - - deterministic_eval = self.training_options.deterministic_eval - self.policy.learn( - n_episodes=n_learn_episodes, - timesteps_per_episode=max_steps_per_episode, - ) - self.save_models() - - self.mode = SessionMode.EVAL - if n_eval_episodes > 0: - self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval) - - self.mode = SessionMode.MANUAL - - def save_models(self) -> None: - """Save the RL models.""" - save_path = self.io_manager.generate_model_save_path("temp_model_name") - self.policy.save(save_path) - def step(self): """ Perform one step of the simulation/agent loop. @@ -232,7 +108,7 @@ class PrimaiteSession: single-agent gym, make sure to update the ProxyAgent's action with the action before calling ``self.apply_agent_actions()``. """ - _LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}") + _LOGGER.debug(f"Stepping. Step counter: {self.step_counter}") # Get the current state of the simulation sim_state = self.get_sim_state() @@ -274,29 +150,29 @@ class PrimaiteSession: def calculate_truncated(self) -> bool: """Calculate whether the episode is truncated.""" current_step = self.step_counter - max_steps = self.training_options.max_steps_per_episode + max_steps = self.options.max_episode_length if current_step >= max_steps: return True return False def reset(self) -> None: - """Reset the session, this will reset the simulation.""" + """Reset the game, this will reset the simulation.""" self.episode_counter += 1 self.step_counter = 0 - _LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}") + _LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}") self.simulation = deepcopy(self._simulation_initial_state) def close(self) -> None: - """Close the session, this will stop the env and close the simulation.""" + """Close the game, this will close the simulation.""" return NotImplemented @classmethod - def from_config(cls, cfg: dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession": - """Create a PrimaiteSession object from a config dictionary. + def from_config(cls, cfg: Dict) -> "PrimaiteGame": + """Create a PrimaiteGame object from a config dictionary. The config dictionary should have the following top-level keys: 1. training_config: options for training the RL agent. - 2. game_config: options for the game itself. Used by PrimaiteSession. + 2. game_config: options for the game itself. Used by PrimaiteGame. 3. simulation: defines the network topology and the initial state of the simulation. The specification for each of the three major areas is described in a separate documentation page. @@ -304,26 +180,19 @@ class PrimaiteSession: :param cfg: The config dictionary. :type cfg: dict - :return: A PrimaiteSession object. - :rtype: PrimaiteSession + :return: A PrimaiteGame object. + :rtype: PrimaiteGame """ - sess = cls() - sess.options = PrimaiteSessionOptions( - ports=cfg["game_config"]["ports"], - protocols=cfg["game_config"]["protocols"], - ) - sess.training_options = TrainingOptions(**cfg["training_config"]) + game = cls() + game.options = PrimaiteGameOptions(**cfg["game"]) - # READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS... - io_settings = cfg.get("io_settings", {}) - sess.io_manager.settings = SessionIOSettings(**io_settings) - - sim = sess.simulation + # 1. create simulation + sim = game.simulation net = sim.network - sess.ref_map_nodes: Dict[str, Node] = {} - sess.ref_map_services: Dict[str, Service] = {} - sess.ref_map_links: Dict[str, Link] = {} + game.ref_map_nodes: Dict[str, Node] = {} + game.ref_map_services: Dict[str, Service] = {} + game.ref_map_links: Dict[str, Link] = {} nodes_cfg = cfg["simulation"]["network"]["nodes"] links_cfg = cfg["simulation"]["network"]["links"] @@ -400,7 +269,7 @@ class PrimaiteSession: print(f"installing {service_type} on node {new_node.hostname}") new_node.software_manager.install(service_types_mapping[service_type]) new_service = new_node.software_manager.software[service_type] - sess.ref_map_services[service_ref] = new_service + game.ref_map_services[service_ref] = new_service else: print(f"service type not found {service_type}") # service-dependent options @@ -434,7 +303,7 @@ class PrimaiteSession: if application_type in application_types_mapping: new_node.software_manager.install(application_types_mapping[application_type]) new_application = new_node.software_manager.software[application_type] - sess.ref_map_applications[application_ref] = new_application + game.ref_map_applications[application_ref] = new_application else: print(f"application type not found {application_type}") @@ -442,7 +311,7 @@ class PrimaiteSession: if "options" in application_cfg: opt = application_cfg["options"] new_application.configure( - server_ip_address=opt.get("server_ip"), + server_ip_address=IPv4Address(opt.get("server_ip")), payload=opt.get("payload"), port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")), @@ -453,7 +322,7 @@ class PrimaiteSession: net.add_node(new_node) new_node.power_on() - sess.ref_map_nodes[ + game.ref_map_nodes[ node_ref ] = ( new_node.uuid @@ -461,8 +330,8 @@ class PrimaiteSession: # 2. create links between nodes for link_cfg in links_cfg: - node_a = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_a_ref"]]] - node_b = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_b_ref"]]] + node_a = net.nodes[game.ref_map_nodes[link_cfg["endpoint_a_ref"]]] + node_b = net.nodes[game.ref_map_nodes[link_cfg["endpoint_b_ref"]]] if isinstance(node_a, Switch): endpoint_a = node_a.switch_ports[link_cfg["endpoint_a_port"]] else: @@ -472,13 +341,10 @@ class PrimaiteSession: else: endpoint_b = node_b.ethernet_port[link_cfg["endpoint_b_port"]] new_link = net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b) - sess.ref_map_links[link_cfg["ref"]] = new_link.uuid - # endpoint_a.enable() - # endpoint_b.enable() + game.ref_map_links[link_cfg["ref"]] = new_link.uuid # 3. create agents - game_cfg = cfg["game_config"] - agents_cfg = game_cfg["agents"] + agents_cfg = cfg["agents"] for agent_cfg in agents_cfg: agent_ref = agent_cfg["ref"] # noqa: F841 @@ -488,7 +354,7 @@ class PrimaiteSession: reward_function_cfg = agent_cfg["reward_function"] # CREATE OBSERVATION SPACE - obs_space = ObservationManager.from_config(observation_space_cfg, sess) + obs_space = ObservationManager.from_config(observation_space_cfg, game) # CREATE ACTION SPACE action_space_cfg["options"]["node_uuids"] = [] @@ -497,7 +363,7 @@ class PrimaiteSession: # if a list of nodes is defined, convert them from node references to node UUIDs for action_node_option in action_space_cfg.get("options", {}).pop("nodes", {}): if "node_ref" in action_node_option: - node_uuid = sess.ref_map_nodes[action_node_option["node_ref"]] + node_uuid = game.ref_map_nodes[action_node_option["node_ref"]] action_space_cfg["options"]["node_uuids"].append(node_uuid) if "applications" in action_node_option: @@ -505,7 +371,7 @@ class PrimaiteSession: for application_option in action_node_option["applications"]: # TODO: fix inconsistency with node uuids and application uuids. The node object get added to # node_uuid, whereas here the application gets added by uuid. - application_uuid = sess.ref_map_applications[application_option["application_ref"]].uuid + application_uuid = game.ref_map_applications[application_option["application_ref"]].uuid node_application_uuids.append(application_uuid) action_space_cfg["options"]["application_uuids"].append(node_application_uuids) @@ -522,12 +388,12 @@ class PrimaiteSession: if "options" in action_config: if "target_router_ref" in action_config["options"]: _target = action_config["options"]["target_router_ref"] - action_config["options"]["target_router_uuid"] = sess.ref_map_nodes[_target] + action_config["options"]["target_router_uuid"] = game.ref_map_nodes[_target] - action_space = ActionManager.from_config(sess, action_space_cfg) + action_space = ActionManager.from_config(game, action_space_cfg) # CREATE REWARD FUNCTION - rew_function = RewardFunction.from_config(reward_function_cfg, session=sess) + rew_function = RewardFunction.from_config(reward_function_cfg, game=game) agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) @@ -541,7 +407,7 @@ class PrimaiteSession: reward_function=rew_function, agent_settings=agent_settings, ) - sess.agents.append(new_agent) + game.agents.append(new_agent) elif agent_type == "ProxyAgent": new_agent = ProxyAgent( agent_name=agent_cfg["ref"], @@ -549,8 +415,8 @@ class PrimaiteSession: observation_space=obs_space, reward_function=rew_function, ) - sess.agents.append(new_agent) - sess.rl_agents.append(new_agent) + game.agents.append(new_agent) + game.rl_agents.append(new_agent) elif agent_type == "RedDatabaseCorruptingAgent": new_agent = DataManipulationAgent( agent_name=agent_cfg["ref"], @@ -559,18 +425,10 @@ class PrimaiteSession: reward_function=rew_function, agent_settings=agent_settings, ) - sess.agents.append(new_agent) + game.agents.append(new_agent) else: print("agent type not found") - # CREATE ENVIRONMENT - sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents) + game._simulation_initial_state = deepcopy(game.simulation) # noqa - # CREATE POLICY - sess.policy = PolicyABC.from_config(sess.training_options, session=sess) - if agent_load_path: - sess.policy.load(Path(agent_load_path)) - - sess._simulation_initial_state = deepcopy(sess.simulation) # noqa - - return sess + return game diff --git a/src/primaite/game/policy/__init__.py b/src/primaite/game/policy/__init__.py deleted file mode 100644 index 29196112..00000000 --- a/src/primaite/game/policy/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from primaite.game.policy.sb3 import SB3Policy - -__all__ = ["SB3Policy"] diff --git a/src/primaite/main.py b/src/primaite/main.py index 1699fe51..b63227a7 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -5,8 +5,8 @@ from pathlib import Path from typing import Optional, Union from primaite import getLogger -from primaite.config.load import load -from primaite.game.session import PrimaiteSession +from primaite.config.load import example_config_path, load +from primaite.session.session import PrimaiteSession # from primaite.primaite_session import PrimaiteSession @@ -42,6 +42,6 @@ if __name__ == "__main__": args = parser.parse_args() if not args.config: - _LOGGER.error("Please provide a config file using the --config " "argument") + args.config = example_config_path() - run(session_path=args.config) + run(args.config) diff --git a/src/primaite/notebooks/training_example_ray_multi_agent.ipynb b/src/primaite/notebooks/training_example_ray_multi_agent.ipynb new file mode 100644 index 00000000..d31d53cc --- /dev/null +++ b/src/primaite/notebooks/training_example_ray_multi_agent.ipynb @@ -0,0 +1,127 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from primaite.game.game import PrimaiteGame\n", + "import yaml\n", + "from primaite.config.load import example_config_path\n", + "\n", + "from primaite.session.environment import PrimaiteRayEnv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(example_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + "\n", + "game = PrimaiteGame.from_config(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# gym = PrimaiteRayEnv({\"game\":game})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ray\n", + "from ray import air, tune\n", + "from ray.rllib.algorithms.ppo import PPOConfig" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ray.shutdown()\n", + "ray.init()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from primaite.session.environment import PrimaiteRayMARLEnv\n", + "\n", + "\n", + "env_config = {\"game\":game}\n", + "config = (\n", + " PPOConfig()\n", + " .environment(env=PrimaiteRayMARLEnv, env_config={\"game\":game})\n", + " .rollouts(num_rollout_workers=0)\n", + " .multi_agent(\n", + " policies={agent.agent_name for agent in game.rl_agents},\n", + " policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n", + " )\n", + " .training(train_batch_size=128)\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tune.Tuner(\n", + " \"PPO\",\n", + " run_config=air.RunConfig(\n", + " stop={\"training_iteration\": 128},\n", + " checkpoint_config=air.CheckpointConfig(\n", + " checkpoint_frequency=10,\n", + " ),\n", + " ),\n", + " param_space=config\n", + ").fit()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/primaite/notebooks/training_example_ray_single_agent.ipynb b/src/primaite/notebooks/training_example_ray_single_agent.ipynb new file mode 100644 index 00000000..8ee16d41 --- /dev/null +++ b/src/primaite/notebooks/training_example_ray_single_agent.ipynb @@ -0,0 +1,122 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from primaite.game.game import PrimaiteGame\n", + "import yaml\n", + "from primaite.config.load import example_config_path\n", + "\n", + "from primaite.session.environment import PrimaiteRayEnv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(example_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + "\n", + "game = PrimaiteGame.from_config(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gym = PrimaiteRayEnv({\"game\":game})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ray\n", + "from ray.rllib.algorithms import ppo" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ray.shutdown()\n", + "ray.init()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env_config = {\"game\":game}\n", + "config = {\n", + " \"env\" : PrimaiteRayEnv,\n", + " \"env_config\" : env_config,\n", + " \"disable_env_checking\": True,\n", + " \"num_rollout_workers\": 0,\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "algo = ppo.PPO(config=config)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(5):\n", + " result = algo.train()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "algo.save(\"temp/deleteme\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/primaite/notebooks/training_example_sb3.ipynb b/src/primaite/notebooks/training_example_sb3.ipynb new file mode 100644 index 00000000..e5085c5e --- /dev/null +++ b/src/primaite/notebooks/training_example_sb3.ipynb @@ -0,0 +1,102 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from primaite.game.game import PrimaiteGame\n", + "from primaite.session.environment import PrimaiteGymEnv\n", + "import yaml" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from primaite.config.load import example_config_path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(example_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + "\n", + "game = PrimaiteGame.from_config(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gym = PrimaiteGymEnv(game=game)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from stable_baselines3 import PPO" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = PPO('MlpPolicy', gym)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.learn(total_timesteps=1000)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.save(\"deleteme\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/primaite/notebooks/uc2_demo.ipynb b/src/primaite/notebooks/uc2_demo.ipynb new file mode 100644 index 00000000..3950ef10 --- /dev/null +++ b/src/primaite/notebooks/uc2_demo.ipynb @@ -0,0 +1,306 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/cade/repos/PrimAITE/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "2023-11-26 23:25:47,985\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", + "2023-11-26 23:25:51,213\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", + "2023-11-26 23:25:51,491\tWARNING __init__.py:10 -- PG has/have been moved to `rllib_contrib` and will no longer be maintained by the RLlib team. You can still use it/them normally inside RLlib util Ray 2.8, but from Ray 2.9 on, all `rllib_contrib` algorithms will no longer be part of the core repo, and will therefore have to be installed separately with pinned dependencies for e.g. ray[rllib] and other packages! See https://github.com/ray-project/ray/tree/master/rllib_contrib#rllib-contrib for more information on the RLlib contrib effort.\n" + ] + } + ], + "source": [ + "from primaite.session.session import PrimaiteSession\n", + "from primaite.game.game import PrimaiteGame\n", + "from primaite.config.load import example_config_path\n", + "\n", + "from primaite.simulator.system.services.database.database_service import DatabaseService\n", + "\n", + "import yaml" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-11-26 23:25:51,579::ERROR::primaite.simulator.network.hardware.base::175::NIC a9:92:0a:5e:1b:e4/127.0.0.1 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,580::ERROR::primaite.simulator.network.hardware.base::175::NIC ef:03:23:af:3c:19/127.0.0.1 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,581::ERROR::primaite.simulator.network.hardware.base::175::NIC ae:cf:83:2f:94:17/127.0.0.1 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,582::ERROR::primaite.simulator.network.hardware.base::175::NIC 4c:b2:99:e2:4a:5d/127.0.0.1 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,583::ERROR::primaite.simulator.network.hardware.base::175::NIC b9:eb:f9:c2:17:2f/127.0.0.1 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,590::ERROR::primaite.simulator.network.hardware.base::175::NIC cb:df:ca:54:be:01/192.168.1.10 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,595::ERROR::primaite.simulator.network.hardware.base::175::NIC 6e:32:12:da:4d:0d/192.168.1.12 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,600::ERROR::primaite.simulator.network.hardware.base::175::NIC 58:6e:9b:a7:68:49/192.168.1.14 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,604::ERROR::primaite.simulator.network.hardware.base::175::NIC 33:db:a6:40:dd:a3/192.168.1.16 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,608::ERROR::primaite.simulator.network.hardware.base::175::NIC 72:aa:2b:c0:4c:5f/192.168.1.110 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,610::ERROR::primaite.simulator.network.hardware.base::175::NIC 11:d7:0e:90:d9:a4/192.168.10.110 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,614::ERROR::primaite.simulator.network.hardware.base::175::NIC 86:2b:a4:e5:4d:0f/192.168.10.21 cannot be enabled as it is not connected to a Link\n", + "2023-11-26 23:25:51,631::ERROR::primaite.simulator.network.hardware.base::175::NIC af:ad:8f:84:f1:db/192.168.10.22 cannot be enabled as it is not connected to a Link\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "installing DNSServer on node domain_controller\n", + "installing DatabaseClient on node web_server\n", + "installing WebServer on node web_server\n", + "installing DatabaseService on node database_server\n", + "installing FTPClient on node database_server\n", + "installing FTPServer on node backup_server\n", + "installing DNSClient on node client_1\n", + "installing DNSClient on node client_2\n" + ] + } + ], + "source": [ + "\n", + "with open(example_config_path(),'r') as cfgfile:\n", + " cfg = yaml.safe_load(cfgfile)\n", + "game = PrimaiteGame.from_config(cfg)\n", + "net = game.simulation.network\n", + "database_server = net.get_node_by_hostname('database_server')\n", + "web_server = net.get_node_by_hostname('web_server')\n", + "client_1 = net.get_node_by_hostname('client_1')\n", + "\n", + "db_service = database_server.software_manager.software[\"DatabaseService\"]\n", + "db_client = web_server.software_manager.software[\"DatabaseClient\"]\n", + "# db_client.run()\n", + "db_manipulation_bot = client_1.software_manager.software[\"DataManipulationBot\"]\n", + "db_manipulation_bot.port_scan_p_of_success=1.0\n", + "db_manipulation_bot.data_manipulation_p_of_success=1.0\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "db_client.run()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "db_service.backup_database()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "db_client.query(\"SELECT\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "db_manipulation_bot.run()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "db_client.query(\"SELECT\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "db_service.restore_backup()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "db_client.query(\"SELECT\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "db_manipulation_bot.run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "client_1.ping(database_server.ethernet_port[1].ip_address)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from pydantic import validate_call, BaseModel" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "class A(BaseModel):\n", + " x:int\n", + "\n", + " @validate_call\n", + " def increase_x(self, by:int) -> None:\n", + " self.x += 1" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "my_a = A(x=3)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "ename": "ValidationError", + "evalue": "1 validation error for increase_x\n0\n Input should be a valid integer, got a number with a fractional part [type=int_from_float, input_value=3.2, input_type=float]\n For further information visit https://errors.pydantic.dev/2.1/v/int_from_float", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/cade/repos/PrimAITE/src/primaite/notebooks/uc2_demo.ipynb Cell 15\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m my_a\u001b[39m.\u001b[39;49mincrease_x(\u001b[39m3.2\u001b[39;49m)\n", + "File \u001b[0;32m~/repos/PrimAITE/venv/lib/python3.10/site-packages/pydantic/_internal/_validate_call.py:91\u001b[0m, in \u001b[0;36mValidateCallWrapper.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs: Any, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs: Any) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Any:\n\u001b[0;32m---> 91\u001b[0m res \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__pydantic_validator__\u001b[39m.\u001b[39;49mvalidate_python(pydantic_core\u001b[39m.\u001b[39;49mArgsKwargs(args, kwargs))\n\u001b[1;32m 92\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__return_pydantic_validator__:\n\u001b[1;32m 93\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__return_pydantic_validator__\u001b[39m.\u001b[39mvalidate_python(res)\n", + "\u001b[0;31mValidationError\u001b[0m: 1 validation error for increase_x\n0\n Input should be a valid integer, got a number with a fractional part [type=int_from_float, input_value=3.2, input_type=float]\n For further information visit https://errors.pydantic.dev/2.1/v/int_from_float" + ] + } + ], + "source": [ + "my_a.increase_x(3.2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/primaite/session/__init__.py b/src/primaite/session/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py new file mode 100644 index 00000000..db24db60 --- /dev/null +++ b/src/primaite/session/environment.py @@ -0,0 +1,162 @@ +from typing import Any, Dict, Final, Optional, SupportsFloat, Tuple + +import gymnasium +from gymnasium.core import ActType, ObsType +from ray.rllib.env.multi_agent_env import MultiAgentEnv + +from primaite.game.agent.interface import ProxyAgent +from primaite.game.game import PrimaiteGame + + +class PrimaiteGymEnv(gymnasium.Env): + """ + Thin wrapper env to provide agents with a gymnasium API. + + This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some + assumptions about the agent list always having a list of length 1. + """ + + def __init__(self, game: PrimaiteGame): + """Initialise the environment.""" + super().__init__() + self.game: "PrimaiteGame" = game + self.agent: ProxyAgent = self.game.rl_agents[0] + + def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: + """Perform a step in the environment.""" + # make ProxyAgent store the action chosen my the RL policy + self.agent.store_action(action) + # apply_agent_actions accesses the action we just stored + self.game.apply_agent_actions() + self.game.advance_timestep() + state = self.game.get_sim_state() + self.game.update_agents(state) + + next_obs = self._get_obs() + reward = self.agent.reward_function.current_reward + terminated = False + truncated = self.game.calculate_truncated() + info = {} + + return next_obs, reward, terminated, truncated, info + + def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: + """Reset the environment.""" + self.game.reset() + state = self.game.get_sim_state() + self.game.update_agents(state) + next_obs = self._get_obs() + info = {} + return next_obs, info + + @property + def action_space(self) -> gymnasium.Space: + """Return the action space of the environment.""" + return self.agent.action_manager.space + + @property + def observation_space(self) -> gymnasium.Space: + """Return the observation space of the environment.""" + return gymnasium.spaces.flatten_space(self.agent.observation_manager.space) + + def _get_obs(self) -> ObsType: + """Return the current observation.""" + unflat_space = self.agent.observation_manager.space + unflat_obs = self.agent.observation_manager.current_observation + return gymnasium.spaces.flatten(unflat_space, unflat_obs) + + +class PrimaiteRayEnv(gymnasium.Env): + """Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray.""" + + def __init__(self, env_config: Dict[str, PrimaiteGame]) -> None: + """Initialise the environment. + + :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game` + which is the PrimaiteGame instance. + :type env_config: Dict[str, PrimaiteGame] + """ + self.env = PrimaiteGymEnv(game=env_config["game"]) + self.action_space = self.env.action_space + self.observation_space = self.env.observation_space + + def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: + """Reset the environment.""" + return self.env.reset(seed=seed) + + def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]: + """Perform a step in the environment.""" + return self.env.step(action) + + +class PrimaiteRayMARLEnv(MultiAgentEnv): + """Ray Environment that inherits from MultiAgentEnv to allow training MARL systems.""" + + def __init__(self, env_config: Optional[Dict] = None) -> None: + """Initialise the environment. + + :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game` + which is the PrimaiteGame instance. + :type env_config: Dict[str, PrimaiteGame] + """ + self.game: PrimaiteGame = env_config["game"] + """Reference to the primaite game""" + self.agents: Final[Dict[str, ProxyAgent]] = {agent.agent_name: agent for agent in self.game.rl_agents} + """List of all possible agents in the environment. This list should not change!""" + self._agent_ids = list(self.agents.keys()) + + self.terminateds = set() + self.truncateds = set() + self.observation_space = gymnasium.spaces.Dict( + {name: agent.observation_manager.space for name, agent in self.agents.items()} + ) + self.action_space = gymnasium.spaces.Dict( + {name: agent.action_manager.space for name, agent in self.agents.items()} + ) + super().__init__() + + def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: + """Reset the environment.""" + self.game.reset() + state = self.game.get_sim_state() + self.game.update_agents(state) + next_obs = self._get_obs() + info = {} + return next_obs, info + + def step( + self, actions: Dict[str, ActType] + ) -> Tuple[Dict[str, ObsType], Dict[str, SupportsFloat], Dict[str, bool], Dict[str, bool], Dict]: + """Perform a step in the environment. Adherent to Ray MultiAgentEnv step API. + + :param actions: Dict of actions. The key is agent identifier and the value is a gymnasium action instance. + :type actions: Dict[str, ActType] + :return: Observations, rewards, terminateds, truncateds, and info. Each one is a dictionary keyed by agent + identifier. + :rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict] + """ + # 1. Perform actions + for agent_name, action in actions.items(): + self.agents[agent_name].store_action(action) + self.game.apply_agent_actions() + + # 2. Advance timestep + self.game.advance_timestep() + + # 3. Get next observations + state = self.game.get_sim_state() + self.game.update_agents(state) + next_obs = self._get_obs() + + # 4. Get rewards + rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()} + terminateds = {name: False for name, _ in self.agents.items()} + truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()} + infos = {} + terminateds["__all__"] = len(self.terminateds) == len(self.agents) + truncateds["__all__"] = self.game.calculate_truncated() + return next_obs, rewards, terminateds, truncateds, infos + + def _get_obs(self) -> Dict[str, ObsType]: + """Return the current observation.""" + return {name: agent.observation_manager.current_observation for name, agent in self.agents.items()} diff --git a/src/primaite/game/io.py b/src/primaite/session/io.py similarity index 100% rename from src/primaite/game/io.py rename to src/primaite/session/io.py diff --git a/src/primaite/session/policy/__init__.py b/src/primaite/session/policy/__init__.py new file mode 100644 index 00000000..811c7a54 --- /dev/null +++ b/src/primaite/session/policy/__init__.py @@ -0,0 +1,4 @@ +from primaite.session.policy.rllib import RaySingleAgentPolicy +from primaite.session.policy.sb3 import SB3Policy + +__all__ = ["SB3Policy", "RaySingleAgentPolicy"] diff --git a/src/primaite/game/policy/policy.py b/src/primaite/session/policy/policy.py similarity index 93% rename from src/primaite/game/policy/policy.py rename to src/primaite/session/policy/policy.py index 249c3b52..984466d1 100644 --- a/src/primaite/game/policy/policy.py +++ b/src/primaite/session/policy/policy.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Any, Dict, Type, TYPE_CHECKING if TYPE_CHECKING: - from primaite.game.session import PrimaiteSession, TrainingOptions + from primaite.session.session import PrimaiteSession, TrainingOptions class PolicyABC(ABC): @@ -80,5 +80,3 @@ class PolicyABC(ABC): PolicyType = cls._registry[config.rl_framework] return PolicyType.from_config(config=config, session=session) - - # saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass diff --git a/src/primaite/session/policy/rllib.py b/src/primaite/session/policy/rllib.py new file mode 100644 index 00000000..be181797 --- /dev/null +++ b/src/primaite/session/policy/rllib.py @@ -0,0 +1,106 @@ +from pathlib import Path +from typing import Literal, Optional, TYPE_CHECKING + +from primaite.session.environment import PrimaiteRayEnv, PrimaiteRayMARLEnv +from primaite.session.policy.policy import PolicyABC + +if TYPE_CHECKING: + from primaite.session.session import PrimaiteSession, TrainingOptions + +import ray +from ray import air, tune +from ray.rllib.algorithms import ppo +from ray.rllib.algorithms.ppo import PPOConfig + + +class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"): + """Single agent RL policy using Ray RLLib.""" + + def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): + super().__init__(session=session) + + config = { + "env": PrimaiteRayEnv, + "env_config": {"game": session.game}, + "disable_env_checking": True, + "num_rollout_workers": 0, + } + + ray.shutdown() + ray.init() + + self._algo = ppo.PPO(config=config) + + def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: + """Train the agent.""" + for ep in range(n_episodes): + self._algo.train() + + def eval(self, n_episodes: int, deterministic: bool) -> None: + """Evaluate the agent.""" + for ep in range(n_episodes): + obs, info = self.session.env.reset() + for step in range(self.session.game.options.max_episode_length): + action = self._algo.compute_single_action(observation=obs, explore=False) + obs, rew, term, trunc, info = self.session.env.step(action) + + def save(self, save_path: Path) -> None: + """Save the policy to a file.""" + self._algo.save(save_path) + + def load(self, model_path: Path) -> None: + """Load policy parameters from a file.""" + raise NotImplementedError + + @classmethod + def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy": + """Create a policy from a config.""" + return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed) + + +class RayMultiAgentPolicy(PolicyABC, identifier="RLLIB_multi_agent"): + """Mutli agent RL policy using Ray RLLib.""" + + def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO"], seed: Optional[int] = None): + """Initialise multi agent policy wrapper.""" + super().__init__(session=session) + + self.config = ( + PPOConfig() + .environment(env=PrimaiteRayMARLEnv, env_config={"game": session.game}) + .rollouts(num_rollout_workers=0) + .multi_agent( + policies={agent.agent_name for agent in session.game.rl_agents}, + policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id, + ) + .training(train_batch_size=128) + ) + + def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: + """Train the agent.""" + checkpoint_freq = self.session.io_manager.settings.checkpoint_interval + tune.Tuner( + "PPO", + run_config=air.RunConfig( + stop={"training_iteration": n_episodes * timesteps_per_episode}, + checkpoint_config=air.CheckpointConfig(checkpoint_frequency=checkpoint_freq), + ), + param_space=self.config, + ).fit() + + def load(self, model_path: Path) -> None: + """Load policy parameters from a file.""" + return NotImplemented + + def eval(self, n_episodes: int, deterministic: bool) -> None: + """Evaluate trained policy.""" + return NotImplemented + + def save(self, save_path: Path) -> None: + """Save policy parameters to a file.""" + return NotImplemented + + @classmethod + def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RayMultiAgentPolicy": + """Create policy from config.""" + return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed) diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/session/policy/sb3.py similarity index 96% rename from src/primaite/game/policy/sb3.py rename to src/primaite/session/policy/sb3.py index a4870054..051e2770 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/session/policy/sb3.py @@ -8,10 +8,10 @@ from stable_baselines3.common.callbacks import CheckpointCallback from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.ppo import MlpPolicy as PPO_MLP -from primaite.game.policy.policy import PolicyABC +from primaite.session.policy.policy import PolicyABC if TYPE_CHECKING: - from primaite.game.session import PrimaiteSession, TrainingOptions + from primaite.session.session import PrimaiteSession, TrainingOptions class SB3Policy(PolicyABC, identifier="SB3"): diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py new file mode 100644 index 00000000..80b63ba7 --- /dev/null +++ b/src/primaite/session/session.py @@ -0,0 +1,113 @@ +from enum import Enum +from pathlib import Path +from typing import Dict, List, Literal, Optional, Union + +from pydantic import BaseModel, ConfigDict + +from primaite.game.game import PrimaiteGame +from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv +from primaite.session.io import SessionIO, SessionIOSettings + +# from primaite.game.game import PrimaiteGame +from primaite.session.policy.policy import PolicyABC + + +class TrainingOptions(BaseModel): + """Options for training the RL agent.""" + + model_config = ConfigDict(extra="forbid") + + rl_framework: Literal["SB3", "RLLIB_single_agent", "RLLIB_multi_agent"] + rl_algorithm: Literal["PPO", "A2C"] + n_learn_episodes: int + n_eval_episodes: Optional[int] = None + max_steps_per_episode: int + # checkpoint_freq: Optional[int] = None + deterministic_eval: bool + seed: Optional[int] + n_agents: int + agent_references: List[str] + + +class SessionMode(Enum): + """Helper to keep track of the current session mode.""" + + TRAIN = "train" + EVAL = "eval" + MANUAL = "manual" + + +class PrimaiteSession: + """The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments.""" + + def __init__(self, game: PrimaiteGame): + """Initialise PrimaiteSession object.""" + self.training_options: TrainingOptions + """Options specific to agent training.""" + + self.mode: SessionMode = SessionMode.MANUAL + """Current session mode.""" + + self.env: Union[PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv] + """The environment that the RL algorithm can consume.""" + + self.policy: PolicyABC + """The reinforcement learning policy.""" + + self.io_manager = SessionIO() + """IO manager for the session.""" + + self.game: PrimaiteGame = game + """Primaite Game object for managing main simulation loop and agents.""" + + def start_session(self) -> None: + """Commence the training/eval session.""" + self.mode = SessionMode.TRAIN + n_learn_episodes = self.training_options.n_learn_episodes + n_eval_episodes = self.training_options.n_eval_episodes + max_steps_per_episode = self.training_options.max_steps_per_episode + + deterministic_eval = self.training_options.deterministic_eval + self.policy.learn( + n_episodes=n_learn_episodes, + timesteps_per_episode=max_steps_per_episode, + ) + self.save_models() + + self.mode = SessionMode.EVAL + if n_eval_episodes > 0: + self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval) + + self.mode = SessionMode.MANUAL + + def save_models(self) -> None: + """Save the RL models.""" + save_path = self.io_manager.generate_model_save_path("temp_model_name") + self.policy.save(save_path) + + @classmethod + def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession": + """Create a PrimaiteSession object from a config dictionary.""" + game = PrimaiteGame.from_config(cfg) + + sess = cls(game=game) + + sess.training_options = TrainingOptions(**cfg["training_config"]) + + # READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS... + io_settings = cfg.get("io_settings", {}) + sess.io_manager.settings = SessionIOSettings(**io_settings) + + # CREATE ENVIRONMENT + if sess.training_options.rl_framework == "RLLIB_single_agent": + sess.env = PrimaiteRayEnv(env_config={"game": game}) + elif sess.training_options.rl_framework == "RLLIB_multi_agent": + sess.env = PrimaiteRayMARLEnv(env_config={"game": game}) + elif sess.training_options.rl_framework == "SB3": + sess.env = PrimaiteGymEnv(game=game) + + sess.policy = PolicyABC.from_config(sess.training_options, session=sess) + if agent_load_path: + sess.policy.load(Path(agent_load_path)) + + return sess diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index 4d8e4669..9070f246 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -2,20 +2,12 @@ training_config: rl_framework: SB3 rl_algorithm: PPO se3ed: 333 # Purposeful typo to check that error is raised with bad configuration. - n_learn_episodes: 25 + n_learn_steps: 2560 n_eval_episodes: 5 - max_steps_per_episode: 128 - deterministic_eval: false - n_agents: 1 - agent_references: - - defender - -io_settings: - save_checkpoints: true - checkpoint_interval: 5 -game_config: + +game: ports: - ARP - DNS @@ -26,522 +18,499 @@ game_config: - TCP - UDP - agents: - - ref: client_1_green_user - team: GREEN - type: GreenWebBrowsingAgent - observation_space: - type: UC2GreenObservation - action_space: - action_list: - - type: DONOTHING - # - # - type: NODE_LOGON - # - type: NODE_LOGOFF - # - type: NODE_APPLICATION_EXECUTE - # options: - # execution_definition: - # target_address: arcd.com +agents: + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 - options: - nodes: - - node_ref: client_2 - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 + reward_function: + reward_components: + - type: DUMMY - reward_function: - reward_components: - - type: DUMMY + agent_settings: # options specific to this particular agent type, basically args of __init__(self) + start_settings: + start_step: 25 + frequency: 20 + variance: 5 - agent_settings: - start_settings: - start_step: 5 - frequency: 4 - variance: 3 + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent - - ref: client_1_data_manipulation_red_bot - team: RED - type: RedDatabaseCorruptingAgent + observation_space: + type: UC2RedObservation + options: + nodes: {} - observation_space: - type: UC2RedObservation - options: - nodes: + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + - type: NODE_FILE_DELETE + - type: NODE_FILE_CORRUPT + - type: NODE_OS_SCAN + options: + nodes: + - node_ref: client_1 + applications: + - application_ref: data_manipulation_bot + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: # options specific to this particular agent type, basically args of __init__(self) + start_settings: + start_step: 25 + frequency: 20 + variance: 5 + + - ref: defender + team: BLUE + type: ProxyAgent + + observation_space: + type: UC2BlueObservation + options: + num_services_per_node: 1 + num_folders_per_node: 1 + num_files_per_folder: 1 + num_nics_per_node: 2 + nodes: + - node_ref: domain_controller + services: + - service_ref: domain_controller_dns_server + - node_ref: web_server + services: + - service_ref: web_server_database_client + - node_ref: database_server + services: + - service_ref: database_service + folders: + - folder_name: database + files: + - file_name: database.db + - node_ref: backup_server + # services: + # - service_ref: backup_service + - node_ref: security_suite + - node_ref: client_1 + - node_ref: client_2 + links: + - link_ref: router_1___switch_1 + - link_ref: router_1___switch_2 + - link_ref: switch_1___domain_controller + - link_ref: switch_1___web_server + - link_ref: switch_1___database_server + - link_ref: switch_1___backup_server + - link_ref: switch_1___security_suite + - link_ref: switch_2___client_1 + - link_ref: switch_2___client_2 + - link_ref: switch_2___security_suite + acl: + options: + max_acl_rules: 10 + router_node_ref: router_1 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 - node_ref: client_1 - observations: - - logon_status - - operating_status - applications: - - application_ref: data_manipulation_bot - observations: - operating_status - health_status - folders: {} + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 + ics: null - action_space: - action_list: - - type: DONOTHING - # - # - type: NODE_LOGON - # - type: NODE_LOGOFF - # - type: NODE_APPLICATION_EXECUTE - # options: - # execution_definition: - # target_address: arcd.com +agents: + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com - options: - nodes: - - node_ref: client_2 - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 - reward_function: - reward_components: - - type: DUMMY + reward_function: + reward_components: + - type: DUMMY - agent_settings: - start_settings: - start_step: 5 - frequency: 4 - variance: 3 + agent_settings: # options specific to this particular agent type, basically args of __init__(self) + start_settings: + start_step: 25 + frequency: 20 + variance: 5 - - ref: client_1_data_manipulation_red_bot - team: RED - type: RedDatabaseCorruptingAgent + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: + observation_space: + type: UC2RedObservation + options: + nodes: {} + + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + - type: NODE_FILE_DELETE + - type: NODE_FILE_CORRUPT + - type: NODE_OS_SCAN + options: + nodes: + - node_ref: client_1 + applications: + - application_ref: data_manipulation_bot + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + reward_function: + reward_components: + - type: DUMMY + + agent_settings: # options specific to this particular agent type, basically args of __init__(self) + start_settings: + start_step: 25 + frequency: 20 + variance: 5 + + - ref: defender + team: BLUE + type: ProxyAgent + + observation_space: + type: UC2BlueObservation + options: + num_services_per_node: 1 + num_folders_per_node: 1 + num_files_per_folder: 1 + num_nics_per_node: 2 + nodes: + - node_ref: domain_controller + services: + - service_ref: domain_controller_dns_server + - node_ref: web_server + services: + - service_ref: web_server_database_client + - node_ref: database_server + services: + - service_ref: database_service + folders: + - folder_name: database + files: + - file_name: database.db + - node_ref: backup_server + # services: + # - service_ref: backup_service + - node_ref: security_suite + - node_ref: client_1 + - node_ref: client_2 + links: + - link_ref: router_1___switch_1 + - link_ref: router_1___switch_2 + - link_ref: switch_1___domain_controller + - link_ref: switch_1___web_server + - link_ref: switch_1___database_server + - link_ref: switch_1___backup_server + - link_ref: switch_1___security_suite + - link_ref: switch_2___client_1 + - link_ref: switch_2___client_2 + - link_ref: switch_2___security_suite + acl: + options: + max_acl_rules: 10 + router_node_ref: router_1 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 - node_ref: client_1 - observations: - - logon_status - - operating_status - applications: - - application_ref: data_manipulation_bot - observations: - operating_status - health_status - folders: {} + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 + ics: null - action_space: - action_list: - - type: DONOTHING - # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com + + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: # options specific to this particular agent type, basically args of __init__(self) + start_settings: + start_step: 25 + frequency: 20 + variance: 5 + + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: {} + + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + - type: NODE_FILE_DELETE + - type: NODE_FILE_CORRUPT + - type: NODE_OS_SCAN + options: + nodes: + - node_ref: client_1 + applications: + - application_ref: data_manipulation_bot + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: # options specific to this particular agent type, basically args of __init__(self) + start_settings: + start_step: 25 + frequency: 20 + variance: 5 + + - ref: defender1 + team: BLUE + type: ProxyAgent + + observation_space: + type: UC2BlueObservation + options: + num_services_per_node: 1 + num_folders_per_node: 1 + num_files_per_folder: 1 + num_nics_per_node: 2 + nodes: + - node_ref: domain_controller + services: + - service_ref: domain_controller_dns_server + - node_ref: web_server + services: + - service_ref: web_server_database_client + - node_ref: database_server + services: + - service_ref: database_service + folders: + - folder_name: database + files: + - file_name: database.db + - node_ref: backup_server + # services: + # - service_ref: backup_service + - node_ref: security_suite + - node_ref: client_1 + - node_ref: client_2 + links: + - link_ref: router_1___switch_1 + - link_ref: router_1___switch_2 + - link_ref: switch_1___domain_controller + - link_ref: switch_1___web_server + - link_ref: switch_1___database_server + - link_ref: switch_1___backup_server + - link_ref: switch_1___security_suite + - link_ref: switch_2___client_1 + - link_ref: switch_2___client_2 + - link_ref: switch_2___security_suite + acl: + options: + max_acl_rules: 10 + router_node_ref: router_1 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 + - node_ref: client_1 + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 + ics: null + + action_space: + action_list: + - type: DONOTHING + - type: NODE_SERVICE_SCAN + - type: NODE_SERVICE_STOP + - type: NODE_SERVICE_START + - type: NODE_SERVICE_PAUSE + - type: NODE_SERVICE_RESUME + - type: NODE_SERVICE_RESTART + - type: NODE_SERVICE_DISABLE + - type: NODE_SERVICE_ENABLE + - type: NODE_FILE_SCAN + - type: NODE_FILE_CHECKHASH + - type: NODE_FILE_DELETE + - type: NODE_FILE_REPAIR + - type: NODE_FILE_RESTORE + - type: NODE_FOLDER_SCAN + - type: NODE_FOLDER_CHECKHASH + - type: NODE_FOLDER_REPAIR + - type: NODE_FOLDER_RESTORE + - type: NODE_OS_SCAN + - type: NODE_SHUTDOWN + - type: NODE_STARTUP + - type: NODE_RESET + - type: NETWORK_ACL_ADDRULE + options: + target_router_ref: router_1 + - type: NETWORK_ACL_REMOVERULE + options: + target_router_ref: router_1 + - type: NETWORK_NIC_ENABLE + - type: NETWORK_NIC_DISABLE + + action_map: + 0: + action: DONOTHING + options: {} + # scan webapp service + 1: + action: NODE_SERVICE_SCAN + options: + node_id: 2 + service_id: 1 + # stop webapp service + 2: + action: NODE_SERVICE_STOP + options: + node_id: 2 + service_id: 1 + # start webapp service + 3: + action: "NODE_SERVICE_START" + options: + node_id: 2 + service_id: 1 + 4: + action: "NODE_SERVICE_PAUSE" + options: + node_id: 2 + service_id: 1 + 5: + action: "NODE_SERVICE_RESUME" + options: + node_id: 2 + service_id: 1 + 6: + action: "NODE_SERVICE_RESTART" + options: + node_id: 2 + service_id: 1 + 7: + action: "NODE_SERVICE_DISABLE" + options: + node_id: 2 + service_id: 1 + 8: + action: "NODE_SERVICE_ENABLE" + options: + node_id: 2 + service_id: 1 + 9: + action: "NODE_FILE_SCAN" + options: + node_id: 3 + folder_id: 1 + file_id: 1 + 10: + action: "NODE_FILE_CHECKHASH" + options: + node_id: 3 + folder_id: 1 + file_id: 1 + 11: + action: "NODE_FILE_DELETE" + options: + node_id: 3 + folder_id: 1 + file_id: 1 + 12: + action: "NODE_FILE_REPAIR" + options: + node_id: 3 + folder_id: 1 + file_id: 1 + 13: + action: "NODE_FILE_RESTORE" + options: + node_id: 3 + folder_id: 1 + file_id: 1 + 14: + action: "NODE_FOLDER_SCAN" + options: + node_id: 3 + folder_id: 1 + 15: + action: "NODE_FOLDER_CHECKHASH" + options: + node_id: 3 + folder_id: 1 + 16: + action: "NODE_FOLDER_REPAIR" + options: + node_id: 3 + folder_id: 1 + 17: + action: "NODE_FOLDER_RESTORE" + options: + node_id: 3 + folder_id: 1 + 18: + action: "NODE_OS_SCAN" + options: + node_id: 3 + 19: + action: "NODE_SHUTDOWN" + options: + node_id: 6 + 20: + action: "NODE_STARTUP" + options: + node_id: 6 + 21: + action: "NODE_RESET" + options: + node_id: 6 + 22: + action: "NETWORK_ACL_ADDRULE" + options: + position: 1 + permission: 2 + source_ip_id: 7 + dest_ip_id: 1 + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + 23: + action: "NETWORK_ACL_ADDRULE" + options: + position: 1 + permission: 2 + source_ip_id: 8 + dest_ip_id: 1 + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + 24: + action: "NETWORK_ACL_ADDRULE" + options: + position: 1 + permission: 2 + source_ip_id: 7 + dest_ip_id: 3 + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + 25: + action: "NETWORK_ACL_ADDRULE" + options: + position: 1 + permission: 2 + source_ip_id: 8 + dest_ip_id: 3 + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + 26: + action: "NETWORK_ACL_ADDRULE" + options: + position: 1 + permission: 2 + source_ip_id: 7 + dest_ip_id: 4 + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + 27: + action: "NETWORK_ACL_ADDRULE" + options: + position: 1 + permission: 2 + source_ip_id: 8 + dest_ip_id: 4 + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + 28: + action: "NETWORK_ACL_REMOVERULE" + options: + position: 0 + 29: + action: "NETWORK_ACL_REMOVERULE" + options: + position: 1 + 30: + action: "NETWORK_ACL_REMOVERULE" + options: + position: 2 + 31: + action: "NETWORK_ACL_REMOVERULE" + options: + position: 3 + 32: + action: "NETWORK_ACL_REMOVERULE" + options: + position: 4 + 33: + action: "NETWORK_ACL_REMOVERULE" + options: + position: 5 + 34: + action: "NETWORK_ACL_REMOVERULE" + options: + position: 6 + 35: + action: "NETWORK_ACL_REMOVERULE" + options: + position: 7 + 36: + action: "NETWORK_ACL_REMOVERULE" + options: + position: 8 + 37: + action: "NETWORK_ACL_REMOVERULE" + options: + position: 9 + 38: + action: "NETWORK_NIC_DISABLE" + options: + node_id: 1 + nic_id: 1 + 39: + action: "NETWORK_NIC_ENABLE" + options: + node_id: 1 + nic_id: 1 + 40: + action: "NETWORK_NIC_DISABLE" + options: + node_id: 2 + nic_id: 1 + 41: + action: "NETWORK_NIC_ENABLE" + options: + node_id: 2 + nic_id: 1 + 42: + action: "NETWORK_NIC_DISABLE" + options: + node_id: 3 + nic_id: 1 + 43: + action: "NETWORK_NIC_ENABLE" + options: + node_id: 3 + nic_id: 1 + 44: + action: "NETWORK_NIC_DISABLE" + options: + node_id: 4 + nic_id: 1 + 45: + action: "NETWORK_NIC_ENABLE" + options: + node_id: 4 + nic_id: 1 + 46: + action: "NETWORK_NIC_DISABLE" + options: + node_id: 5 + nic_id: 1 + 47: + action: "NETWORK_NIC_ENABLE" + options: + node_id: 5 + nic_id: 1 + 48: + action: "NETWORK_NIC_DISABLE" + options: + node_id: 5 + nic_id: 2 + 49: + action: "NETWORK_NIC_ENABLE" + options: + node_id: 5 + nic_id: 2 + 50: + action: "NETWORK_NIC_DISABLE" + options: + node_id: 6 + nic_id: 1 + 51: + action: "NETWORK_NIC_ENABLE" + options: + node_id: 6 + nic_id: 1 + 52: + action: "NETWORK_NIC_DISABLE" + options: + node_id: 7 + nic_id: 1 + 53: + action: "NETWORK_NIC_ENABLE" + options: + node_id: 7 + nic_id: 1 + + + options: + nodes: + - node_ref: router_1 + - node_ref: switch_1 + - node_ref: switch_2 + - node_ref: domain_controller + - node_ref: web_server + - node_ref: database_server + - node_ref: backup_server + - node_ref: security_suite + - node_ref: client_1 + - node_ref: client_2 + max_folders_per_node: 2 + max_files_per_folder: 2 + max_services_per_node: 2 + max_nics_per_node: 8 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DATABASE_FILE_INTEGRITY + weight: 0.5 + options: + node_ref: database_server + folder_name: database + file_name: database.db + + + - type: WEB_SERVER_404_PENALTY + weight: 0.5 + options: + node_ref: web_server + service_ref: web_server_web_service + + + agent_settings: + # ... + + - ref: defender2 + team: BLUE + type: ProxyAgent + + observation_space: + type: UC2BlueObservation + options: + num_services_per_node: 1 + num_folders_per_node: 1 + num_files_per_folder: 1 + num_nics_per_node: 2 + nodes: + - node_ref: domain_controller + services: + - service_ref: domain_controller_dns_server + - node_ref: web_server + services: + - service_ref: web_server_database_client + - node_ref: database_server + services: + - service_ref: database_service + folders: + - folder_name: database + files: + - file_name: database.db + - node_ref: backup_server + # services: + # - service_ref: backup_service + - node_ref: security_suite + - node_ref: client_1 + - node_ref: client_2 + links: + - link_ref: router_1___switch_1 + - link_ref: router_1___switch_2 + - link_ref: switch_1___domain_controller + - link_ref: switch_1___web_server + - link_ref: switch_1___database_server + - link_ref: switch_1___backup_server + - link_ref: switch_1___security_suite + - link_ref: switch_2___client_1 + - link_ref: switch_2___client_2 + - link_ref: switch_2___security_suite + acl: + options: + max_acl_rules: 10 + router_node_ref: router_1 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 + - node_ref: client_1 + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 + ics: null + + action_space: + action_list: + - type: DONOTHING + - type: NODE_SERVICE_SCAN + - type: NODE_SERVICE_STOP + - type: NODE_SERVICE_START + - type: NODE_SERVICE_PAUSE + - type: NODE_SERVICE_RESUME + - type: NODE_SERVICE_RESTART + - type: NODE_SERVICE_DISABLE + - type: NODE_SERVICE_ENABLE + - type: NODE_FILE_SCAN + - type: NODE_FILE_CHECKHASH + - type: NODE_FILE_DELETE + - type: NODE_FILE_REPAIR + - type: NODE_FILE_RESTORE + - type: NODE_FOLDER_SCAN + - type: NODE_FOLDER_CHECKHASH + - type: NODE_FOLDER_REPAIR + - type: NODE_FOLDER_RESTORE + - type: NODE_OS_SCAN + - type: NODE_SHUTDOWN + - type: NODE_STARTUP + - type: NODE_RESET + - type: NETWORK_ACL_ADDRULE + options: + target_router_ref: router_1 + - type: NETWORK_ACL_REMOVERULE + options: + target_router_ref: router_1 + - type: NETWORK_NIC_ENABLE + - type: NETWORK_NIC_DISABLE + + action_map: + 0: + action: DONOTHING + options: {} + # scan webapp service + 1: + action: NODE_SERVICE_SCAN + options: + node_id: 2 + service_id: 1 + # stop webapp service + 2: + action: NODE_SERVICE_STOP + options: + node_id: 2 + service_id: 1 + # start webapp service + 3: + action: "NODE_SERVICE_START" + options: + node_id: 2 + service_id: 1 + 4: + action: "NODE_SERVICE_PAUSE" + options: + node_id: 2 + service_id: 1 + 5: + action: "NODE_SERVICE_RESUME" + options: + node_id: 2 + service_id: 1 + 6: + action: "NODE_SERVICE_RESTART" + options: + node_id: 2 + service_id: 1 + 7: + action: "NODE_SERVICE_DISABLE" + options: + node_id: 2 + service_id: 1 + 8: + action: "NODE_SERVICE_ENABLE" + options: + node_id: 2 + service_id: 1 + 9: + action: "NODE_FILE_SCAN" + options: + node_id: 3 + folder_id: 1 + file_id: 1 + 10: + action: "NODE_FILE_CHECKHASH" + options: + node_id: 3 + folder_id: 1 + file_id: 1 + 11: + action: "NODE_FILE_DELETE" + options: + node_id: 3 + folder_id: 1 + file_id: 1 + 12: + action: "NODE_FILE_REPAIR" + options: + node_id: 3 + folder_id: 1 + file_id: 1 + 13: + action: "NODE_FILE_RESTORE" + options: + node_id: 3 + folder_id: 1 + file_id: 1 + 14: + action: "NODE_FOLDER_SCAN" + options: + node_id: 3 + folder_id: 1 + 15: + action: "NODE_FOLDER_CHECKHASH" + options: + node_id: 3 + folder_id: 1 + 16: + action: "NODE_FOLDER_REPAIR" + options: + node_id: 3 + folder_id: 1 + 17: + action: "NODE_FOLDER_RESTORE" + options: + node_id: 3 + folder_id: 1 + 18: + action: "NODE_OS_SCAN" + options: + node_id: 3 + 19: + action: "NODE_SHUTDOWN" + options: + node_id: 6 + 20: + action: "NODE_STARTUP" + options: + node_id: 6 + 21: + action: "NODE_RESET" + options: + node_id: 6 + 22: + action: "NETWORK_ACL_ADDRULE" + options: + position: 1 + permission: 2 + source_ip_id: 7 + dest_ip_id: 1 + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + 23: + action: "NETWORK_ACL_ADDRULE" + options: + position: 1 + permission: 2 + source_ip_id: 8 + dest_ip_id: 1 + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + 24: + action: "NETWORK_ACL_ADDRULE" + options: + position: 1 + permission: 2 + source_ip_id: 7 + dest_ip_id: 3 + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + 25: + action: "NETWORK_ACL_ADDRULE" + options: + position: 1 + permission: 2 + source_ip_id: 8 + dest_ip_id: 3 + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + 26: + action: "NETWORK_ACL_ADDRULE" + options: + position: 1 + permission: 2 + source_ip_id: 7 + dest_ip_id: 4 + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + 27: + action: "NETWORK_ACL_ADDRULE" + options: + position: 1 + permission: 2 + source_ip_id: 8 + dest_ip_id: 4 + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + 28: + action: "NETWORK_ACL_REMOVERULE" + options: + position: 0 + 29: + action: "NETWORK_ACL_REMOVERULE" + options: + position: 1 + 30: + action: "NETWORK_ACL_REMOVERULE" + options: + position: 2 + 31: + action: "NETWORK_ACL_REMOVERULE" + options: + position: 3 + 32: + action: "NETWORK_ACL_REMOVERULE" + options: + position: 4 + 33: + action: "NETWORK_ACL_REMOVERULE" + options: + position: 5 + 34: + action: "NETWORK_ACL_REMOVERULE" + options: + position: 6 + 35: + action: "NETWORK_ACL_REMOVERULE" + options: + position: 7 + 36: + action: "NETWORK_ACL_REMOVERULE" + options: + position: 8 + 37: + action: "NETWORK_ACL_REMOVERULE" + options: + position: 9 + 38: + action: "NETWORK_NIC_DISABLE" + options: + node_id: 1 + nic_id: 1 + 39: + action: "NETWORK_NIC_ENABLE" + options: + node_id: 1 + nic_id: 1 + 40: + action: "NETWORK_NIC_DISABLE" + options: + node_id: 2 + nic_id: 1 + 41: + action: "NETWORK_NIC_ENABLE" + options: + node_id: 2 + nic_id: 1 + 42: + action: "NETWORK_NIC_DISABLE" + options: + node_id: 3 + nic_id: 1 + 43: + action: "NETWORK_NIC_ENABLE" + options: + node_id: 3 + nic_id: 1 + 44: + action: "NETWORK_NIC_DISABLE" + options: + node_id: 4 + nic_id: 1 + 45: + action: "NETWORK_NIC_ENABLE" + options: + node_id: 4 + nic_id: 1 + 46: + action: "NETWORK_NIC_DISABLE" + options: + node_id: 5 + nic_id: 1 + 47: + action: "NETWORK_NIC_ENABLE" + options: + node_id: 5 + nic_id: 1 + 48: + action: "NETWORK_NIC_DISABLE" + options: + node_id: 5 + nic_id: 2 + 49: + action: "NETWORK_NIC_ENABLE" + options: + node_id: 5 + nic_id: 2 + 50: + action: "NETWORK_NIC_DISABLE" + options: + node_id: 6 + nic_id: 1 + 51: + action: "NETWORK_NIC_ENABLE" + options: + node_id: 6 + nic_id: 1 + 52: + action: "NETWORK_NIC_DISABLE" + options: + node_id: 7 + nic_id: 1 + 53: + action: "NETWORK_NIC_ENABLE" + options: + node_id: 7 + nic_id: 1 + + + options: + nodes: + - node_ref: router_1 + - node_ref: switch_1 + - node_ref: switch_2 + - node_ref: domain_controller + - node_ref: web_server + - node_ref: database_server + - node_ref: backup_server + - node_ref: security_suite + - node_ref: client_1 + - node_ref: client_2 + max_folders_per_node: 2 + max_files_per_folder: 2 + max_services_per_node: 2 + max_nics_per_node: 8 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DATABASE_FILE_INTEGRITY + weight: 0.5 + options: + node_ref: database_server + folder_name: database + file_name: database.db + + + - type: WEB_SERVER_404_PENALTY + weight: 0.5 + options: + node_ref: web_server + service_ref: web_server_web_service + + + agent_settings: + # ... + + + + + +simulation: + network: + nodes: + + - ref: router_1 + type: router + hostname: router_1 + num_ports: 5 + ports: + 1: + ip_address: 192.168.1.1 + subnet_mask: 255.255.255.0 + 2: + ip_address: 192.168.1.1 + subnet_mask: 255.255.255.0 + acl: + 0: + action: PERMIT + src_port: POSTGRES_SERVER + dst_port: POSTGRES_SERVER + 1: + action: PERMIT + src_port: DNS + dst_port: DNS + 22: + action: PERMIT + src_port: ARP + dst_port: ARP + 23: + action: PERMIT + protocol: ICMP + + - ref: switch_1 + type: switch + hostname: switch_1 + num_ports: 8 + + - ref: switch_2 + type: switch + hostname: switch_2 + num_ports: 8 + + - ref: domain_controller + type: server + hostname: domain_controller + ip_address: 192.168.1.10 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + services: + - ref: domain_controller_dns_server + type: DNSServer + options: + domain_mapping: + arcd.com: 192.168.1.12 # web server + + - ref: web_server + type: server + hostname: web_server + ip_address: 192.168.1.12 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.10 + dns_server: 192.168.1.10 + services: + - ref: web_server_database_client + type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + - ref: web_server_web_service + type: WebServer + + + - ref: database_server + type: server + hostname: database_server + ip_address: 192.168.1.14 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - ref: database_service + type: DatabaseService + + - ref: backup_server + type: server + hostname: backup_server + ip_address: 192.168.1.16 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - ref: backup_service + type: DatabaseBackup + + - ref: security_suite + type: server + hostname: security_suite + ip_address: 192.168.1.110 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + nics: + 2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot + ip_address: 192.168.10.110 + subnet_mask: 255.255.255.0 + + - ref: client_1 + type: computer + hostname: client_1 + ip_address: 192.168.10.21 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - ref: data_manipulation_bot + type: DataManipulationBot + options: + port_scan_p_of_success: 0.1 + data_manipulation_p_of_success: 0.1 + payload: "DELETE" + server_ip: 192.168.1.14 + services: + - ref: client_1_dns_client + type: DNSClient + + - ref: client_2 + type: computer + hostname: client_2 + ip_address: 192.168.10.22 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - ref: client_2_web_browser + type: WebBrowser + services: + - ref: client_2_dns_client + type: DNSClient + + links: + - ref: router_1___switch_1 + endpoint_a_ref: router_1 + endpoint_a_port: 1 + endpoint_b_ref: switch_1 + endpoint_b_port: 8 + - ref: router_1___switch_2 + endpoint_a_ref: router_1 + endpoint_a_port: 2 + endpoint_b_ref: switch_2 + endpoint_b_port: 8 + - ref: switch_1___domain_controller + endpoint_a_ref: switch_1 + endpoint_a_port: 1 + endpoint_b_ref: domain_controller + endpoint_b_port: 1 + - ref: switch_1___web_server + endpoint_a_ref: switch_1 + endpoint_a_port: 2 + endpoint_b_ref: web_server + endpoint_b_port: 1 + - ref: switch_1___database_server + endpoint_a_ref: switch_1 + endpoint_a_port: 3 + endpoint_b_ref: database_server + endpoint_b_port: 1 + - ref: switch_1___backup_server + endpoint_a_ref: switch_1 + endpoint_a_port: 4 + endpoint_b_ref: backup_server + endpoint_b_port: 1 + - ref: switch_1___security_suite + endpoint_a_ref: switch_1 + endpoint_a_port: 7 + endpoint_b_ref: security_suite + endpoint_b_port: 1 + - ref: switch_2___client_1 + endpoint_a_ref: switch_2 + endpoint_a_port: 1 + endpoint_b_ref: client_1 + endpoint_b_port: 1 + - ref: switch_2___client_2 + endpoint_a_ref: switch_2 + endpoint_a_port: 2 + endpoint_b_ref: client_2 + endpoint_b_port: 1 + - ref: switch_2___security_suite + endpoint_a_ref: switch_2 + endpoint_a_port: 7 + endpoint_b_ref: security_suite + endpoint_b_port: 2 diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index 64be5488..d7e94cb6 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -15,7 +15,7 @@ io_settings: checkpoint_interval: 5 -game_config: +game: ports: - ARP - DNS @@ -26,522 +26,507 @@ game_config: - TCP - UDP - agents: - - ref: client_1_green_user - team: GREEN - type: GreenWebBrowsingAgent - observation_space: - type: UC2GreenObservation - action_space: - action_list: - - type: DONOTHING - # - # - type: NODE_LOGON - # - type: NODE_LOGOFF - # - type: NODE_APPLICATION_EXECUTE - # options: - # execution_definition: - # target_address: arcd.com +agents: + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com - options: - nodes: - - node_ref: client_2 - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 - reward_function: - reward_components: - - type: DUMMY + reward_function: + reward_components: + - type: DUMMY - agent_settings: - start_settings: - start_step: 5 - frequency: 4 - variance: 3 + agent_settings: # options specific to this particular agent type, basically args of __init__(self) + start_settings: + start_step: 25 + frequency: 20 + variance: 5 - - ref: client_1_data_manipulation_red_bot - team: RED - type: RedDatabaseCorruptingAgent + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: + observation_space: + type: UC2RedObservation + options: + nodes: {} + + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + - type: NODE_FILE_DELETE + - type: NODE_FILE_CORRUPT + - type: NODE_OS_SCAN + options: + nodes: + - node_ref: client_1 + applications: + - application_ref: data_manipulation_bot + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: # options specific to this particular agent type, basically args of __init__(self) + start_settings: + start_step: 25 + frequency: 20 + variance: 5 + + - ref: defender + team: BLUE + type: ProxyAgent + + observation_space: + type: UC2BlueObservation + options: + num_services_per_node: 1 + num_folders_per_node: 1 + num_files_per_folder: 1 + num_nics_per_node: 2 + nodes: + - node_ref: domain_controller + services: + - service_ref: domain_controller_dns_server + - node_ref: web_server + services: + - service_ref: web_server_database_client + - node_ref: database_server + services: + - service_ref: database_service + folders: + - folder_name: database + files: + - file_name: database.db + - node_ref: backup_server + # services: + # - service_ref: backup_service + - node_ref: security_suite + - node_ref: client_1 + - node_ref: client_2 + links: + - link_ref: router_1___switch_1 + - link_ref: router_1___switch_2 + - link_ref: switch_1___domain_controller + - link_ref: switch_1___web_server + - link_ref: switch_1___database_server + - link_ref: switch_1___backup_server + - link_ref: switch_1___security_suite + - link_ref: switch_2___client_1 + - link_ref: switch_2___client_2 + - link_ref: switch_2___security_suite + acl: + options: + max_acl_rules: 10 + router_node_ref: router_1 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 - node_ref: client_1 - observations: - - logon_status - - operating_status - applications: - - application_ref: data_manipulation_bot - observations: - operating_status - health_status - folders: {} + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 + ics: null - action_space: - action_list: - - type: DONOTHING - # - # - type: NODE_LOGON - # - type: NODE_LOGOFF - # - type: NODE_APPLICATION_EXECUTE - # options: - # execution_definition: - # target_address: arcd.com +agents: + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com - options: - nodes: - - node_ref: client_2 - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 - reward_function: - reward_components: - - type: DUMMY + reward_function: + reward_components: + - type: DUMMY - agent_settings: - start_settings: - start_step: 5 - frequency: 4 - variance: 3 + agent_settings: # options specific to this particular agent type, basically args of __init__(self) + start_settings: + start_step: 25 + frequency: 20 + variance: 5 - - ref: client_1_data_manipulation_red_bot - team: RED - type: RedDatabaseCorruptingAgent + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: + observation_space: + type: UC2RedObservation + options: + nodes: {} + + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + - type: NODE_FILE_DELETE + - type: NODE_FILE_CORRUPT + - type: NODE_OS_SCAN + options: + nodes: + - node_ref: client_1 + applications: + - application_ref: data_manipulation_bot + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: # options specific to this particular agent type, basically args of __init__(self) + start_settings: + start_step: 25 + frequency: 20 + variance: 5 + + - ref: defender + team: BLUE + type: ProxyAgent + + observation_space: + type: UC2BlueObservation + options: + num_services_per_node: 1 + num_folders_per_node: 1 + num_files_per_folder: 1 + num_nics_per_node: 2 + nodes: + - node_ref: domain_controller + services: + - service_ref: domain_controller_dns_server + - node_ref: web_server + services: + - service_ref: web_server_database_client + - node_ref: database_server + services: + - service_ref: database_service + folders: + - folder_name: database + files: + - file_name: database.db + - node_ref: backup_server + # services: + # - service_ref: backup_service + - node_ref: security_suite + - node_ref: client_1 + - node_ref: client_2 + links: + - link_ref: router_1___switch_1 + - link_ref: router_1___switch_2 + - link_ref: switch_1___domain_controller + - link_ref: switch_1___web_server + - link_ref: switch_1___database_server + - link_ref: switch_1___backup_server + - link_ref: switch_1___security_suite + - link_ref: switch_2___client_1 + - link_ref: switch_2___client_2 + - link_ref: switch_2___security_suite + acl: + options: + max_acl_rules: 10 + router_node_ref: router_1 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 - node_ref: client_1 - observations: - - logon_status - - operating_status - applications: - - application_ref: data_manipulation_bot - observations: - operating_status - health_status - folders: {} + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 + ics: null - action_space: - action_list: - - type: DONOTHING - # Date: Mon, 27 Nov 2023 11:55:58 +0000 Subject: [PATCH 25/35] Fix incorrect order in session from config --- src/primaite/session/session.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py index 80b63ba7..3919902a 100644 --- a/src/primaite/session/session.py +++ b/src/primaite/session/session.py @@ -88,16 +88,16 @@ class PrimaiteSession: @classmethod def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession": """Create a PrimaiteSession object from a config dictionary.""" + # READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS... + io_settings = cfg.get("io_settings", {}) + io_manager = SessionIO(SessionIOSettings(**io_settings)) + game = PrimaiteGame.from_config(cfg) sess = cls(game=game) - + sess.io_manager = io_manager sess.training_options = TrainingOptions(**cfg["training_config"]) - # READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS... - io_settings = cfg.get("io_settings", {}) - sess.io_manager.settings = SessionIOSettings(**io_settings) - # CREATE ENVIRONMENT if sess.training_options.rl_framework == "RLLIB_single_agent": sess.env = PrimaiteRayEnv(env_config={"game": game}) From 89cbc0835221f4172ae469aa9c7229b3ee8b4cb4 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 27 Nov 2023 13:28:11 +0000 Subject: [PATCH 26/35] Apply suggestions from code review --- src/primaite/game/agent/actions.py | 54 +++++++++---------- .../game/agent/data_manipulation_agent.py | 0 src/primaite/game/agent/interface.py | 23 +++++--- .../red_services/data_manipulation_bot.py | 6 +-- .../test_data_manipulation_bot.py | 2 +- 5 files changed, 47 insertions(+), 38 deletions(-) delete mode 100644 src/primaite/game/agent/data_manipulation_agent.py diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 6c6cf7b2..ea992485 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -82,7 +82,7 @@ class NodeServiceAbstractAction(AbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager) self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services} - self.verb: str + self.verb: str # define but don't initialise: defends against children classes not defining this def form_request(self, node_id: int, service_id: int) -> List[str]: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" @@ -98,7 +98,7 @@ class NodeServiceScanAction(NodeServiceAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb = "scan" + self.verb: str = "scan" class NodeServiceStopAction(NodeServiceAbstractAction): @@ -106,7 +106,7 @@ class NodeServiceStopAction(NodeServiceAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb = "stop" + self.verb: str = "stop" class NodeServiceStartAction(NodeServiceAbstractAction): @@ -114,7 +114,7 @@ class NodeServiceStartAction(NodeServiceAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb = "start" + self.verb: str = "start" class NodeServicePauseAction(NodeServiceAbstractAction): @@ -122,7 +122,7 @@ class NodeServicePauseAction(NodeServiceAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb = "pause" + self.verb: str = "pause" class NodeServiceResumeAction(NodeServiceAbstractAction): @@ -130,7 +130,7 @@ class NodeServiceResumeAction(NodeServiceAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb = "resume" + self.verb: str = "resume" class NodeServiceRestartAction(NodeServiceAbstractAction): @@ -138,7 +138,7 @@ class NodeServiceRestartAction(NodeServiceAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb = "restart" + self.verb: str = "restart" class NodeServiceDisableAction(NodeServiceAbstractAction): @@ -146,7 +146,7 @@ class NodeServiceDisableAction(NodeServiceAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb = "disable" + self.verb: str = "disable" class NodeServiceEnableAction(NodeServiceAbstractAction): @@ -154,7 +154,7 @@ class NodeServiceEnableAction(NodeServiceAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb = "enable" + self.verb: str = "enable" class NodeApplicationAbstractAction(AbstractAction): @@ -169,7 +169,7 @@ class NodeApplicationAbstractAction(AbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None: super().__init__(manager=manager) self.shape: Dict[str, int] = {"node_id": num_nodes, "application_id": num_applications} - self.verb: str + self.verb: str # define but don't initialise: defends against children classes not defining this def form_request(self, node_id: int, application_id: int) -> List[str]: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" @@ -185,7 +185,7 @@ class NodeApplicationExecuteAction(NodeApplicationAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, num_applications=num_applications) - self.verb = "execute" + self.verb: str = "execute" class NodeFolderAbstractAction(AbstractAction): @@ -200,7 +200,7 @@ class NodeFolderAbstractAction(AbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: super().__init__(manager=manager) self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders} - self.verb: str + self.verb: str # define but don't initialise: defends against children classes not defining this def form_request(self, node_id: int, folder_id: int) -> List[str]: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" @@ -254,7 +254,7 @@ class NodeFileAbstractAction(AbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager=manager) self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders, "file_id": num_files} - self.verb: str + self.verb: str # define but don't initialise: defends against children classes not defining this def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" @@ -271,7 +271,7 @@ class NodeFileScanAction(NodeFileAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb = "scan" + self.verb: str = "scan" class NodeFileCheckhashAction(NodeFileAbstractAction): @@ -279,7 +279,7 @@ class NodeFileCheckhashAction(NodeFileAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb = "checkhash" + self.verb: str = "checkhash" class NodeFileDeleteAction(NodeFileAbstractAction): @@ -287,7 +287,7 @@ class NodeFileDeleteAction(NodeFileAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb = "delete" + self.verb: str = "delete" class NodeFileRepairAction(NodeFileAbstractAction): @@ -295,7 +295,7 @@ class NodeFileRepairAction(NodeFileAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb = "repair" + self.verb: str = "repair" class NodeFileRestoreAction(NodeFileAbstractAction): @@ -303,7 +303,7 @@ class NodeFileRestoreAction(NodeFileAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb = "restore" + self.verb: str = "restore" class NodeFileCorruptAction(NodeFileAbstractAction): @@ -311,7 +311,7 @@ class NodeFileCorruptAction(NodeFileAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb = "corrupt" + self.verb: str = "corrupt" class NodeAbstractAction(AbstractAction): @@ -325,7 +325,7 @@ class NodeAbstractAction(AbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: super().__init__(manager=manager) self.shape: Dict[str, int] = {"node_id": num_nodes} - self.verb: str + self.verb: str # define but don't initialise: defends against children classes not defining this def form_request(self, node_id: int) -> List[str]: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" @@ -338,7 +338,7 @@ class NodeOSScanAction(NodeAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes) - self.verb = "scan" + self.verb: str = "scan" class NodeShutdownAction(NodeAbstractAction): @@ -346,7 +346,7 @@ class NodeShutdownAction(NodeAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes) - self.verb = "shutdown" + self.verb: str = "shutdown" class NodeStartupAction(NodeAbstractAction): @@ -354,7 +354,7 @@ class NodeStartupAction(NodeAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes) - self.verb = "startup" + self.verb: str = "startup" class NodeResetAction(NodeAbstractAction): @@ -362,7 +362,7 @@ class NodeResetAction(NodeAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes) - self.verb = "reset" + self.verb: str = "reset" class NetworkACLAddRuleAction(AbstractAction): @@ -520,7 +520,7 @@ class NetworkNICAbstractAction(AbstractAction): """ super().__init__(manager=manager) self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node} - self.verb: str + self.verb: str # define but don't initialise: defends against children classes not defining this def form_request(self, node_id: int, nic_id: int) -> List[str]: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" @@ -543,7 +543,7 @@ class NetworkNICEnableAction(NetworkNICAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs) - self.verb = "enable" + self.verb: str = "enable" class NetworkNICDisableAction(NetworkNICAbstractAction): @@ -551,7 +551,7 @@ class NetworkNICDisableAction(NetworkNICAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs) - self.verb = "disable" + self.verb: str = "disable" class ActionManager: diff --git a/src/primaite/game/agent/data_manipulation_agent.py b/src/primaite/game/agent/data_manipulation_agent.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 6e783725..fbbe5473 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional, Tuple, TYPE_CHECKING from gymnasium.core import ActType, ObsType -from pydantic import BaseModel +from pydantic import BaseModel, model_validator from primaite.game.agent.actions import ActionManager from primaite.game.agent.observations import ObservationManager @@ -23,6 +23,21 @@ class AgentStartSettings(BaseModel): variance: int = 0 "The amount the frequency can randomly change to" + @model_validator(mode="after") + def check_variance_lt_frequency(self) -> "AgentStartSettings": + """ + Make sure variance is equal to or lower than frequency. + + This is because the calculation for the next execution time is now + (frequency +- variance). If variance were + greater than frequency, sometimes the bracketed term would be negative and the attack would never happen again. + """ + if self.variance > self.frequency: + raise ValueError( + f"Agent start settings error: variance must be lower than frequency " + f"{self.variance=}, {self.frequency=}" + ) + return self + class AgentSettings(BaseModel): """Settings for configuring the operation of an agent.""" @@ -180,9 +195,3 @@ class ProxyAgent(AbstractAgent): The environment is responsible for calling this method when it receives an action from the agent policy. """ self.most_recent_action = action - - -class AbstractGATEAgent(AbstractAgent): - """Base class for actors controlled via external messages, such as RL policies.""" - - ... diff --git a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py index 6db9e1aa..b0b34396 100644 --- a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py +++ b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py @@ -24,7 +24,7 @@ class DataManipulationAttackStage(IntEnum): "Represents the stage of performing a horizontal port scan on the target." ATTACKING = 3 "Stage of actively attacking the target." - COMPLETE = 4 + SUCCEEDED = 4 "Indicates the attack has been successfully completed." FAILED = 5 "Signifies that the attack has failed." @@ -134,7 +134,7 @@ class DataManipulationBot(DatabaseClient): attack_successful = True if attack_successful: self.sys_log.info(f"{self.name}: Data manipulation successful") - self.attack_stage = DataManipulationAttackStage.COMPLETE + self.attack_stage = DataManipulationAttackStage.SUCCEEDED else: self.sys_log.info(f"{self.name}: Data manipulation failed") self.attack_stage = DataManipulationAttackStage.FAILED @@ -163,7 +163,7 @@ class DataManipulationBot(DatabaseClient): self._perform_data_manipulation(p_of_success=self.data_manipulation_p_of_success) if self.repeat and self.attack_stage in ( - DataManipulationAttackStage.COMPLETE, + DataManipulationAttackStage.SUCCEEDED, DataManipulationAttackStage.FAILED, ): self.attack_stage = DataManipulationAttackStage.NOT_STARTED diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py index 936f7c5c..3b1e4aa4 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py @@ -69,5 +69,5 @@ def test_dm_bot_perform_data_manipulation_success(dm_bot): dm_bot._perform_data_manipulation(p_of_success=1.0) - assert dm_bot.attack_stage in (DataManipulationAttackStage.COMPLETE, DataManipulationAttackStage.FAILED) + assert dm_bot.attack_stage in (DataManipulationAttackStage.SUCCEEDED, DataManipulationAttackStage.FAILED) assert dm_bot.connected From 4d4a578555f2452c85faa72e2b40b85cd4489542 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Mon, 27 Nov 2023 13:47:59 +0000 Subject: [PATCH 27/35] #1859 - Integrated the runtime execution for web client. Added in the webclient application execution action. Now fixing http status code issues. --- .../config/_package_data/example_config.yaml | 28 +++++++++---------- src/primaite/game/game.py | 5 +++- src/primaite/session/environment.py | 2 +- src/primaite/simulator/network/networks.py | 2 ++ .../simulator/network/protocols/http.py | 4 +-- .../system/applications/database_client.py | 4 +-- .../system/applications/web_browser.py | 15 ++++++++-- .../system/services/web_server/web_server.py | 4 ++- .../system/test_web_client_server.py | 11 ++++---- 9 files changed, 46 insertions(+), 29 deletions(-) diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 3cea2f29..b68861e1 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -1,5 +1,5 @@ training_config: - rl_framework: RLLIB_single_agent + rl_framework: SB3 rl_algorithm: PPO seed: 333 n_learn_episodes: 1 @@ -36,22 +36,16 @@ agents: action_space: action_list: - type: DONOTHING - # - # - type: NODE_LOGON - # - type: NODE_LOGOFF - # - type: NODE_APPLICATION_EXECUTE - # options: - # execution_definition: - # target_address: arcd.com - + - type: NODE_APPLICATION_EXECUTE options: nodes: - node_ref: client_2 + applications: + - application_ref: client_2_web_browser max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 + max_applications_per_node: 1 reward_function: reward_components: @@ -549,19 +543,19 @@ simulation: ip_address: 192.168.10.1 subnet_mask: 255.255.255.0 acl: - 0: + 18: action: PERMIT src_port: POSTGRES_SERVER dst_port: POSTGRES_SERVER - 1: + 19: action: PERMIT src_port: DNS dst_port: DNS - 2: + 20: action: PERMIT src_port: FTP dst_port: FTP - 3: + 21: action: PERMIT src_port: HTTP dst_port: HTTP @@ -679,10 +673,14 @@ simulation: applications: - ref: client_2_web_browser type: WebBrowser + options: + target_url: http://arcd.com/users/ services: - ref: client_2_dns_client type: DNSClient + + links: - ref: router_1___switch_1 endpoint_a_ref: router_1 diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index ae60bbc1..48615ca6 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -316,6 +316,10 @@ class PrimaiteGame: port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")), ) + elif application_type == "WebBrowser": + if "options" in application_cfg: + opt = application_cfg["options"] + new_application.target_url = opt.get("target_url") if "nics" in node_cfg: for nic_num, nic_cfg in node_cfg["nics"].items(): new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"])) @@ -377,7 +381,6 @@ class PrimaiteGame: action_space_cfg["options"]["application_uuids"].append(node_application_uuids) else: action_space_cfg["options"]["application_uuids"].append([]) - # Each action space can potentially have a different list of nodes that it can apply to. Therefore, # we will pass node_uuids as a part of the action space config. # However, it's not possible to specify the node uuids directly in the config, as they are generated diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index db24db60..a5fdade9 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -37,7 +37,7 @@ class PrimaiteGymEnv(gymnasium.Env): terminated = False truncated = self.game.calculate_truncated() info = {} - + print(f"Episode: {self.game.episode_counter}, Step: {self.game.step_counter}, Reward: {reward}") return next_obs, reward, terminated, truncated, info def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index ea767b54..446e5649 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -157,6 +157,8 @@ def arcd_uc2_network() -> Network: operating_state=NodeOperatingState.ON, ) client_2.power_on() + web_browser = client_2.software_manager["WebBrowser"] + web_browser.target_url = "http://arcd.com/users/" network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2]) # Domain Controller diff --git a/src/primaite/simulator/network/protocols/http.py b/src/primaite/simulator/network/protocols/http.py index 2dba2614..b88916a9 100644 --- a/src/primaite/simulator/network/protocols/http.py +++ b/src/primaite/simulator/network/protocols/http.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import Enum, IntEnum from primaite.simulator.network.protocols.packet import DataPacket @@ -25,7 +25,7 @@ class HttpRequestMethod(Enum): """Apply partial modifications to a resource.""" -class HttpStatusCode(Enum): +class HttpStatusCode(IntEnum): """List of available HTTP Statuses.""" OK = 200 diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index b24b6062..37236e69 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -75,11 +75,11 @@ class DatabaseClient(Application): """ if is_reattempt: if self.connected: - self.sys_log.info(f"{self.name}: DatabaseClient connected to {server_ip_address} authorised") + self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} authorised") self.server_ip_address = server_ip_address return self.connected else: - self.sys_log.info(f"{self.name}: DatabaseClient connected to {server_ip_address} declined") + self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} declined") return False payload = {"type": "connect_request", "password": password} software_manager: SoftwareManager = self.software_manager diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index ea9c3ac3..0a9c7fc3 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -2,6 +2,7 @@ from ipaddress import IPv4Address from typing import Dict, Optional from urllib.parse import urlparse +from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.protocols.http import HttpRequestMethod, HttpRequestPacket, HttpResponsePacket from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port @@ -16,6 +17,8 @@ class WebBrowser(Application): The application requests and loads web pages using its domain name and requesting IP addresses using DNS. """ + target_url: Optional[str] = None + domain_name_ip_address: Optional[IPv4Address] = None "The IP address of the domain name for the webpage." @@ -32,6 +35,14 @@ class WebBrowser(Application): super().__init__(**kwargs) self.run() + def _init_request_manager(self) -> RequestManager: + rm = super()._init_request_manager() + rm.add_request( + name="execute", request_type=RequestType(func=lambda request, context: self.get_webpage()) # noqa + ) + + return rm + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of the WebBrowser. @@ -51,7 +62,7 @@ class WebBrowser(Application): self.domain_name_ip_address = None self.latest_response = None - def get_webpage(self, url: str) -> bool: + def get_webpage(self) -> bool: """ Retrieve the webpage. @@ -60,6 +71,7 @@ class WebBrowser(Application): :param: url: The address of the web page the browser requests :type: url: str """ + url = self.target_url # reset latest response self.latest_response = None @@ -71,7 +83,6 @@ class WebBrowser(Application): # get the IP address of the domain name via DNS dns_client: DNSClient = self.software_manager.software["DNSClient"] - domain_exists = dns_client.check_domain_exists(target_domain=parsed_url.hostname) # if domain does not exist, the request fails diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index cb1a4738..5dda82d5 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -29,8 +29,9 @@ class WebServer(Service): :rtype: Dict """ state = super().describe_state() + state["last_response_status_code"] = ( - self.last_response_status_code.value if self.last_response_status_code else None + self.last_response_status_code.value if isinstance(self.last_response_status_code, HttpStatusCode) else None ) return state @@ -84,6 +85,7 @@ class WebServer(Service): # return true if response is OK self.last_response_status_code = response.status_code + print(self.last_response_status_code) return response.status_code == HttpStatusCode.OK def _handle_get_request(self, payload: HttpRequestPacket) -> HttpResponsePacket: diff --git a/tests/integration_tests/system/test_web_client_server.py b/tests/integration_tests/system/test_web_client_server.py index f4546cbf..991d6282 100644 --- a/tests/integration_tests/system/test_web_client_server.py +++ b/tests/integration_tests/system/test_web_client_server.py @@ -3,7 +3,6 @@ from primaite.simulator.network.hardware.nodes.server import Server from primaite.simulator.network.protocols.http import HttpStatusCode from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.web_browser import WebBrowser -from primaite.simulator.system.services.service import ServiceOperatingState def test_web_page_home_page(uc2_network): @@ -11,9 +10,10 @@ def test_web_page_home_page(uc2_network): client_1: Computer = uc2_network.get_node_by_hostname("client_1") web_client: WebBrowser = client_1.software_manager.software["WebBrowser"] web_client.run() + web_client.target_url = "http://arcd.com/" assert web_client.operating_state == ApplicationOperatingState.RUNNING - assert web_client.get_webpage("http://arcd.com/") is True + assert web_client.get_webpage() is True # latest reponse should have status code 200 assert web_client.latest_response is not None @@ -27,7 +27,7 @@ def test_web_page_get_users_page_request_with_domain_name(uc2_network): web_client.run() assert web_client.operating_state == ApplicationOperatingState.RUNNING - assert web_client.get_webpage("http://arcd.com/users/") is True + assert web_client.get_webpage() is True # latest reponse should have status code 200 assert web_client.latest_response is not None @@ -41,11 +41,12 @@ def test_web_page_get_users_page_request_with_ip_address(uc2_network): web_client.run() web_server: Server = uc2_network.get_node_by_hostname("web_server") - web_server_ip = web_server.nics.get(next(iter(web_server.nics))).ip_address + web_server_ip = web_server.nics.get(next(iter(web_server.nics))).ip_address + web_client.target_url = f"http://{web_server_ip}/users/" assert web_client.operating_state == ApplicationOperatingState.RUNNING - assert web_client.get_webpage(f"http://{web_server_ip}/users/") is True + assert web_client.get_webpage() is True # latest reponse should have status code 200 assert web_client.latest_response is not None From ae5046b8fb94d1a8c787f870fa461489c5d98fef Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Mon, 27 Nov 2023 17:05:12 +0000 Subject: [PATCH 28/35] #1859 - As disccused --- src/primaite/game/agent/actions.py | 11 ++-- src/primaite/game/agent/observations.py | 2 +- src/primaite/game/agent/rewards.py | 3 +- src/primaite/game/game.py | 56 +++++++++++++------ .../simulator/network/hardware/base.py | 3 + .../system/applications/web_browser.py | 7 +++ .../system/services/web_server/web_server.py | 29 +++++++++- 7 files changed, 83 insertions(+), 28 deletions(-) diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index ea992485..62e56c6e 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -634,7 +634,6 @@ class ActionManager: :type act_map: Optional[Dict[int, Dict]] """ self.game: "PrimaiteGame" = game - self.sim: Simulation = self.game.simulation self.node_uuids: List[str] = node_uuids self.application_uuids: List[List[str]] = application_uuids self.protocols: List[str] = protocols @@ -646,7 +645,7 @@ class ActionManager: else: self.ip_address_list = [] for node_uuid in self.node_uuids: - node_obj = self.sim.network.nodes[node_uuid] + node_obj = self.game.simulation.network.nodes[node_uuid] nics = node_obj.nics for nic_uuid, nic_obj in nics.items(): self.ip_address_list.append(nic_obj.ip_address) @@ -770,7 +769,7 @@ class ActionManager: :rtype: Optional[str] """ node_uuid = self.get_node_uuid_by_idx(node_idx) - node = self.sim.network.nodes[node_uuid] + node = self.game.simulation.network.nodes[node_uuid] folder_uuids = list(node.file_system.folders.keys()) return folder_uuids[folder_idx] if len(folder_uuids) > folder_idx else None @@ -788,7 +787,7 @@ class ActionManager: :rtype: Optional[str] """ node_uuid = self.get_node_uuid_by_idx(node_idx) - node = self.sim.network.nodes[node_uuid] + node = self.game.simulation.network.nodes[node_uuid] folder_uuids = list(node.file_system.folders.keys()) if len(folder_uuids) <= folder_idx: return None @@ -807,7 +806,7 @@ class ActionManager: :rtype: Optional[str] """ node_uuid = self.get_node_uuid_by_idx(node_idx) - node = self.sim.network.nodes[node_uuid] + node = self.game.simulation.network.nodes[node_uuid] service_uuids = list(node.services.keys()) return service_uuids[service_idx] if len(service_uuids) > service_idx else None @@ -867,7 +866,7 @@ class ActionManager: :rtype: str """ node_uuid = self.get_node_uuid_by_idx(node_idx) - node_obj = self.sim.network.nodes[node_uuid] + node_obj = self.game.simulation.network.nodes[node_uuid] nics = list(node_obj.nics.keys()) if len(nics) <= nic_idx: return None diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index 14fb2fa7..823d65d7 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -162,7 +162,7 @@ class ServiceObservation(AbstractObservation): :return: Constructed service observation :rtype: ServiceObservation """ - return cls(where=parent_where + ["services", game.ref_map_services[config["service_ref"]].uuid]) + return cls(where=parent_where + ["services", game.ref_map_services[config["service_ref"]]]) class LinkObservation(AbstractObservation): diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 8a1c2da4..7cca9116 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -25,6 +25,7 @@ the structure: service_ref: web_server_database_client ``` """ +import json from abc import abstractmethod from typing import Dict, List, Tuple, Type, TYPE_CHECKING @@ -213,7 +214,7 @@ class WebServer404Penalty(AbstractReward): _LOGGER.warn(msg) return DummyReward() # TODO: should we error out with incorrect inputs? Probably! node_uuid = game.ref_map_nodes[node_ref] - service_uuid = game.ref_map_services[service_ref].uuid + service_uuid = game.ref_map_services[service_ref] if not (node_uuid and service_uuid): msg = ( f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not" diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 48615ca6..147ed499 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -59,8 +59,9 @@ class PrimaiteGame: """Initialise a PrimaiteGame object.""" self.simulation: Simulation = Simulation() """Simulation object with which the agents will interact.""" + print(f"Hello, welcome to PrimaiteGame. This is the ID of the ORIGINAL simulation {id(self.simulation)}") - self._simulation_initial_state = deepcopy(self.simulation) + self._simulation_initial_state = None """The Simulation original state (deepcopy of the original Simulation).""" self.agents: List[AbstractAgent] = [] @@ -78,16 +79,16 @@ class PrimaiteGame: self.options: PrimaiteGameOptions """Special options that apply for the entire game.""" - self.ref_map_nodes: Dict[str, Node] = {} + self.ref_map_nodes: Dict[str, str] = {} """Mapping from unique node reference name to node object. Used when parsing config files.""" - self.ref_map_services: Dict[str, Service] = {} + self.ref_map_services: Dict[str, str] = {} """Mapping from human-readable service reference to service object. Used for parsing config files.""" - self.ref_map_applications: Dict[str, Application] = {} + self.ref_map_applications: Dict[str, str] = {} """Mapping from human-readable application reference to application object. Used for parsing config files.""" - self.ref_map_links: Dict[str, Link] = {} + self.ref_map_links: Dict[str, str] = {} """Mapping from human-readable link reference to link object. Used when parsing config files.""" def step(self): @@ -161,6 +162,33 @@ class PrimaiteGame: self.step_counter = 0 _LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}") self.simulation = deepcopy(self._simulation_initial_state) + self._reset_components_for_episode() + print("Reset") + + def _reset_components_for_episode(self): + print("Performing full reset for episode") + for node in self.simulation.network.nodes.values(): + print(f"Resetting Node: {node.hostname}") + node.reset_component_for_episode(self.episode_counter) + + # reset Node NIC + + # Reset Node Services + + # Reset Node Applications + print(f"Resetting Software...") + for application in node.software_manager.software.values(): + print(f"Resetting {application.name}") + if isinstance(application, WebBrowser): + application.do_this() + + # Reset Node FileSystem + # Reset Node FileSystemFolder's + # Reset Node FileSystemFile's + + # Reset Router + + # Reset Links def close(self) -> None: """Close the game, this will close the simulation.""" @@ -190,10 +218,6 @@ class PrimaiteGame: sim = game.simulation net = sim.network - game.ref_map_nodes: Dict[str, Node] = {} - game.ref_map_services: Dict[str, Service] = {} - game.ref_map_links: Dict[str, Link] = {} - nodes_cfg = cfg["simulation"]["network"]["nodes"] links_cfg = cfg["simulation"]["network"]["links"] for node_cfg in nodes_cfg: @@ -269,7 +293,7 @@ class PrimaiteGame: print(f"installing {service_type} on node {new_node.hostname}") new_node.software_manager.install(service_types_mapping[service_type]) new_service = new_node.software_manager.software[service_type] - game.ref_map_services[service_ref] = new_service + game.ref_map_services[service_ref] = new_service.uuid else: print(f"service type not found {service_type}") # service-dependent options @@ -303,7 +327,7 @@ class PrimaiteGame: if application_type in application_types_mapping: new_node.software_manager.install(application_types_mapping[application_type]) new_application = new_node.software_manager.software[application_type] - game.ref_map_applications[application_ref] = new_application + game.ref_map_applications[application_ref] = new_application.uuid else: print(f"application type not found {application_type}") @@ -326,11 +350,7 @@ class PrimaiteGame: net.add_node(new_node) new_node.power_on() - game.ref_map_nodes[ - node_ref - ] = ( - new_node.uuid - ) # TODO: fix inconsistency with service and link. Node gets added by uuid, but service by object + game.ref_map_nodes[node_ref] = new_node.uuid # 2. create links between nodes for link_cfg in links_cfg: @@ -375,7 +395,7 @@ class PrimaiteGame: for application_option in action_node_option["applications"]: # TODO: fix inconsistency with node uuids and application uuids. The node object get added to # node_uuid, whereas here the application gets added by uuid. - application_uuid = game.ref_map_applications[application_option["application_ref"]].uuid + application_uuid = game.ref_map_applications[application_option["application_ref"]] node_application_uuids.append(application_uuid) action_space_cfg["options"]["application_uuids"].append(node_application_uuids) @@ -433,5 +453,7 @@ class PrimaiteGame: print("agent type not found") game._simulation_initial_state = deepcopy(game.simulation) # noqa + web_server = game.simulation.network.get_node_by_hostname("web_server").software_manager.software["WebServer"] + print(f"And this is the ID of the original WebServer {id(web_server)}") return game diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 29d3a05c..0717f813 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1005,6 +1005,9 @@ class Node(SimComponent): return rm + def reset_component_for_episode(self, episode: int): + self._init_request_manager() + def _install_system_software(self): """Install System Software - software that is usually provided with the OS.""" pass diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index 0a9c7fc3..ef9ac0e7 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -43,6 +43,13 @@ class WebBrowser(Application): return rm + def do_this(self): + self._init_request_manager() + print(f"Resetting WebBrowser for episode") + + def reset_component_for_episode(self, episode: int): + pass + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of the WebBrowser. diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 5dda82d5..86a4e4f1 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -17,7 +17,16 @@ from primaite.simulator.system.services.service import Service class WebServer(Service): """Class used to represent a Web Server Service in simulation.""" - last_response_status_code: Optional[HttpStatusCode] = None + _last_response_status_code: Optional[HttpStatusCode] = None + + @property + def last_response_status_code(self) -> HttpStatusCode: + return self._last_response_status_code + + @last_response_status_code.setter + def last_response_status_code(self, val: Any): + print(f"val: {val}, type: {type(val)}") + self._last_response_status_code = val def describe_state(self) -> Dict: """ @@ -29,10 +38,17 @@ class WebServer(Service): :rtype: Dict """ state = super().describe_state() - state["last_response_status_code"] = ( self.last_response_status_code.value if isinstance(self.last_response_status_code, HttpStatusCode) else None ) + + print( + f"" + f"Printing state from Webserver describe func: " + f"val={state['last_response_status_code']}, " + f"type={type(state['last_response_status_code'])}, " + f"Service obj ID={id(self)}" + ) return state def __init__(self, **kwargs): @@ -85,7 +101,14 @@ class WebServer(Service): # return true if response is OK self.last_response_status_code = response.status_code - print(self.last_response_status_code) + + print( + f"" + f"Printing state from Webserver http request func: " + f"val={self.last_response_status_code}, " + f"type={type(self.last_response_status_code)}, " + f"Service obj ID={id(self)}" + ) return response.status_code == HttpStatusCode.OK def _handle_get_request(self, payload: HttpRequestPacket) -> HttpResponsePacket: From 58e9033a4c8729290e56d9d2b601ab291521b65c Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Mon, 27 Nov 2023 23:01:56 +0000 Subject: [PATCH 29/35] #1859 - First pass at an implementation of the full reset method. Will now start testing... --- src/primaite/game/agent/actions.py | 1 - src/primaite/game/agent/rewards.py | 1 - src/primaite/game/game.py | 42 +--- src/primaite/simulator/core.py | 20 +- src/primaite/simulator/domain/account.py | 13 ++ src/primaite/simulator/file_system/file.py | 12 ++ .../simulator/file_system/file_system.py | 30 +++ .../file_system/file_system_item_abc.py | 5 + src/primaite/simulator/file_system/folder.py | 38 ++++ src/primaite/simulator/network/container.py | 14 ++ .../simulator/network/hardware/base.py | 195 ++++++++---------- .../network/hardware/nodes/router.py | 31 +++ src/primaite/simulator/sim_container.py | 10 +- .../system/applications/application.py | 15 +- .../system/applications/database_client.py | 7 + .../system/applications/web_browser.py | 23 +-- .../simulator/system/core/packet_capture.py | 9 +- .../simulator/system/core/session_manager.py | 5 + src/primaite/simulator/system/core/sys_log.py | 7 +- .../simulator/system/processes/process.py | 6 + .../services/database/database_service.py | 17 ++ .../system/services/dns/dns_client.py | 20 +- .../system/services/dns/dns_server.py | 14 +- .../simulator/system/services/service.py | 15 +- .../system/services/web_server/web_server.py | 21 +- src/primaite/simulator/system/software.py | 29 ++- 26 files changed, 360 insertions(+), 240 deletions(-) diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 62e56c6e..c70d4d66 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -15,7 +15,6 @@ from typing import Dict, List, Optional, Tuple, TYPE_CHECKING from gymnasium import spaces from primaite import getLogger -from primaite.simulator.sim_container import Simulation _LOGGER = getLogger(__name__) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 7cca9116..3466114c 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -25,7 +25,6 @@ the structure: service_ref: web_server_database_client ``` """ -import json from abc import abstractmethod from typing import Dict, List, Tuple, Type, TYPE_CHECKING diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 147ed499..38e9d5fc 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -1,5 +1,4 @@ """PrimAITE game - Encapsulates the simulation and agents.""" -from copy import deepcopy from ipaddress import IPv4Address from typing import Dict, List @@ -11,7 +10,7 @@ from primaite.game.agent.data_manipulation_bot import DataManipulationAgent from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction -from primaite.simulator.network.hardware.base import Link, NIC, Node, NodeOperatingState +from primaite.simulator.network.hardware.base import NIC, NodeOperatingState from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.server import Server @@ -19,7 +18,6 @@ from primaite.simulator.network.hardware.nodes.switch import Switch from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation -from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.database.database_service import DatabaseService @@ -28,7 +26,6 @@ from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot -from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.web_server.web_server import WebServer _LOGGER = getLogger(__name__) @@ -59,10 +56,6 @@ class PrimaiteGame: """Initialise a PrimaiteGame object.""" self.simulation: Simulation = Simulation() """Simulation object with which the agents will interact.""" - print(f"Hello, welcome to PrimaiteGame. This is the ID of the ORIGINAL simulation {id(self.simulation)}") - - self._simulation_initial_state = None - """The Simulation original state (deepcopy of the original Simulation).""" self.agents: List[AbstractAgent] = [] """List of agents.""" @@ -161,34 +154,7 @@ class PrimaiteGame: self.episode_counter += 1 self.step_counter = 0 _LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}") - self.simulation = deepcopy(self._simulation_initial_state) - self._reset_components_for_episode() - print("Reset") - - def _reset_components_for_episode(self): - print("Performing full reset for episode") - for node in self.simulation.network.nodes.values(): - print(f"Resetting Node: {node.hostname}") - node.reset_component_for_episode(self.episode_counter) - - # reset Node NIC - - # Reset Node Services - - # Reset Node Applications - print(f"Resetting Software...") - for application in node.software_manager.software.values(): - print(f"Resetting {application.name}") - if isinstance(application, WebBrowser): - application.do_this() - - # Reset Node FileSystem - # Reset Node FileSystemFolder's - # Reset Node FileSystemFile's - - # Reset Router - - # Reset Links + self.simulation.reset_component_for_episode(episode=self.episode_counter) def close(self) -> None: """Close the game, this will close the simulation.""" @@ -452,8 +418,6 @@ class PrimaiteGame: else: print("agent type not found") - game._simulation_initial_state = deepcopy(game.simulation) # noqa - web_server = game.simulation.network.get_node_by_hostname("web_server").software_manager.software["WebServer"] - print(f"And this is the ID of the original WebServer {id(web_server)}") + game.simulation.set_original_state() return game diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 9ead877e..18a470cd 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -153,6 +153,8 @@ class SimComponent(BaseModel): uuid: str """The component UUID.""" + _original_state: Dict = {} + def __init__(self, **kwargs): if not kwargs.get("uuid"): kwargs["uuid"] = str(uuid4()) @@ -160,6 +162,16 @@ class SimComponent(BaseModel): self._request_manager: RequestManager = self._init_request_manager() self._parent: Optional["SimComponent"] = None + # @abstractmethod + def set_original_state(self): + """Sets the original state.""" + pass + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + for key, value in self._original_state.items(): + self.__setattr__(key, value) + def _init_request_manager(self) -> RequestManager: """ Initialise the request manager for this component. @@ -227,14 +239,6 @@ class SimComponent(BaseModel): """ pass - def reset_component_for_episode(self, episode: int): - """ - Reset this component to its original state for a new episode. - - Override this method with anything that needs to happen within the component for it to be reset. - """ - pass - @property def parent(self) -> "SimComponent": """Reference to the parent object which manages this object. diff --git a/src/primaite/simulator/domain/account.py b/src/primaite/simulator/domain/account.py index d235c00e..1402a474 100644 --- a/src/primaite/simulator/domain/account.py +++ b/src/primaite/simulator/domain/account.py @@ -42,6 +42,19 @@ class Account(SimComponent): "Account Type, currently this can be service account (used by apps) or user account." enabled: bool = True + def set_original_state(self): + """Sets the original state.""" + vals_to_include = { + "num_logons", + "num_logoffs", + "num_group_changes", + "username", + "password", + "account_type", + "enabled", + } + self._original_state = self.model_dump(include=vals_to_include) + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. diff --git a/src/primaite/simulator/file_system/file.py b/src/primaite/simulator/file_system/file.py index d9b02e8e..8f0abb3c 100644 --- a/src/primaite/simulator/file_system/file.py +++ b/src/primaite/simulator/file_system/file.py @@ -73,6 +73,18 @@ class File(FileSystemItemABC): self.sys_log.info(f"Created file /{self.path} (id: {self.uuid})") + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"folder_id", "folder_name", "file_type", "sim_size", "real", "sim_path", "sim_root"} + self._original_state.update(self.model_dump(include=vals_to_include)) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + super().reset_component_for_episode(episode) + @property def path(self) -> str: """ diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index 41f02270..dc6f01a3 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -35,6 +35,36 @@ class FileSystem(SimComponent): if not self.folders: self.create_folder("root") + def set_original_state(self): + """Sets the original state.""" + for folder in self.folders.values(): + folder.set_original_state() + super().set_original_state() + # Capture a list of all 'original' file uuids + self._original_state["original_folder_uuids"] = list(self.folders.keys()) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + # Move any 'original' folder that have been deleted back to folders + original_folder_uuids = self._original_state.pop("original_folder_uuids") + for uuid in original_folder_uuids: + if uuid in self.deleted_folders: + self.folders[uuid] = self.deleted_folders.pop(uuid) + + # Clear any other deleted folders that aren't original (have been created by agent) + self.deleted_folders.clear() + + # Now clear all non-original folders created by agent + current_folder_uuids = list(self.folders.keys()) + for uuid in current_folder_uuids: + if uuid not in original_folder_uuids: + self.folders.pop(uuid) + + # Now reset all remaining folders + for folder in self.folders.values(): + folder.reset_component_for_episode(episode) + super().reset_component_for_episode(episode) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/file_system/file_system_item_abc.py b/src/primaite/simulator/file_system/file_system_item_abc.py index fbe5f4b3..86cd1ee7 100644 --- a/src/primaite/simulator/file_system/file_system_item_abc.py +++ b/src/primaite/simulator/file_system/file_system_item_abc.py @@ -85,6 +85,11 @@ class FileSystemItemABC(SimComponent): deleted: bool = False "If true, the FileSystemItem was deleted." + def set_original_state(self): + """Sets the original state.""" + vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red"} + self._original_state = self.model_dump(include=vals_to_keep) + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index f0d55ef8..8e577097 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -51,6 +51,44 @@ class Folder(FileSystemItemABC): self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})") + def set_original_state(self): + """Sets the original state.""" + for file in self.files.values(): + file.set_original_state() + super().set_original_state() + vals_to_include = { + "scan_duration", + "scan_countdown", + "red_scan_duration", + "red_scan_countdown", + "restore_duration", + "restore_countdown", + } + self._original_state.update(self.model_dump(include=vals_to_include)) + self._original_state["original_file_uuids"] = list(self.files.keys()) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + # Move any 'original' file that have been deleted back to files + original_file_uuids = self._original_state.pop("original_file_uuids") + for uuid in original_file_uuids: + if uuid in self.deleted_files: + self.files[uuid] = self.deleted_files.pop(uuid) + + # Clear any other deleted files that aren't original (have been created by agent) + self.deleted_files.clear() + + # Now clear all non-original files created by agent + current_file_uuids = list(self.files.keys()) + for uuid in current_file_uuids: + if uuid not in original_file_uuids: + self.files.pop(uuid) + + # Now reset all remaining files + for file in self.files.values(): + file.reset_component_for_episode(episode) + super().reset_component_for_episode(episode) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request( diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index 9fbafc29..cab983c7 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -43,6 +43,20 @@ class Network(SimComponent): self._nx_graph = MultiGraph() + def set_original_state(self): + """Sets the original state.""" + for node in self.nodes.values(): + node.set_original_state() + for link in self.links.values(): + link.set_original_state() + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + for node in self.nodes.values(): + node.reset_component_for_episode(episode) + for link in self.links.values(): + link.reset_component_for_episode(episode) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() self._node_request_manager = RequestManager() diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 0717f813..2863dd22 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -121,6 +121,20 @@ class NIC(SimComponent): _LOGGER.error(msg) raise ValueError(msg) + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"} + self._original_state = self.model_dump(include=vals_to_include) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + super().reset_component_for_episode(episode) + if episode and self.pcap: + self.pcap.current_episode = episode + self.pcap.setup_logger() + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -308,6 +322,14 @@ class SwitchPort(SimComponent): kwargs["mac_address"] = generate_mac_address() super().__init__(**kwargs) + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + vals_to_include = {"port_num", "mac_address", "speed", "mtu", "enabled"} + self._original_state = self.model_dump(include=vals_to_include) + super().set_original_state() + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -454,6 +476,14 @@ class Link(SimComponent): self.endpoint_b.connect_link(self) self.endpoint_up() + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + vals_to_include = {"bandwidth", "current_load"} + self._original_state = self.model_dump(include=vals_to_include) + super().set_original_state() + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -536,15 +566,6 @@ class Link(SimComponent): return True return False - def reset_component_for_episode(self, episode: int): - """ - Link reset function. - - Reset: - - returns the link current_load to 0. - """ - self.current_load = 0 - def __str__(self) -> str: return f"{self.endpoint_a}<-->{self.endpoint_b}" @@ -584,6 +605,10 @@ class ARPCache: ) print(table) + def clear(self): + """Clears the arp cache.""" + self.arp.clear() + def add_arp_cache_entry(self, ip_address: IPv4Address, mac_address: str, nic: NIC, override: bool = False): """ Add an ARP entry to the cache. @@ -756,6 +781,10 @@ class ICMP: self.arp: ARPCache = arp_cache self.request_replies = {} + def clear(self): + """Clears the ICMP request replies tracker.""" + self.request_replies.clear() + def process_icmp(self, frame: Frame, from_nic: NIC, is_reattempt: bool = False): """ Process an ICMP packet, including handling echo requests and replies. @@ -972,6 +1001,55 @@ class Node(SimComponent): self.arp.nics = self.nics self.session_manager.software_manager = self.software_manager self._install_system_software() + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + for software in self.software_manager.software.values(): + software.set_original_state() + + for nic in self.nics.values(): + nic.set_original_state() + + vals_to_include = { + "hostname", + "default_gateway", + "operating_state", + "revealed_to_red", + "start_up_duration", + "start_up_countdown", + "shut_down_duration", + "shut_down_countdown", + "is_resetting", + "node_scan_duration", + "node_scan_countdown", + "red_scan_countdown", + } + self._original_state = self.model_dump(include=vals_to_include) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + # Reset ARP Cache + self.arp.clear() + + # Reset ICMP + self.icmp.clear() + + # Reset Session Manager + self.session_manager.clear() + + for software in self.software_manager.software.values(): + software.reset_component_for_episode(episode) + + # Reset all Nics + for nic in self.nics.values(): + nic.reset_component_for_episode(episode) + + if episode and self.sys_log: + self.sys_log.current_episode = episode + self.sys_log.setup_logger() + + super().reset_component_for_episode(episode) def _init_request_manager(self) -> RequestManager: # TODO: I see that this code is really confusing and hard to read right now... I think some of these things will @@ -1005,9 +1083,6 @@ class Node(SimComponent): return rm - def reset_component_for_episode(self, episode: int): - self._init_request_manager() - def _install_system_software(self): """Install System Software - software that is usually provided with the OS.""" pass @@ -1425,99 +1500,3 @@ class Node(SimComponent): if isinstance(item, Service): return item.uuid in self.services return None - - -class Switch(Node): - """A class representing a Layer 2 network switch.""" - - num_ports: int = 24 - "The number of ports on the switch." - switch_ports: Dict[int, SwitchPort] = {} - "The SwitchPorts on the switch." - mac_address_table: Dict[str, SwitchPort] = {} - "A MAC address table mapping destination MAC addresses to corresponding SwitchPorts." - - def __init__(self, **kwargs): - super().__init__(**kwargs) - if not self.switch_ports: - self.switch_ports = {i: SwitchPort() for i in range(1, self.num_ports + 1)} - for port_num, port in self.switch_ports.items(): - port._connected_node = self - port.parent = self - port.port_num = port_num - - def show(self): - """Prints a table of the SwitchPorts on the Switch.""" - table = PrettyTable(["Port", "MAC Address", "Speed", "Status"]) - - for port_num, port in self.switch_ports.items(): - table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"]) - print(table) - - def describe_state(self) -> Dict: - """ - Produce a dictionary describing the current state of this object. - - Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation. - - :return: Current state of this object and child objects. - :rtype: Dict - """ - return { - "uuid": self.uuid, - "num_ports": self.num_ports, # redundant? - "ports": {port_num: port.describe_state() for port_num, port in self.switch_ports.items()}, - "mac_address_table": {mac: port for mac, port in self.mac_address_table.items()}, - } - - def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort): - mac_table_port = self.mac_address_table.get(mac_address) - if not mac_table_port: - self.mac_address_table[mac_address] = switch_port - self.sys_log.info(f"Added MAC table entry: Port {switch_port.port_num} -> {mac_address}") - else: - if mac_table_port != switch_port: - self.mac_address_table.pop(mac_address) - self.sys_log.info(f"Removed MAC table entry: Port {mac_table_port.port_num} -> {mac_address}") - self._add_mac_table_entry(mac_address, switch_port) - - def forward_frame(self, frame: Frame, incoming_port: SwitchPort): - """ - Forward a frame to the appropriate port based on the destination MAC address. - - :param frame: The Frame to be forwarded. - :param incoming_port: The port number from which the frame was received. - """ - src_mac = frame.ethernet.src_mac_addr - dst_mac = frame.ethernet.dst_mac_addr - self._add_mac_table_entry(src_mac, incoming_port) - - outgoing_port = self.mac_address_table.get(dst_mac) - if outgoing_port or dst_mac != "ff:ff:ff:ff:ff:ff": - outgoing_port.send_frame(frame) - else: - # If the destination MAC is not in the table, flood to all ports except incoming - for port in self.switch_ports.values(): - if port != incoming_port: - port.send_frame(frame) - - def disconnect_link_from_port(self, link: Link, port_number: int): - """ - Disconnect a given link from the specified port number on the switch. - - :param link: The Link object to be disconnected. - :param port_number: The port number on the switch from where the link should be disconnected. - :raise NetworkError: When an invalid port number is provided or the link does not match the connection. - """ - port = self.switch_ports.get(port_number) - if port is None: - msg = f"Invalid port number {port_number} on the switch" - _LOGGER.error(msg) - raise NetworkError(msg) - - if port._connected_link != link: - msg = f"The link does not match the connection at port number {port_number}" - _LOGGER.error(msg) - raise NetworkError(msg) - - port.disconnect_link() diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index c2a38aba..8e03cfa3 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -52,6 +52,11 @@ class ACLRule(SimComponent): rule_strings.append(f"{key}={value}") return ", ".join(rule_strings) + def set_original_state(self): + """Sets the original state.""" + vals_to_keep = {"action", "protocol", "src_ip_address", "src_port", "dst_ip_address", "dst_port"} + self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True) + def describe_state(self) -> Dict: """ Describes the current state of the ACLRule. @@ -93,6 +98,18 @@ class AccessControlList(SimComponent): super().__init__(**kwargs) self._acl = [None] * (self.max_acl_rules - 1) + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + self.implicit_rule.set_original_state() + vals_to_keep = {"implicit_action", "max_acl_rules", "acl"} + self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + self.implicit_rule.reset_component_for_episode(episode) + super().reset_component_for_episode(episode) def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() @@ -638,6 +655,20 @@ class Router(Node): self.arp.nics = self.nics self.icmp.arp = self.arp + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + self.acl.set_original_state() + vals_to_include = {"num_ports", "route_table"} + self._original_state = self.model_dump(include=vals_to_include) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + self.arp.clear() + self.acl.reset_component_for_episode(episode) + super().reset_component_for_episode(episode) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request("acl", RequestType(func=self.acl._request_manager)) diff --git a/src/primaite/simulator/sim_container.py b/src/primaite/simulator/sim_container.py index 8e820ec8..c529ed04 100644 --- a/src/primaite/simulator/sim_container.py +++ b/src/primaite/simulator/sim_container.py @@ -9,7 +9,7 @@ class Simulation(SimComponent): """Top-level simulation object which holds a reference to all other parts of the simulation.""" network: Network - domain: DomainController + # domain: DomainController def __init__(self, **kwargs): """Initialise the Simulation.""" @@ -21,6 +21,14 @@ class Simulation(SimComponent): super().__init__(**kwargs) + def set_original_state(self): + """Sets the original state.""" + self.network.set_original_state() + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + self.network.reset_component_for_episode(episode) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() # pass through network requests to the network objects diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 9a58c98a..c69f745d 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -38,6 +38,12 @@ class Application(IOSoftware): self.health_state_visible = SoftwareHealthState.UNUSED self.health_state_actual = SoftwareHealthState.UNUSED + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"operating_state", "execution_control_status", "num_executions", "groups"} + self._original_state.update(self.model_dump(include=vals_to_include)) + @abstractmethod def describe_state(self) -> Dict: """ @@ -82,15 +88,6 @@ class Application(IOSoftware): self.sys_log.info(f"Installing Application {self.name}") self.operating_state = ApplicationOperatingState.INSTALLING - def reset_component_for_episode(self, episode: int): - """ - Resets the Application component for a new episode. - - This method ensures the Application is ready for a new episode, including resetting any - stateful properties or statistics, and clearing any message queues. - """ - pass - def receive(self, payload: Any, session_id: str, **kwargs) -> bool: """ Receives a payload from the SessionManager. diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 37236e69..12dfc0ac 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -31,6 +31,13 @@ class DatabaseClient(Application): kwargs["port"] = Port.POSTGRES_SERVER kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) + self.set_original_state() + + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"server_ip_address", "server_password", "connected"} + self._original_state.update(self.model_dump(include=vals_to_include)) def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index ef9ac0e7..32dd9cd2 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -33,8 +33,15 @@ class WebBrowser(Application): kwargs["port"] = Port.HTTP super().__init__(**kwargs) + self.set_original_state() self.run() + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"target_url", "domain_name_ip_address", "latest_response"} + self._original_state.update(self.model_dump(include=vals_to_include)) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request( @@ -43,13 +50,6 @@ class WebBrowser(Application): return rm - def do_this(self): - self._init_request_manager() - print(f"Resetting WebBrowser for episode") - - def reset_component_for_episode(self, episode: int): - pass - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of the WebBrowser. @@ -60,14 +60,7 @@ class WebBrowser(Application): state["last_response_status_code"] = self.latest_response.status_code if self.latest_response else None def reset_component_for_episode(self, episode: int): - """ - Resets the Application component for a new episode. - - This method ensures the Application is ready for a new episode, including resetting any - stateful properties or statistics, and clearing any message queues. - """ - self.domain_name_ip_address = None - self.latest_response = None + """Reset the original state of the SimComponent.""" def get_webpage(self) -> bool: """ diff --git a/src/primaite/simulator/system/core/packet_capture.py b/src/primaite/simulator/system/core/packet_capture.py index c2faeb10..1539e024 100644 --- a/src/primaite/simulator/system/core/packet_capture.py +++ b/src/primaite/simulator/system/core/packet_capture.py @@ -34,9 +34,12 @@ class PacketCapture: "The IP address associated with the PCAP logs." self.switch_port_number = switch_port_number "The SwitchPort number." - self._setup_logger() - def _setup_logger(self): + self.current_episode: int = 1 + + self.setup_logger() + + def setup_logger(self): """Set up the logger configuration.""" log_path = self._get_log_path() @@ -75,7 +78,7 @@ class PacketCapture: def _get_log_path(self) -> Path: """Get the path for the log file.""" - root = SIM_OUTPUT.path / self.hostname + root = SIM_OUTPUT.path / f"episode_{self.current_episode}" / self.hostname root.mkdir(exist_ok=True, parents=True) return root / f"{self._logger_name}.log" diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index 360b5e73..8658f155 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -93,6 +93,11 @@ class SessionManager: """ pass + def clear(self): + """Clears the sessions.""" + self.sessions_by_key.clear() + self.sessions_by_uuid.clear() + @staticmethod def _get_session_key( frame: Frame, inbound_frame: bool = True diff --git a/src/primaite/simulator/system/core/sys_log.py b/src/primaite/simulator/system/core/sys_log.py index 7ac6df85..41ce8fee 100644 --- a/src/primaite/simulator/system/core/sys_log.py +++ b/src/primaite/simulator/system/core/sys_log.py @@ -31,9 +31,10 @@ class SysLog: :param hostname: The hostname associated with the system logs being recorded. """ self.hostname = hostname - self._setup_logger() + self.current_episode: int = 1 + self.setup_logger() - def _setup_logger(self): + def setup_logger(self): """ Configures the logger for this SysLog instance. @@ -80,7 +81,7 @@ class SysLog: :return: Path object representing the location of the log file. """ - root = SIM_OUTPUT.path / self.hostname + root = SIM_OUTPUT.path / f"episode_{self.current_episode}" / self.hostname root.mkdir(exist_ok=True, parents=True) return root / f"{self.hostname}_sys.log" diff --git a/src/primaite/simulator/system/processes/process.py b/src/primaite/simulator/system/processes/process.py index c4e94845..ad9af335 100644 --- a/src/primaite/simulator/system/processes/process.py +++ b/src/primaite/simulator/system/processes/process.py @@ -24,6 +24,12 @@ class Process(Software): operating_state: ProcessOperatingState "The current operating state of the Process." + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"operating_state"} + self._original_state.update(self.model_dump(include=vals_to_include)) + @abstractmethod def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index d7277e1e..616cbedd 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -38,6 +38,23 @@ class DatabaseService(Service): self._db_file: File self._create_db_file() + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = { + "password", + "connections", + "backup_server", + "latest_backup_directory", + "latest_backup_file_name", + } + self._original_state.update(self.model_dump(include=vals_to_include)) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + self.connections.clear() + super().reset_component_for_episode(episode) + def configure_backup(self, backup_server: IPv4Address): """ Set up the database backup. diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index 266ac4f6..c6c3e09a 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -29,6 +29,17 @@ class DNSClient(Service): super().__init__(**kwargs) self.start() + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"dns_server"} + self._original_state.update(self.model_dump(include=vals_to_include)) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + self.dns_cache.clear() + super().reset_component_for_episode(episode) + def describe_state(self) -> Dict: """ Describes the current state of the software. @@ -42,15 +53,6 @@ class DNSClient(Service): state = super().describe_state() return state - def reset_component_for_episode(self, episode: int): - """ - Resets the Service component for a new episode. - - This method ensures the Service is ready for a new episode, including resetting any - stateful properties or statistics, and clearing any message queues. - """ - pass - def add_domain_to_cache(self, domain_name: str, ip_address: IPv4Address): """ Adds a domain name to the DNS Client cache. diff --git a/src/primaite/simulator/system/services/dns/dns_server.py b/src/primaite/simulator/system/services/dns/dns_server.py index 90a350c8..bbeaa62c 100644 --- a/src/primaite/simulator/system/services/dns/dns_server.py +++ b/src/primaite/simulator/system/services/dns/dns_server.py @@ -28,6 +28,11 @@ class DNSServer(Service): super().__init__(**kwargs) self.start() + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + self.dns_table.clear() + super().reset_component_for_episode(episode) + def describe_state(self) -> Dict: """ Describes the current state of the software. @@ -62,15 +67,6 @@ class DNSServer(Service): """ self.dns_table[domain_name] = domain_ip_address - def reset_component_for_episode(self, episode: int): - """ - Resets the Service component for a new episode. - - This method ensures the Service is ready for a new episode, including resetting any - stateful properties or statistics, and clearing any message queues. - """ - pass - def receive( self, payload: Any, diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index e2b04c15..d519da8e 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -46,6 +46,12 @@ class Service(IOSoftware): self.health_state_visible = SoftwareHealthState.UNUSED self.health_state_actual = SoftwareHealthState.UNUSED + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"operating_state", "restart_duration", "restart_countdown"} + self._original_state.update(self.model_dump(include=vals_to_include)) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request("scan", RequestType(func=lambda request, context: self.scan())) @@ -73,15 +79,6 @@ class Service(IOSoftware): state["health_state_visible"] = self.health_state_visible return state - def reset_component_for_episode(self, episode: int): - """ - Resets the Service component for a new episode. - - This method ensures the Service is ready for a new episode, including resetting any - stateful properties or statistics, and clearing any message queues. - """ - pass - def stop(self) -> None: """Stop the service.""" if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]: diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 86a4e4f1..754aa22f 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -19,8 +19,14 @@ class WebServer(Service): _last_response_status_code: Optional[HttpStatusCode] = None + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + self._last_response_status_code = None + super().reset_component_for_episode(episode) + @property def last_response_status_code(self) -> HttpStatusCode: + """The latest http response code.""" return self._last_response_status_code @last_response_status_code.setter @@ -41,14 +47,6 @@ class WebServer(Service): state["last_response_status_code"] = ( self.last_response_status_code.value if isinstance(self.last_response_status_code, HttpStatusCode) else None ) - - print( - f"" - f"Printing state from Webserver describe func: " - f"val={state['last_response_status_code']}, " - f"type={type(state['last_response_status_code'])}, " - f"Service obj ID={id(self)}" - ) return state def __init__(self, **kwargs): @@ -102,13 +100,6 @@ class WebServer(Service): # return true if response is OK self.last_response_status_code = response.status_code - print( - f"" - f"Printing state from Webserver http request func: " - f"val={self.last_response_status_code}, " - f"type={type(self.last_response_status_code)}, " - f"Service obj ID={id(self)}" - ) return response.status_code == HttpStatusCode.OK def _handle_get_request(self, payload: HttpRequestPacket) -> HttpResponsePacket: diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index f2627557..413da959 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -89,6 +89,19 @@ class Software(SimComponent): folder: Optional[Folder] = None "The folder on the file system the Software uses." + def set_original_state(self): + """Sets the original state.""" + vals_to_include = { + "name", + "health_state_actual", + "health_state_visible", + "criticality", + "patching_count", + "scanning_count", + "revealed_to_red", + } + self._original_state = self.model_dump(include=vals_to_include) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request( @@ -131,16 +144,6 @@ class Software(SimComponent): ) return state - def reset_component_for_episode(self, episode: int): - """ - Resets the software component for a new episode. - - This method should ensure the software is ready for a new episode, including resetting any - stateful properties or statistics, and clearing any message queues. The specifics of what constitutes a - "reset" should be implemented in subclasses. - """ - pass - def set_health_state(self, health_state: SoftwareHealthState) -> None: """ Assign a new health state to this software. @@ -203,6 +206,12 @@ class IOSoftware(Software): port: Port "The port to which the software is connected." + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"installing_count", "max_sessions", "tcp", "udp", "port"} + self._original_state.update(self.model_dump(include=vals_to_include)) + @abstractmethod def describe_state(self) -> Dict: """ From 39dfbb741f53fbd43c25e43cd7c5dcbad153c29e Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Tue, 28 Nov 2023 00:21:41 +0000 Subject: [PATCH 30/35] #1859 - Made some fixes to resets. Still an issue with the Router reset. --- src/primaite/simulator/network/hardware/base.py | 1 + .../simulator/network/hardware/nodes/router.py | 4 ++++ .../simulator/system/services/dns/dns_server.py | 11 +++++++++++ .../system/services/web_server/web_server.py | 3 +-- 4 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 2863dd22..09e2b12f 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -134,6 +134,7 @@ class NIC(SimComponent): if episode and self.pcap: self.pcap.current_episode = episode self.pcap.setup_logger() + self.enable() def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index 8e03cfa3..1bf2ea2f 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -667,6 +667,10 @@ class Router(Node): """Reset the original state of the SimComponent.""" self.arp.clear() self.acl.reset_component_for_episode(episode) + for i, nic in self.ethernet_ports.items(): + nic.reset_component_for_episode(episode) + self.enable_port(i) + super().reset_component_for_episode(episode) def _init_request_manager(self) -> RequestManager: diff --git a/src/primaite/simulator/system/services/dns/dns_server.py b/src/primaite/simulator/system/services/dns/dns_server.py index bbeaa62c..3b1f3bf6 100644 --- a/src/primaite/simulator/system/services/dns/dns_server.py +++ b/src/primaite/simulator/system/services/dns/dns_server.py @@ -28,10 +28,21 @@ class DNSServer(Service): super().__init__(**kwargs) self.start() + def set_original_state(self): + """Sets the original state.""" + super().set_original_state() + vals_to_include = {"dns_table"} + self._original_state["dns_table_orig"] = self.model_dump(include=vals_to_include)["dns_table"] + def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" + print("dns reset") + print("DNSServer original state", self._original_state) self.dns_table.clear() + for key, value in self._original_state["dns_table_orig"].items(): + self.dns_table[key] = value super().reset_component_for_episode(episode) + self.show() def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 754aa22f..56f47195 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -31,7 +31,6 @@ class WebServer(Service): @last_response_status_code.setter def last_response_status_code(self, val: Any): - print(f"val: {val}, type: {type(val)}") self._last_response_status_code = val def describe_state(self) -> Dict: @@ -47,6 +46,7 @@ class WebServer(Service): state["last_response_status_code"] = ( self.last_response_status_code.value if isinstance(self.last_response_status_code, HttpStatusCode) else None ) + print(state) return state def __init__(self, **kwargs): @@ -99,7 +99,6 @@ class WebServer(Service): # return true if response is OK self.last_response_status_code = response.status_code - return response.status_code == HttpStatusCode.OK def _handle_get_request(self, payload: HttpRequestPacket) -> HttpResponsePacket: From 37663c941d4ba4345216548e6f47fc0d58abf987 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Tue, 28 Nov 2023 00:51:48 +0000 Subject: [PATCH 31/35] #1859 - Added route table reset, still not working --- .../network/hardware/nodes/router.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index 1bf2ea2f..667cf2bf 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -354,6 +354,11 @@ class RouteEntry(SimComponent): kwargs[key] = IPv4Address(kwargs[key]) super().__init__(**kwargs) + def set_original_state(self): + """Sets the original state.""" + vals_to_include = {"address", "subnet_mask", "next_hop_ip_address", "metric"} + self._original_values = self.model_dump(include=vals_to_include) + def describe_state(self) -> Dict: """ Describes the current state of the RouteEntry. @@ -385,6 +390,18 @@ class RouteTable(SimComponent): routes: List[RouteEntry] = [] sys_log: SysLog + def set_original_state(self): + """Sets the original state.""" + """Sets the original state.""" + super().set_original_state() + self._original_state["routes_orig"] = self.routes + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + self.routes.clear() + self.routes = self._original_state["routes_orig"] + super().reset_component_for_episode(episode) + def describe_state(self) -> Dict: """ Describes the current state of the RouteTable. @@ -660,13 +677,15 @@ class Router(Node): def set_original_state(self): """Sets the original state.""" self.acl.set_original_state() - vals_to_include = {"num_ports", "route_table"} + self.route_table.set_original_state() + vals_to_include = {"num_ports"} self._original_state = self.model_dump(include=vals_to_include) def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" self.arp.clear() self.acl.reset_component_for_episode(episode) + self.route_table.reset_component_for_episode(episode) for i, nic in self.ethernet_ports.items(): nic.reset_component_for_episode(episode) self.enable_port(i) @@ -765,6 +784,7 @@ class Router(Node): dst_ip_address=dst_ip_address, dst_port=dst_port, ) + if not permitted: at_port = self._get_port_of_nic(from_nic) self.sys_log.info(f"Frame blocked at port {at_port} by rule {rule}") From 517f99b04b9e8c2792ad70768a5e3bfa65f9e88a Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Tue, 28 Nov 2023 09:45:45 +0000 Subject: [PATCH 32/35] #1859 - Added the call to file system reset --- src/primaite/simulator/network/hardware/base.py | 7 +++++++ src/primaite/simulator/network/hardware/nodes/router.py | 1 + 2 files changed, 8 insertions(+) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 09e2b12f..cb159b8b 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1009,6 +1009,8 @@ class Node(SimComponent): for software in self.software_manager.software.values(): software.set_original_state() + self.file_system.set_original_state() + for nic in self.nics.values(): nic.set_original_state() @@ -1039,13 +1041,18 @@ class Node(SimComponent): # Reset Session Manager self.session_manager.clear() + # Reset software for software in self.software_manager.software.values(): software.reset_component_for_episode(episode) + # Reset File System + self.file_system.reset_component_for_episode(episode) + # Reset all Nics for nic in self.nics.values(): nic.reset_component_for_episode(episode) + # if episode and self.sys_log: self.sys_log.current_episode = episode self.sys_log.setup_logger() diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index 667cf2bf..34b92a07 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -818,6 +818,7 @@ class Router(Node): nic.ip_address = ip_address nic.subnet_mask = subnet_mask self.sys_log.info(f"Configured port {port} with ip_address={ip_address}/{nic.ip_network.prefixlen}") + self.set_original_state() def enable_port(self, port: int): """ From b0399195bbddfce87d6b9c032462c2944a920232 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 27 Nov 2023 22:20:44 +0000 Subject: [PATCH 33/35] Fix software manager usage in uc2 network func --- src/primaite/simulator/network/networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index 446e5649..b7bd2e95 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -157,7 +157,7 @@ def arcd_uc2_network() -> Network: operating_state=NodeOperatingState.ON, ) client_2.power_on() - web_browser = client_2.software_manager["WebBrowser"] + web_browser = client_2.software_manager.software["WebBrowser"] web_browser.target_url = "http://arcd.com/users/" network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2]) From 3df3e113d1320b1eed9d7f76a3591a80d7e68c02 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 27 Nov 2023 22:24:30 +0000 Subject: [PATCH 34/35] Change data manipulation test to use the right func --- .../e2e_integration_tests/test_uc2_data_manipulation_scenario.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py index fe7bab5f..81bbfc96 100644 --- a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py +++ b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py @@ -23,7 +23,6 @@ def test_data_manipulation(uc2_network): # Now we run the DataManipulationBot db_manipulation_bot.run() - db_manipulation_bot.attack() # Now check that the DB client on the web_server cannot query the users table on the database assert not db_client.query("SELECT") From 2de1d02c48805efaed3dfd26d21d56fcc12a7263 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 27 Nov 2023 22:55:00 +0000 Subject: [PATCH 35/35] Fix app install logic --- src/primaite/simulator/system/applications/application.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 4fe7a5e1..898e5917 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -108,9 +108,6 @@ class Application(IOSoftware): def install(self) -> None: """Install Application.""" - if self._can_perform_action(): - return - super().install() if self.operating_state == ApplicationOperatingState.CLOSED: self.sys_log.info(f"Installing Application {self.name}")