diff --git a/.azure/azure-build-deploy-docs-pipeline.yml b/.azure/azure-build-deploy-docs-pipeline.yml index f60840a7..d9926ba7 100644 --- a/.azure/azure-build-deploy-docs-pipeline.yml +++ b/.azure/azure-build-deploy-docs-pipeline.yml @@ -29,17 +29,6 @@ jobs: pip install -e .[dev] displayName: 'Install PrimAITE for docs autosummary' - - script: | - GATE_WHEEL=$(ls ./GATE/arcd_gate*.whl) - python -m pip install $GATE_WHEEL[dev] - displayName: 'Install GATE' - condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' )) - - - script: | - forfiles /p GATE\ /m *.whl /c "cmd /c python -m pip install @file[dev]" - displayName: 'Install GATE' - condition: eq( variables['Agent.OS'], 'Windows_NT' ) - - script: | primaite setup displayName: 'Perform PrimAITE Setup' diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index efeba284..9070270a 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -81,17 +81,6 @@ stages: displayName: 'Install PrimAITE' condition: eq( variables['Agent.OS'], 'Windows_NT' ) - - script: | - GATE_WHEEL=$(ls ./GATE/arcd_gate*.whl) - python -m pip install $GATE_WHEEL[dev] - displayName: 'Install GATE' - condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' )) - - - script: | - forfiles /p GATE\ /m *.whl /c "cmd /c python -m pip install @file[dev]" - displayName: 'Install GATE' - condition: eq( variables['Agent.OS'], 'Windows_NT' ) - - script: | primaite setup displayName: 'Perform PrimAITE Setup' diff --git a/README.md b/README.md index 9ec8164e..7fc41681 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,6 @@ python3 -m venv .venv attrib +h .venv /s /d # Hides the .venv directory .\.venv\Scripts\activate pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl -pip install GATE/arcd_gate-0.1.0-py3-none-any.whl primaite setup ``` @@ -75,7 +74,6 @@ cd ~/primaite python3 -m venv .venv source .venv/bin/activate pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl -pip install arcd_gate-0.1.0-py3-none-any.whl primaite setup ``` @@ -120,7 +118,6 @@ source venv/bin/activate ```bash python3 -m pip install -e .[dev] -pip install arcd_gate-0.1.0-py3-none-any.whl ``` #### 6. Perform the PrimAITE setup: diff --git a/docs/source/simulation_components/system/data_manipulation_bot.rst b/docs/source/simulation_components/system/data_manipulation_bot.rst index e93c4e54..03f2208b 100644 --- a/docs/source/simulation_components/system/data_manipulation_bot.rst +++ b/docs/source/simulation_components/system/data_manipulation_bot.rst @@ -53,7 +53,7 @@ Example network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1]) client_1.software_manager.install(DataManipulationBot) data_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"] - data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DROP TABLE IF EXISTS user;") + data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE") data_manipulation_bot.run() This would connect to the database service at 192.168.1.14, authenticate, and execute the SQL statement to drop the 'users' table. diff --git a/docs/source/simulation_components/system/database_client_server.rst b/docs/source/simulation_components/system/database_client_server.rst index 53687f60..4bef335b 100644 --- a/docs/source/simulation_components/system/database_client_server.rst +++ b/docs/source/simulation_components/system/database_client_server.rst @@ -14,10 +14,10 @@ The ``DatabaseService`` provides a SQL database server simulation by extending t Key capabilities ^^^^^^^^^^^^^^^^ -- Initialises a SQLite database file in the ``Node`` 's ``FileSystem`` upon creation. +- Creates a database file in the ``Node`` 's ``FileSystem`` upon creation. - Handles connecting clients by maintaining a dictionary of connections mapped to session IDs. - Authenticates connections using a configurable password. -- Executes SQL queries against the SQLite database. +- Simulates ``SELECT`` and ``DELETE`` SQL queries. - Returns query results and status codes back to clients. - Leverages the Service base class for install/uninstall, status tracking, etc. @@ -30,10 +30,9 @@ Usage Implementation ^^^^^^^^^^^^^^ -- Uses SQLite for persistent storage. - Creates the database file within the node's file system. - Manages client connections in a dictionary by session ID. -- Processes SQL queries via the SQLite cursor and connection. +- Processes SQL queries. - Returns results and status codes in a standard dictionary format. - Extends Service class for integration with ``SoftwareManager``. diff --git a/pyproject.toml b/pyproject.toml index 51ed84f2..92f78ec0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ "stable-baselines3[extra]==2.1.0", "tensorflow==2.12.0", "typer[all]==0.9.0", - "pydantic==2.1.1" + "pydantic==2.1.1", ] [tool.setuptools.dynamic] diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index 30fc9ab9..28245d33 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -29,6 +29,15 @@ class _PrimaitePaths: def __init__(self) -> None: self._dirs: Final[PlatformDirs] = PlatformDirs(appname="primaite", version=__version__) + self.user_home_path = self.generate_user_home_path() + self.user_sessions_path = self.generate_user_sessions_path() + self.user_config_path = self.generate_user_config_path() + self.user_notebooks_path = self.generate_user_notebooks_path() + self.app_home_path = self.generate_app_home_path() + self.app_config_dir_path = self.generate_app_config_dir_path() + self.app_config_file_path = self.generate_app_config_file_path() + self.app_log_dir_path = self.generate_app_log_dir_path() + self.app_log_file_path = self.generate_app_log_file_path() def _get_dirs_properties(self) -> List[str]: class_items = self.__class__.__dict__.items() @@ -43,55 +52,47 @@ class _PrimaitePaths: for p in self._get_dirs_properties(): getattr(self, p) - @property - def user_home_path(self) -> Path: + def generate_user_home_path(self) -> Path: """The PrimAITE user home path.""" path = Path.home() / "primaite" / __version__ path.mkdir(exist_ok=True, parents=True) return path - @property - def user_sessions_path(self) -> Path: + def generate_user_sessions_path(self) -> Path: """The PrimAITE user sessions path.""" path = self.user_home_path / "sessions" path.mkdir(exist_ok=True, parents=True) return path - @property - def user_config_path(self) -> Path: + def generate_user_config_path(self) -> Path: """The PrimAITE user config path.""" path = self.user_home_path / "config" path.mkdir(exist_ok=True, parents=True) return path - @property - def user_notebooks_path(self) -> Path: + def generate_user_notebooks_path(self) -> Path: """The PrimAITE user notebooks path.""" path = self.user_home_path / "notebooks" path.mkdir(exist_ok=True, parents=True) return path - @property - def app_home_path(self) -> Path: + def generate_app_home_path(self) -> Path: """The PrimAITE app home path.""" path = self._dirs.user_data_path path.mkdir(exist_ok=True, parents=True) return path - @property - def app_config_dir_path(self) -> Path: + def generate_app_config_dir_path(self) -> Path: """The PrimAITE app config directory path.""" path = self._dirs.user_config_path path.mkdir(exist_ok=True, parents=True) return path - @property - def app_config_file_path(self) -> Path: + def generate_app_config_file_path(self) -> Path: """The PrimAITE app config file path.""" return self.app_config_dir_path / "primaite_config.yaml" - @property - def app_log_dir_path(self) -> Path: + def generate_app_log_dir_path(self) -> Path: """The PrimAITE app log directory path.""" if sys.platform == "win32": path = self.app_home_path / "logs" @@ -100,8 +101,7 @@ class _PrimaitePaths: path.mkdir(exist_ok=True, parents=True) return path - @property - def app_log_file_path(self) -> Path: + def generate_app_log_file_path(self) -> Path: """The PrimAITE app log file path.""" return self.app_log_dir_path / "primaite.log" @@ -133,6 +133,7 @@ def _get_primaite_config() -> Dict: "DEBUG": logging.DEBUG, "INFO": logging.INFO, "WARN": logging.WARN, + "WARNING": logging.WARN, "ERROR": logging.ERROR, "CRITICAL": logging.CRITICAL, } diff --git a/src/primaite/cli.py b/src/primaite/cli.py index a5b3be46..81ab2792 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -95,8 +95,6 @@ def setup(overwrite_existing: bool = True) -> None: WARNING: All user-data will be lost. """ - from arcd_gate.cli import setup as gate_setup - from primaite import getLogger from primaite.setup import reset_demo_notebooks, reset_example_configs @@ -115,15 +113,13 @@ def setup(overwrite_existing: bool = True) -> None: _LOGGER.info("Rebuilding the example notebooks...") reset_example_configs.run(overwrite_existing=True) - _LOGGER.info("Setting up ARCD GATE...") - gate_setup() - _LOGGER.info("PrimAITE setup complete!") @app.command() def session( config: Optional[str] = None, + agent_load_file: Optional[str] = None, ) -> None: """ Run a PrimAITE session. @@ -131,16 +127,10 @@ def session( :param config: The path to the config file. Optional, if None, the example config will be used. :type config: Optional[str] """ - from threading import Thread - from primaite.config.load import example_config_path from primaite.main import run - from primaite.utils.start_gate_server import start_gate_server - - server_thread = Thread(target=start_gate_server) - server_thread.start() if not config: config = example_config_path() print(config) - run(config_path=config) + run(config_path=config, agent_load_path=agent_load_file) diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 8ea1c83c..270760f5 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -2,10 +2,17 @@ training_config: rl_framework: SB3 rl_algorithm: PPO seed: 333 - n_learn_episodes: 20 - n_learn_steps: 128 - n_eval_episodes: 20 - n_eval_steps: 128 + n_learn_episodes: 25 + n_eval_episodes: 5 + max_steps_per_episode: 128 + deterministic_eval: false + n_agents: 1 + agent_references: + - defender + +io_settings: + save_checkpoints: true + checkpoint_interval: 5 game_config: @@ -107,7 +114,7 @@ game_config: - ref: defender team: BLUE - type: GATERLAgent + type: ProxyAgent observation_space: type: UC2BlueObservation diff --git a/src/primaite/game/agent/GATE_agents.py b/src/primaite/game/agent/GATE_agents.py deleted file mode 100644 index e4ee16ca..00000000 --- a/src/primaite/game/agent/GATE_agents.py +++ /dev/null @@ -1,31 +0,0 @@ -# flake8: noqa -from typing import Dict, Optional, Tuple - -from gymnasium.core import ActType, ObsType - -from primaite.game.agent.actions import ActionManager -from primaite.game.agent.interface import AbstractGATEAgent, ObsType -from primaite.game.agent.observations import ObservationSpace -from primaite.game.agent.rewards import RewardFunction - - -class GATERLAgent(AbstractGATEAgent): - ... - # The communication with GATE needs to be handled by the PrimaiteSession, rather than by individual agents, - # because when we are supporting MARL, the actions form multiple agents will have to be batched - - # For example MultiAgentEnv in Ray allows sending a dict of observations of multiple agents, then it will reply - # with the actions for those agents. - - def __init__( - self, - agent_name: Optional[str], - action_space: Optional[ActionManager], - observation_space: Optional[ObservationSpace], - reward_function: Optional[RewardFunction], - ) -> None: - super().__init__(agent_name, action_space, observation_space, reward_function) - self.most_recent_action: ActType - - def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: - return self.most_recent_action diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 33932df2..ff0986a8 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -4,10 +4,11 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, TypeAlias, Union import numpy as np +from gymnasium.core import ActType, ObsType from pydantic import BaseModel from primaite.game.agent.actions import ActionManager -from primaite.game.agent.observations import ObservationSpace +from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction if TYPE_CHECKING: @@ -55,7 +56,7 @@ class AbstractAgent(ABC): self, agent_name: Optional[str], action_space: Optional[ActionManager], - observation_space: Optional[ObservationSpace], + observation_space: Optional[ObservationManager], reward_function: Optional[RewardFunction], agent_settings: Optional[AgentSettings], ) -> None: @@ -72,21 +73,21 @@ class AbstractAgent(ABC): :type reward_function: Optional[RewardFunction] """ self.agent_name: str = agent_name or "unnamed_agent" - self.action_space: Optional[ActionManager] = action_space - self.observation_space: Optional[ObservationSpace] = observation_space + self.action_manager: Optional[ActionManager] = action_space + self.observation_manager: Optional[ObservationManager] = observation_space self.reward_function: Optional[RewardFunction] = reward_function self.agent_settings = agent_settings or AgentSettings() - def convert_state_to_obs(self, state: Dict) -> ObsType: + def update_observation(self, state: Dict) -> ObsType: """ Convert a state from the simulator into an observation for the agent using the observation space. state : dict state directly from simulation.describe_state output : dict state according to CAOS. """ - return self.observation_space.observe(state) + return self.observation_manager.update(state) - def calculate_reward_from_state(self, state: Dict) -> float: + def update_reward(self, state: Dict) -> float: """ Use the reward function to calculate a reward from the state. @@ -95,10 +96,10 @@ class AbstractAgent(ABC): :return: Reward from the state. :rtype: float """ - return self.reward_function.calculate(state) + return self.reward_function.update(state) @abstractmethod - def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: """ Return an action to be taken in the environment. @@ -111,7 +112,7 @@ class AbstractAgent(ABC): :return: Action to be taken in the environment. :rtype: Tuple[str, Dict] """ - # in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39, + # in RL agent, this method will send CAOS observation to RL agent, then receive a int 0-39, # then use a bespoke conversion to take 1-40 int back into CAOS action return ("DO_NOTHING", {}) @@ -119,7 +120,7 @@ class AbstractAgent(ABC): # this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator. # therefore the execution definition needs to be a mapping from CAOS into SIMULATOR """Format action into format expected by the simulator, and apply execution definition if applicable.""" - request = self.action_space.form_request(action_identifier=action, action_options=options) + request = self.action_manager.form_request(action_identifier=action, action_options=options) return request @@ -132,7 +133,7 @@ class AbstractScriptedAgent(AbstractAgent): class RandomAgent(AbstractScriptedAgent): """Agent that ignores its observation and acts completely at random.""" - def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: """Randomly sample an action from the action space. :param obs: _description_ @@ -142,7 +143,47 @@ class RandomAgent(AbstractScriptedAgent): :return: _description_ :rtype: Tuple[str, Dict] """ - return self.action_space.get_action(self.action_space.space.sample()) + return self.action_manager.get_action(self.action_manager.space.sample()) + + +class ProxyAgent(AbstractAgent): + """Agent that sends observations to an RL model and receives actions from that model.""" + + def __init__( + self, + agent_name: Optional[str], + action_space: Optional[ActionManager], + observation_space: Optional[ObservationManager], + reward_function: Optional[RewardFunction], + ) -> None: + super().__init__( + agent_name=agent_name, + action_space=action_space, + observation_space=observation_space, + reward_function=reward_function, + ) + self.most_recent_action: ActType + + def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: + """ + Return the agent's most recent action, formatted in CAOS format. + + :param obs: Observation for the agent. Not used by ProxyAgents, but required by the interface. + :type obs: ObsType + :param reward: Reward value for the agent. Not used by ProxyAgents, defaults to None. + :type reward: float, optional + :return: Action to be taken in CAOS format. + :rtype: Tuple[str, Dict] + """ + return self.action_manager.get_action(self.most_recent_action) + + def store_action(self, action: ActType): + """ + Store the most recent action taken by the agent. + + The environment is responsible for calling this method when it receives an action from the agent policy. + """ + self.most_recent_action = action class DataManipulationAgent(AbstractScriptedAgent): diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index a3bafeea..a74771c0 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING from gymnasium import spaces +from gymnasium.core import ObsType from primaite import getLogger from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE @@ -926,7 +927,7 @@ class UC2GreenObservation(NullObservation): pass -class ObservationSpace: +class ObservationManager: """ Manage the observations of an Agent. @@ -947,15 +948,17 @@ class ObservationSpace: :type observation: AbstractObservation """ self.obs: AbstractObservation = observation + self.current_observation: ObsType - def observe(self, state: Dict) -> Dict: + def update(self, state: Dict) -> Dict: """ Generate observation based on the current state of the simulation. :param state: Simulation state dictionary :type state: Dict """ - return self.obs.observe(state) + self.current_observation = self.obs.observe(state) + return self.current_observation @property def space(self) -> None: @@ -963,7 +966,7 @@ class ObservationSpace: return self.obs.space @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationSpace": + def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationManager": """Create observation space from a config. :param config: Dictionary containing the configuration for this observation space. diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 6c408ff9..da1331b0 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -26,7 +26,7 @@ the structure: ``` """ from abc import abstractmethod -from typing import Dict, List, Tuple, TYPE_CHECKING +from typing import Dict, List, Tuple, Type, TYPE_CHECKING from primaite import getLogger from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE @@ -228,7 +228,7 @@ class WebServer404Penalty(AbstractReward): class RewardFunction: """Manages the reward function for the agent.""" - __rew_class_identifiers: Dict[str, type[AbstractReward]] = { + __rew_class_identifiers: Dict[str, Type[AbstractReward]] = { "DUMMY": DummyReward, "DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity, "WEB_SERVER_404_PENALTY": WebServer404Penalty, @@ -238,6 +238,7 @@ class RewardFunction: """Initialise the reward function object.""" self.reward_components: List[Tuple[AbstractReward, float]] = [] "attribute reward_components keeps track of reward components and the weights assigned to each." + self.current_reward: float def regsiter_component(self, component: AbstractReward, weight: float = 1.0) -> None: """Add a reward component to the reward function. @@ -249,7 +250,7 @@ class RewardFunction: """ self.reward_components.append((component, weight)) - def calculate(self, state: Dict) -> float: + def update(self, state: Dict) -> float: """Calculate the overall reward for the current state. :param state: The current state of the simulation. @@ -260,7 +261,8 @@ class RewardFunction: comp = comp_and_weight[0] weight = comp_and_weight[1] total += weight * comp.calculate(state=state) - return total + self.current_reward = total + return self.current_reward @classmethod def from_config(cls, config: Dict, session: "PrimaiteSession") -> "RewardFunction": diff --git a/src/primaite/game/io.py b/src/primaite/game/io.py new file mode 100644 index 00000000..e0b849c9 --- /dev/null +++ b/src/primaite/game/io.py @@ -0,0 +1,63 @@ +from datetime import datetime +from pathlib import Path +from typing import Optional + +from pydantic import BaseModel, ConfigDict + +from primaite import PRIMAITE_PATHS +from primaite.simulator import SIM_OUTPUT + + +class SessionIOSettings(BaseModel): + """Schema for session IO settings.""" + + model_config = ConfigDict(extra="forbid") + + save_final_model: bool = True + """Whether to save the final model right at the end of training.""" + save_checkpoints: bool = False + """Whether to save a checkpoint model every `checkpoint_interval` episodes""" + checkpoint_interval: int = 10 + """How often to save a checkpoint model (if save_checkpoints is True).""" + save_logs: bool = True + """Whether to save logs""" + save_transactions: bool = True + """Whether to save transactions, If true, the session path will have a transactions folder.""" + save_tensorboard_logs: bool = False + """Whether to save tensorboard logs. If true, the session path will have a tenorboard_logs folder.""" + + +class SessionIO: + """ + Class for managing session IO. + + Currently it's handling path generation, but could expand to handle loading, transaction, tensorboard, and so on. + """ + + def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None: + self.settings: SessionIOSettings = settings + self.session_path: Path = self.generate_session_path() + + # set global SIM_OUTPUT path + SIM_OUTPUT.path = self.session_path / "simulation_output" + + # warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's + # possible refactor needed + + def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path: + """Create a folder for the session and return the path to it.""" + if timestamp is None: + timestamp = datetime.now() + date_str = timestamp.strftime("%Y-%m-%d") + time_str = timestamp.strftime("%H-%M-%S") + session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str + session_path.mkdir(exist_ok=True, parents=True) + return session_path + + def generate_model_save_path(self, agent_name: str) -> Path: + """Return the path where the final model will be saved (excluding filename extension).""" + return self.session_path / "checkpoints" / f"{agent_name}_final" + + def generate_checkpoint_save_path(self, agent_name: str, episode: int) -> Path: + """Return the path where the checkpoint model will be saved (excluding filename extension).""" + return self.session_path / "checkpoints" / f"{agent_name}_checkpoint_{episode}.pt" diff --git a/src/primaite/game/policy/__init__.py b/src/primaite/game/policy/__init__.py new file mode 100644 index 00000000..29196112 --- /dev/null +++ b/src/primaite/game/policy/__init__.py @@ -0,0 +1,3 @@ +from primaite.game.policy.sb3 import SB3Policy + +__all__ = ["SB3Policy"] diff --git a/src/primaite/game/policy/policy.py b/src/primaite/game/policy/policy.py new file mode 100644 index 00000000..249c3b52 --- /dev/null +++ b/src/primaite/game/policy/policy.py @@ -0,0 +1,84 @@ +"""Base class and common logic for RL policies.""" +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, Dict, Type, TYPE_CHECKING + +if TYPE_CHECKING: + from primaite.game.session import PrimaiteSession, TrainingOptions + + +class PolicyABC(ABC): + """Base class for reinforcement learning agents.""" + + _registry: Dict[str, Type["PolicyABC"]] = {} + """ + Registry of policy types, keyed by name. + + Automatically populated when PolicyABC subclasses are defined. Used for defining from_config. + """ + + def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: + """ + Register a policy subclass. + + :param name: Identifier used by from_config to create an instance of the policy. + :type name: str + :raises ValueError: When attempting to create a policy with a duplicate name. + """ + super().__init_subclass__(**kwargs) + if identifier in cls._registry: + raise ValueError(f"Duplicate policy name {identifier}") + cls._registry[identifier] = cls + return + + @abstractmethod + def __init__(self, session: "PrimaiteSession") -> None: + """ + Initialize a reinforcement learning policy. + + :param session: The session context. + :type session: PrimaiteSession + :param agents: The agents to train. + :type agents: List[RLAgent] + """ + self.session: "PrimaiteSession" = session + """Reference to the session.""" + + @abstractmethod + def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: + """Train the agent.""" + pass + + @abstractmethod + def eval(self, n_episodes: int, timesteps_per_episode: int, deterministic: bool) -> None: + """Evaluate the agent.""" + pass + + @abstractmethod + def save(self, save_path: Path) -> None: + """Save the agent.""" + pass + + @abstractmethod + def load(self) -> None: + """Load agent from a file.""" + pass + + def close(self) -> None: + """Close the agent.""" + pass + + @classmethod + def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC": + """ + Create an RL policy from a config by calling the relevant subclass's from_config method. + + Subclasses should not call super().from_config(), they should just handle creation form config. + """ + # Assume that basically the contents of training_config are passed into here. + # I should really define a config schema class using pydantic. + + PolicyType = cls._registry[config.rl_framework] + return PolicyType.from_config(config=config, session=session) + + # saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py new file mode 100644 index 00000000..a4870054 --- /dev/null +++ b/src/primaite/game/policy/sb3.py @@ -0,0 +1,80 @@ +"""Stable baselines 3 policy.""" +from pathlib import Path +from typing import Literal, Optional, Type, TYPE_CHECKING, Union + +from stable_baselines3 import A2C, PPO +from stable_baselines3.a2c import MlpPolicy as A2C_MLP +from stable_baselines3.common.callbacks import CheckpointCallback +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.ppo import MlpPolicy as PPO_MLP + +from primaite.game.policy.policy import PolicyABC + +if TYPE_CHECKING: + from primaite.game.session import PrimaiteSession, TrainingOptions + + +class SB3Policy(PolicyABC, identifier="SB3"): + """Single agent RL policy using stable baselines 3.""" + + def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): + """Initialize a stable baselines 3 policy.""" + super().__init__(session=session) + + self._agent_class: Type[Union[PPO, A2C]] + if algorithm == "PPO": + self._agent_class = PPO + policy = PPO_MLP + elif algorithm == "A2C": + self._agent_class = A2C + policy = A2C_MLP + else: + raise ValueError(f"Unknown algorithm `{algorithm}` for stable_baselines3 policy") + self._agent = self._agent_class( + policy=policy, + env=self.session.env, + n_steps=128, # this is not the number of steps in an episode, but the number of steps in a batch + seed=seed, + ) + + def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: + """Train the agent.""" + if self.session.io_manager.settings.save_checkpoints: + checkpoint_callback = CheckpointCallback( + save_freq=timesteps_per_episode * self.session.io_manager.settings.checkpoint_interval, + save_path=self.session.io_manager.generate_model_save_path("sb3"), + name_prefix="sb3_model", + ) + else: + checkpoint_callback = None + self._agent.learn(total_timesteps=n_episodes * timesteps_per_episode, callback=checkpoint_callback) + + def eval(self, n_episodes: int, deterministic: bool) -> None: + """Evaluate the agent.""" + reward_data = evaluate_policy( + self._agent, + self.session.env, + n_eval_episodes=n_episodes, + deterministic=deterministic, + return_episode_rewards=True, + ) + print(reward_data) + + def save(self, save_path: Path) -> None: + """ + Save the current policy parameters. + + Warning: The recommended way to save model checkpoints is to use a callback within the `learn()` method. Please + refer to https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html for more information. + Therefore, this method is only used to save the final model. + """ + self._agent.save(save_path) + + def load(self, model_path: Path) -> None: + """Load agent from a checkpoint.""" + self._agent = self._agent_class.load(model_path, env=self.session.env) + + @classmethod + def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy": + """Create an agent from config file.""" + return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed) diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 286de498..7856cc9f 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -1,18 +1,21 @@ """PrimAITE session - the main entry point to training agents on PrimAITE.""" +from copy import deepcopy +from enum import Enum from ipaddress import IPv4Address -from typing import Any, Dict, List, Optional, Tuple +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple -from arcd_gate.client.gate_client import ActType, GATEClient -from gymnasium import spaces +import gymnasium from gymnasium.core import ActType, ObsType -from gymnasium.spaces.utils import flatten, flatten_space -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from primaite import getLogger from primaite.game.agent.actions import ActionManager -from primaite.game.agent.interface import AbstractAgent, AgentSettings, DataManipulationAgent, RandomAgent -from primaite.game.agent.observations import ObservationSpace +from primaite.game.agent.interface import AbstractAgent, AgentSettings, DataManipulationAgent, ProxyAgent, RandomAgent +from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction +from primaite.game.io import SessionIO, SessionIOSettings +from primaite.game.policy.policy import PolicyABC from primaite.simulator.network.hardware.base import Link, NIC, Node from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.router import ACLAction, Router @@ -34,109 +37,62 @@ from primaite.simulator.system.services.web_server.web_server import WebServer _LOGGER = getLogger(__name__) -class PrimaiteGATEClient(GATEClient): - """Lightweight wrapper around the GATEClient class that allows PrimAITE to message GATE.""" +class PrimaiteGymEnv(gymnasium.Env): + """ + Thin wrapper env to provide agents with a gymnasium API. - def __init__(self, parent_session: "PrimaiteSession", service_port: int = 50000): - """ - Create a new GATE client for PrimAITE. + This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some + assumptions about the agent list always having a list of length 1. + """ - :param parent_session: The parent session object. - :type parent_session: PrimaiteSession - :param service_port: The port on which the GATE service is running. - :type service_port: int, optional - """ - super().__init__(service_port=service_port) - self.parent_session: "PrimaiteSession" = parent_session + def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]): + """Initialise the environment.""" + super().__init__() + self.session: "PrimaiteSession" = session + self.agent: ProxyAgent = agents[0] - @property - def rl_framework(self) -> str: - """The reinforcement learning framework to use.""" - return self.parent_session.training_options.rl_framework + def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: + """Perform a step in the environment.""" + # make ProxyAgent store the action chosen my the RL policy + self.agent.store_action(action) + # apply_agent_actions accesses the action we just stored + self.session.apply_agent_actions() + self.session.advance_timestep() + state = self.session.get_sim_state() + self.session.update_agents(state) - @property - def rl_algorithm(self) -> str: - """The reinforcement learning algorithm to use.""" - return self.parent_session.training_options.rl_algorithm - - @property - def seed(self) -> Optional[int]: - """The seed to use for the environment's random number generator.""" - return self.parent_session.training_options.seed - - @property - def n_learn_episodes(self) -> int: - """The number of episodes in each learning run.""" - return self.parent_session.training_options.n_learn_episodes - - @property - def n_learn_steps(self) -> int: - """The number of steps in each learning episode.""" - return self.parent_session.training_options.n_learn_steps - - @property - def n_eval_episodes(self) -> int: - """The number of episodes in each evaluation run.""" - return self.parent_session.training_options.n_eval_episodes - - @property - def n_eval_steps(self) -> int: - """The number of steps in each evaluation episode.""" - return self.parent_session.training_options.n_eval_steps - - @property - def action_space(self) -> spaces.Space: - """The gym action space of the agent.""" - return self.parent_session.rl_agent.action_space.space - - @property - def observation_space(self) -> spaces.Space: - """The gymnasium observation space of the agent.""" - return flatten_space(self.parent_session.rl_agent.observation_space.space) - - def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, Dict]: - """Take a step in the environment. - - This method is called by GATE to advance the simulation by one timestep. - - :param action: The agent's action. - :type action: ActType - :return: The observation, reward, terminal flag, truncated flag, and info dictionary. - :rtype: Tuple[ObsType, float, bool, bool, Dict] - """ - self.parent_session.rl_agent.most_recent_action = action - self.parent_session.step() - state = self.parent_session.simulation.describe_state() - obs = self.parent_session.rl_agent.observation_space.observe(state) - obs = flatten(self.parent_session.rl_agent.observation_space.space, obs) - rew = self.parent_session.rl_agent.reward_function.calculate(state) - term = False - trunc = False + next_obs = self._get_obs() + reward = self.agent.reward_function.current_reward + terminated = False + truncated = self.session.calculate_truncated() info = {} - return obs, rew, term, trunc, info - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) -> Tuple[ObsType, Dict]: - """Reset the environment. + return next_obs, reward, terminated, truncated, info - This method is called when the environment is initialized and at the end of each episode. + def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: + """Reset the environment.""" + self.session.reset() + state = self.session.get_sim_state() + self.session.update_agents(state) + next_obs = self._get_obs() + info = {} + return next_obs, info - :param seed: The seed to use for the environment's random number generator. - :type seed: int, optional - :param options: Additional options for the reset. None are used by PrimAITE but this is included for - compatibility with GATE. - :type options: dict[str, Any], optional - :return: The initial observation and an empty info dictionary. - :rtype: Tuple[ObsType, Dict] - """ - self.parent_session.reset() - state = self.parent_session.simulation.describe_state() - obs = self.parent_session.rl_agent.observation_space.observe(state) - obs = flatten(self.parent_session.rl_agent.observation_space.space, obs) - return obs, {} + @property + def action_space(self) -> gymnasium.Space: + """Return the action space of the environment.""" + return self.agent.action_manager.space - def close(self): - """Close the session, this will stop the gate client and close the simulation.""" - self.parent_session.close() + @property + def observation_space(self) -> gymnasium.Space: + """Return the observation space of the environment.""" + return gymnasium.spaces.flatten_space(self.agent.observation_manager.space) + + def _get_obs(self) -> ObsType: + """Return the current observation.""" + unflat_space = self.agent.observation_manager.space + unflat_obs = self.agent.observation_manager.current_observation + return gymnasium.spaces.flatten(unflat_space, unflat_obs) class PrimaiteSessionOptions(BaseModel): @@ -146,6 +102,8 @@ class PrimaiteSessionOptions(BaseModel): Currently this is used to restrict which ports and protocols exist in the world of the simulation. """ + model_config = ConfigDict(extra="forbid") + ports: List[str] protocols: List[str] @@ -153,48 +111,105 @@ class PrimaiteSessionOptions(BaseModel): class TrainingOptions(BaseModel): """Options for training the RL agent.""" - rl_framework: str - rl_algorithm: str - seed: Optional[int] + model_config = ConfigDict(extra="forbid") + + rl_framework: Literal["SB3", "RLLIB"] + rl_algorithm: Literal["PPO", "A2C"] n_learn_episodes: int - n_learn_steps: int - n_eval_episodes: int - n_eval_steps: int + n_eval_episodes: Optional[int] = None + max_steps_per_episode: int + # checkpoint_freq: Optional[int] = None + deterministic_eval: bool + seed: Optional[int] + n_agents: int + agent_references: List[str] + + +class SessionMode(Enum): + """Helper to keep track of the current session mode.""" + + TRAIN = "train" + EVAL = "eval" + MANUAL = "manual" class PrimaiteSession: - """The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and connections to ARCD GATE.""" + """The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments.""" def __init__(self): + """Initialise a PrimaiteSession object.""" self.simulation: Simulation = Simulation() """Simulation object with which the agents will interact.""" + + self._simulation_initial_state = deepcopy(self.simulation) + """The Simulation original state (deepcopy of the original Simulation).""" + self.agents: List[AbstractAgent] = [] """List of agents.""" - self.rl_agent: AbstractAgent - """The agent from the list which communicates with GATE to perform reinforcement learning.""" + + self.rl_agents: List[ProxyAgent] = [] + """Subset of agent list including only the reinforcement learning agents.""" + self.step_counter: int = 0 """Current timestep within the episode.""" + self.episode_counter: int = 0 """Current episode number.""" + self.options: PrimaiteSessionOptions """Special options that apply for the entire game.""" + self.training_options: TrainingOptions """Options specific to agent training.""" + self.policy: PolicyABC + """The reinforcement learning policy.""" + self.ref_map_nodes: Dict[str, Node] = {} """Mapping from unique node reference name to node object. Used when parsing config files.""" + self.ref_map_services: Dict[str, Service] = {} """Mapping from human-readable service reference to service object. Used for parsing config files.""" + self.ref_map_applications: Dict[str, Application] = {} """Mapping from human-readable application reference to application object. Used for parsing config files.""" + self.ref_map_links: Dict[str, Link] = {} """Mapping from human-readable link reference to link object. Used when parsing config files.""" - self.gate_client: PrimaiteGATEClient = PrimaiteGATEClient(self) - """Reference to a GATE Client object, which will send data to GATE service for training RL agent.""" + + self.env: PrimaiteGymEnv + """The environment that the agent can consume. Could be PrimaiteEnv.""" + + self.mode: SessionMode = SessionMode.MANUAL + """Current session mode.""" + + self.io_manager = SessionIO() + """IO manager for the session.""" def start_session(self) -> None: - """Commence the training session, this gives the GATE client control over the simulation/agent loop.""" - self.gate_client.start() + """Commence the training session.""" + self.mode = SessionMode.TRAIN + n_learn_episodes = self.training_options.n_learn_episodes + n_eval_episodes = self.training_options.n_eval_episodes + max_steps_per_episode = self.training_options.max_steps_per_episode + + deterministic_eval = self.training_options.deterministic_eval + self.policy.learn( + n_episodes=n_learn_episodes, + timesteps_per_episode=max_steps_per_episode, + ) + self.save_models() + + self.mode = SessionMode.EVAL + if n_eval_episodes > 0: + self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval) + + self.mode = SessionMode.MANUAL + + def save_models(self) -> None: + """Save the RL models.""" + save_path = self.io_manager.generate_model_save_path("temp_model_name") + self.policy.save(save_path) def step(self): """ @@ -208,57 +223,76 @@ class PrimaiteSession: 4. Each agent chooses an action based on the observation. 5. Each agent converts the action to a request. 6. The simulation applies the requests. + + Warning: This method should only be used with scripted agents. For RL agents, the environment that the agent + interacts with should implement a step method that calls methods used by this method. For example, if using a + single-agent gym, make sure to update the ProxyAgent's action with the action before calling + ``self.apply_agent_actions()``. """ _LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}") - # currently designed with assumption that all agents act once per step in order + # Get the current state of the simulation + sim_state = self.get_sim_state() + + # Update agents' observations and rewards based on the current state + self.update_agents(sim_state) + + # Apply all actions to simulation as requests + self.apply_agent_actions() + + # Advance timestep + self.advance_timestep() + + def get_sim_state(self) -> Dict: + """Get the current state of the simulation.""" + return self.simulation.describe_state() + + def update_agents(self, state: Dict) -> None: + """Update agents' observations and rewards based on the current state.""" for agent in self.agents: - # 3. primaite session asks simulation to provide initial state - # 4. primate session gives state to all agents - # 5. primaite session asks agents to produce an action based on most recent state - _LOGGER.debug(f"Sending simulation state to agent {agent.agent_name}") - sim_state = self.simulation.describe_state() + agent.update_observation(state) + agent.update_reward(state) - # 6. each agent takes most recent state and converts it to CAOS observation - agent_obs = agent.convert_state_to_obs(sim_state) + def apply_agent_actions(self) -> None: + """Apply all actions to simulation as requests.""" + for agent in self.agents: + obs = agent.observation_manager.current_observation + rew = agent.reward_function.current_reward + action_choice, options = agent.get_action(obs, rew) + request = agent.format_request(action_choice, options) + self.simulation.apply_request(request) - # 7. meanwhile each agent also takes state and calculates reward - agent_reward = agent.calculate_reward_from_state(sim_state) - - # 8. each agent takes observation and applies decision rule to observation to create CAOS - # action(such as random, rulebased, or send to GATE) (therefore, converting CAOS action - # to discrete(40) is only necessary for purposes of RL learning, therefore that bit of - # code should live inside of the GATE agent subclass) - # gets action in CAOS format - _LOGGER.debug("Getting agent action") - agent_action, action_options = agent.get_action(agent_obs, agent_reward) - # 9. CAOS action is converted into request (extra information might be needed to enrich - # the request, this is what the execution definition is there for) - _LOGGER.debug(f"Formatting agent action {agent_action}") # maybe too many debug log statements - agent_request = agent.format_request(agent_action, action_options) - - # 10. primaite session receives the action from the agents and asks the simulation to apply each - _LOGGER.debug(f"Sending request to simulation: {agent_request}") - self.simulation.apply_request(agent_request) - - _LOGGER.debug(f"Initiating simulation step {self.step_counter}") - self.simulation.apply_timestep(self.step_counter) + def advance_timestep(self) -> None: + """Advance timestep.""" self.step_counter += 1 + _LOGGER.debug(f"Advancing timestep to {self.step_counter} ") + self.simulation.apply_timestep(self.step_counter) + + def calculate_truncated(self) -> bool: + """Calculate whether the episode is truncated.""" + current_step = self.step_counter + max_steps = self.training_options.max_steps_per_episode + if current_step >= max_steps: + return True + return False def reset(self) -> None: """Reset the session, this will reset the simulation.""" - return NotImplemented + self.episode_counter += 1 + self.step_counter = 0 + _LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}") + self.simulation = deepcopy(self._simulation_initial_state) def close(self) -> None: - """Close the session, this will stop the gate client and close the simulation.""" + """Close the session, this will stop the env and close the simulation.""" return NotImplemented @classmethod - def from_config(cls, cfg: dict) -> "PrimaiteSession": + def from_config(cls, cfg: dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession": """Create a PrimaiteSession object from a config dictionary. The config dictionary should have the following top-level keys: - 1. training_config: options for training the RL agent. Used by GATE. + 1. training_config: options for training the RL agent. 2. game_config: options for the game itself. Used by PrimaiteSession. 3. simulation: defines the network topology and the initial state of the simulation. @@ -276,6 +310,11 @@ class PrimaiteSession: protocols=cfg["game_config"]["protocols"], ) sess.training_options = TrainingOptions(**cfg["training_config"]) + + # READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS... + io_settings = cfg.get("io_settings", {}) + sess.io_manager.settings = SessionIOSettings(**io_settings) + sim = sess.simulation net = sim.network @@ -425,7 +464,7 @@ class PrimaiteSession: reward_function_cfg = agent_cfg["reward_function"] # CREATE OBSERVATION SPACE - obs_space = ObservationSpace.from_config(observation_space_cfg, sess) + obs_space = ObservationManager.from_config(observation_space_cfg, sess) # CREATE ACTION SPACE action_space_cfg["options"]["node_uuids"] = [] @@ -479,16 +518,15 @@ class PrimaiteSession: agent_settings=agent_settings, ) sess.agents.append(new_agent) - elif agent_type == "GATERLAgent": - new_agent = RandomAgent( + elif agent_type == "ProxyAgent": + new_agent = ProxyAgent( agent_name=agent_cfg["ref"], action_space=action_space, observation_space=obs_space, reward_function=rew_function, - agent_settings=agent_settings, ) sess.agents.append(new_agent) - sess.rl_agent = new_agent + sess.rl_agents.append(new_agent) elif agent_type == "RedDatabaseCorruptingAgent": new_agent = DataManipulationAgent( agent_name=agent_cfg["ref"], @@ -501,4 +539,14 @@ class PrimaiteSession: else: print("agent type not found") + # CREATE ENVIRONMENT + sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents) + + # CREATE POLICY + sess.policy = PolicyABC.from_config(sess.training_options, session=sess) + if agent_load_path: + sess.policy.load(Path(agent_load_path)) + + sess._simulation_initial_state = deepcopy(sess.simulation) # noqa + return sess diff --git a/src/primaite/main.py b/src/primaite/main.py index 831419d4..1699fe51 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -15,6 +15,7 @@ _LOGGER = getLogger(__name__) def run( config_path: Optional[Union[str, Path]] = "", + agent_load_path: Optional[Union[str, Path]] = None, ) -> None: """ Run the PrimAITE Session. @@ -31,7 +32,7 @@ def run( otherwise False. """ cfg = load(config_path) - sess = PrimaiteSession.from_config(cfg=cfg) + sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path) sess.start_session() diff --git a/src/primaite/simulator/__init__.py b/src/primaite/simulator/__init__.py index 8c55542f..19c86e28 100644 --- a/src/primaite/simulator/__init__.py +++ b/src/primaite/simulator/__init__.py @@ -1,14 +1,26 @@ +"""Warning: SIM_OUTPUT is a mutable global variable for the simulation output directory.""" from datetime import datetime +from pathlib import Path from primaite import _PRIMAITE_ROOT -SIM_OUTPUT = None -"A path at the repo root dir to use temporarily for sim output testing while in dev." -# TODO: Remove once we integrate the simulation into PrimAITE and it uses the primaite session path +__all__ = ["SIM_OUTPUT"] -if not SIM_OUTPUT: - session_timestamp = datetime.now() - date_dir = session_timestamp.strftime("%Y-%m-%d") - sim_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - SIM_OUTPUT = _PRIMAITE_ROOT.parent.parent / "simulation_output" / date_dir / sim_path - SIM_OUTPUT.mkdir(exist_ok=True, parents=True) + +class __SimOutput: + def __init__(self): + self._path: Path = ( + _PRIMAITE_ROOT.parent.parent / "simulation_output" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ) + + @property + def path(self) -> Path: + return self._path + + @path.setter + def path(self, new_path: Path) -> None: + self._path = new_path + self._path.mkdir(exist_ok=True, parents=True) + + +SIM_OUTPUT = __SimOutput() diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 537cebb2..29d3a05c 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -957,7 +957,7 @@ class Node(SimComponent): if not kwargs.get("session_manager"): kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp")) if not kwargs.get("root"): - kwargs["root"] = SIM_OUTPUT / kwargs["hostname"] + kwargs["root"] = SIM_OUTPUT.path / kwargs["hostname"] if not kwargs.get("file_system"): kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs") if not kwargs.get("software_manager"): diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index 25d1bd21..c0f9a07e 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -140,7 +140,7 @@ def arcd_uc2_network() -> Network: network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1]) client_1.software_manager.install(DataManipulationBot) db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"] - db_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DROP TABLE IF EXISTS user;") + db_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE") # Client 2 client_2 = Computer( diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 9d85221e..a5c213cd 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -2,13 +2,14 @@ from ipaddress import IPv4Address from typing import Any, Dict, Optional from uuid import uuid4 -from prettytable import PrettyTable - +from primaite import getLogger from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application, ApplicationOperatingState from primaite.simulator.system.core.software_manager import SoftwareManager +_LOGGER = getLogger(__name__) + class DatabaseClient(Application): """ @@ -130,7 +131,7 @@ class DatabaseClient(Application): def execute(self) -> None: """Run the DatabaseClient.""" - # super().execute() + super().execute() if self.operating_state == ApplicationOperatingState.RUNNING: self.connect() @@ -148,21 +149,6 @@ class DatabaseClient(Application): self._query_success_tracker[query_id] = False return self._query(sql=sql, query_id=query_id) - def _print_data(self, data: Dict): - """ - Display the contents of the Folder in tabular format. - - :param markdown: Whether to display the table in Markdown format or not. Default is `False`. - """ - if data: - table = PrettyTable(list(data.values())[0]) - - table.align = "l" - table.title = f"{self.sys_log.hostname} Database Client" - for row in data.values(): - table.add_row(row.values()) - print(table) - def receive(self, payload: Any, session_id: str, **kwargs) -> bool: """ Receive a payload from the Software Manager. @@ -179,5 +165,5 @@ class DatabaseClient(Application): status_code = payload.get("status_code") self._query_success_tracker[query_id] = status_code == 200 if self._query_success_tracker[query_id]: - self._print_data(payload["data"]) + _LOGGER.debug(f"Received payload {payload}") return True diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index 964e1ce4..ea9c3ac3 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -30,7 +30,7 @@ class WebBrowser(Application): kwargs["port"] = Port.HTTP super().__init__(**kwargs) - self.execute() + self.run() def describe_state(self) -> Dict: """ @@ -135,6 +135,3 @@ class WebBrowser(Application): self.sys_log.info(f"{self.name}: Received HTTP {payload.status_code.value}") self.latest_response = payload return True - - def execute(self): - pass diff --git a/src/primaite/simulator/system/core/packet_capture.py b/src/primaite/simulator/system/core/packet_capture.py index 2e5ed008..c2faeb10 100644 --- a/src/primaite/simulator/system/core/packet_capture.py +++ b/src/primaite/simulator/system/core/packet_capture.py @@ -75,7 +75,7 @@ class PacketCapture: def _get_log_path(self) -> Path: """Get the path for the log file.""" - root = SIM_OUTPUT / self.hostname + root = SIM_OUTPUT.path / self.hostname root.mkdir(exist_ok=True, parents=True) return root / f"{self._logger_name}.log" diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index 8b8fe599..21a121c1 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -100,8 +100,16 @@ class SoftwareManager: self.node.uninstall_application(software) elif isinstance(software, Service): self.node.uninstall_service(software) + for key, value in self.port_protocol_mapping.items(): + if value.name == software_name: + self.port_protocol_mapping.pop(key) + break + for key, value in self._software_class_to_name_map.items(): + if value == software_name: + self._software_class_to_name_map.pop(key) + break del software - self.sys_log.info(f"Deleted {software_name}") + self.sys_log.info(f"Uninstalled {software_name}") return self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed") diff --git a/src/primaite/simulator/system/core/sys_log.py b/src/primaite/simulator/system/core/sys_log.py index 791e0be8..7ac6df85 100644 --- a/src/primaite/simulator/system/core/sys_log.py +++ b/src/primaite/simulator/system/core/sys_log.py @@ -41,7 +41,6 @@ class SysLog: JSON-like messages. """ log_path = self._get_log_path() - file_handler = logging.FileHandler(filename=log_path) file_handler.setLevel(logging.DEBUG) @@ -81,7 +80,7 @@ class SysLog: :return: Path object representing the location of the log file. """ - root = SIM_OUTPUT / self.hostname + root = SIM_OUTPUT.path / self.hostname root.mkdir(exist_ok=True, parents=True) return root / f"{self.hostname}_sys.log" diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index b04174bf..d7277e1e 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -1,10 +1,6 @@ -import sqlite3 from datetime import datetime from ipaddress import IPv4Address -from sqlite3 import OperationalError -from typing import Any, Dict, List, Optional, Union - -from prettytable import MARKDOWN, PrettyTable +from typing import Any, Dict, List, Literal, Optional, Union from primaite.simulator.file_system.file_system import File from primaite.simulator.network.transmission.network_layer import IPProtocol @@ -19,7 +15,7 @@ class DatabaseService(Service): """ A class for simulating a generic SQL Server service. - This class inherits from the `Service` class and provides methods to manage and query a SQLite database. + This class inherits from the `Service` class and provides methods to simulate a SQL database. """ password: Optional[str] = None @@ -41,38 +37,6 @@ class DatabaseService(Service): super().__init__(**kwargs) self._db_file: File self._create_db_file() - self._connect() - - def _connect(self): - self._conn = sqlite3.connect(self._db_file.sim_path) - self._cursor = self._conn.cursor() - - def tables(self) -> List[str]: - """ - Get a list of table names present in the database. - - :return: List of table names. - """ - sql = "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';" - results = self._process_sql(sql, None) - if isinstance(results["data"], dict): - return list(results["data"].keys()) - return [] - - def show(self, markdown: bool = False): - """ - Prints a list of table names in the database using PrettyTable. - - :param markdown: Whether to output the table in Markdown format. - """ - table = PrettyTable(["Table"]) - if markdown: - table.set_style(MARKDOWN) - table.align = "l" - table.title = f"{self.file_system.sys_log.hostname} Database" - for row in self.tables(): - table.add_row([row]) - print(table) def configure_backup(self, backup_server: IPv4Address): """ @@ -89,8 +53,6 @@ class DatabaseService(Service): self.sys_log.error(f"{self.name} - {self.sys_log.hostname}: not configured.") return False - self._conn.close() - software_manager: SoftwareManager = self.software_manager ftp_client_service: FTPClient = software_manager.software["FTPClient"] @@ -98,12 +60,10 @@ class DatabaseService(Service): response = ftp_client_service.send_file( dest_ip_address=self.backup_server, src_file_name=self._db_file.name, - src_folder_name=self._db_file.folder.name, + src_folder_name=self.folder.name, dest_folder_name=str(self.uuid), dest_file_name="database.db", - real_file_path=self._db_file.sim_path, ) - self._connect() if response: return True @@ -125,25 +85,29 @@ class DatabaseService(Service): dest_ip_address=self.backup_server, ) - if response: - self._conn.close() - # replace db file - self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db") - self.file_system.copy_file( - src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name - ) - self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db") - self._connect() + if not response: + self.sys_log.error("Unable to restore database backup.") + return False - return self._db_file is not None + # replace db file + self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db") + self.file_system.copy_file( + src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name + ) + self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db") - self.sys_log.error("Unable to restore database backup.") - return False + if self._db_file is None: + self.sys_log.error("Copying database backup failed.") + return False + + self.set_health_state(SoftwareHealthState.GOOD) + + return True def _create_db_file(self): """Creates the Simulation File and sqlite file in the file system.""" - self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db", real=True) - self.folder = self._db_file.folder + self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db") + self.folder = self.file_system.get_folder_by_id(self._db_file.folder_id) def _process_connect( self, session_id: str, password: Optional[str] = None @@ -163,31 +127,32 @@ class DatabaseService(Service): status_code = 404 # service not found return {"status_code": status_code, "type": "connect_response", "response": status_code == 200} - def _process_sql(self, query: str, query_id: str) -> Dict[str, Union[int, List[Any]]]: + def _process_sql(self, query: Literal["SELECT", "DELETE"], query_id: str) -> Dict[str, Union[int, List[Any]]]: """ Executes the given SQL query and returns the result. + Possible queries: + - SELECT : returns the data + - DELETE : deletes the data + :param query: The SQL query to be executed. :return: Dictionary containing status code and data fetched. """ self.sys_log.info(f"{self.name}: Running {query}") - try: - self._cursor.execute(query) - self._conn.commit() - except OperationalError: - # Handle the case where the table does not exist. - self.sys_log.error(f"{self.name}: Error, query failed") - return {"status_code": 404, "data": {}} - data = [] - description = self._cursor.description - if description: - headers = [] - for header in description: - headers.append(header[0]) - data = self._cursor.fetchall() - if data and headers: - data = {row[0]: {header: value for header, value in zip(headers, row)} for row in data} - return {"status_code": 200, "type": "sql", "data": data, "uuid": query_id} + if query == "SELECT": + if self.health_state_actual == SoftwareHealthState.GOOD: + return {"status_code": 200, "type": "sql", "data": True, "uuid": query_id} + else: + return {"status_code": 404, "data": False} + elif query == "DELETE": + if self.health_state_actual == SoftwareHealthState.GOOD: + self.health_state_actual = SoftwareHealthState.COMPROMISED + return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id} + else: + return {"status_code": 404, "data": False} + else: + # Invalid query + return {"status_code": 500, "data": False} def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 5957e4cb..cb1a4738 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -106,7 +106,7 @@ class WebServer(Service): # get data from DatabaseServer db_client: DatabaseClient = self.software_manager.software["DatabaseClient"] # get all users - if db_client.query("SELECT * FROM user;"): + if db_client.query("SELECT"): # query succeeded response.status_code = HttpStatusCode.OK diff --git a/src/primaite/utils/start_gate_server.py b/src/primaite/utils/start_gate_server.py deleted file mode 100644 index d91952f2..00000000 --- a/src/primaite/utils/start_gate_server.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Utility script to start the gate server for running PrimAITE in attached mode.""" -from arcd_gate.server.gate_service import GATEService - - -def start_gate_server(): - """Start the gate server.""" - service = GATEService() - service.start() - - -if __name__ == "__main__": - start_gate_server() diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml new file mode 100644 index 00000000..80567aea --- /dev/null +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -0,0 +1,725 @@ +training_config: + rl_framework: SB3 + rl_algorithm: PPO + se3ed: 333 # Purposeful typo to check that error is raised with bad configuration. + n_learn_steps: 2560 + n_eval_episodes: 5 + + + +game_config: + ports: + - ARP + - DNS + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + + agents: + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com + + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 + + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot + observations: + operating_status + health_status + folders: {} + + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com + + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 + + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot + observations: + operating_status + health_status + folders: {} + + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com + + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 + + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot + observations: + operating_status + health_status + folders: {} + + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com + + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 + + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot + observations: + operating_status + health_status + folders: {} + + action_space: + action_list: + - type: DONOTHING + # FileSystem: # PrimAITE v2 stuff -@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3 -class TempPrimaiteSession: # PrimaiteSession): +class TempPrimaiteSession(PrimaiteSession): """ A temporary PrimaiteSession class. Uses context manager for deletion of files upon exit. """ - # def __init__( - # self, - # training_config_path: Union[str, Path], - # lay_down_config_path: Union[str, Path], - # ): - # super().__init__(training_config_path, lay_down_config_path) - # self.setup() + @classmethod + def from_config(cls, config_path: Union[str, Path]) -> "TempPrimaiteSession": + """Create a temporary PrimaiteSession object from a config file.""" + config_path = Path(config_path) + with open(config_path, "r") as f: + config = yaml.safe_load(f) - # @property - # def env(self) -> Primaite: - # """Direct access to the env for ease of testing.""" - # return self._agent_session._env # noqa + return super().from_config(cfg=config) - # def __enter__(self): - # return self + def __enter__(self): + return self - # def __exit__(self, type, value, tb): - # shutil.rmtree(self.session_path) - # _LOGGER.debug(f"Deleted temp session directory: {self.session_path}") + def __exit__(self, type, value, tb): + pass -@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3 @pytest.fixture -def temp_primaite_session(request): - """ - Provides a temporary PrimaiteSession instance. - - It's temporary as it uses a temporary directory as the session path. - - To use this fixture you need to: - - - parametrize your test function with: - - - "temp_primaite_session" - - [[path to training config, path to lay down config]] - - Include the temp_primaite_session fixture as a param in your test - function. - - use the temp_primaite_session as a context manager assigning is the - name 'session'. - - .. code:: python - - from primaite.config.lay_down_config import dos_very_basic_config_path - from primaite.config.training_config import main_training_config_path - @pytest.mark.parametrize( - "temp_primaite_session", - [ - [main_training_config_path(), dos_very_basic_config_path()] - ], - indirect=True - ) - def test_primaite_session(temp_primaite_session): - with temp_primaite_session as session: - # Learning outputs are saved in session.learning_path - session.learn() - - # Evaluation outputs are saved in session.evaluation_path - session.evaluate() - - # To ensure that all files are written, you must call .close() - session.close() - - # If you need to inspect any session outputs, it must be done - # inside the context manager - - # Now that we've exited the context manager, the - # session.session_path directory and its contents are deleted - """ - training_config_path = request.param[0] - lay_down_config_path = request.param[1] - with patch("primaite.agents.agent_abc.get_session_path", get_temp_session_path) as mck: - mck.session_timestamp = datetime.now() - - return TempPrimaiteSession(training_config_path, lay_down_config_path) - - -@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3 -@pytest.fixture -def temp_session_path() -> Path: - """ - Get a temp directory session path the test session will output to. - - :return: The session directory path. - """ - session_timestamp = datetime.now() - date_dir = session_timestamp.strftime("%Y-%m-%d") - session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = Path(tempfile.gettempdir()) / "_primaite" / date_dir / session_path - session_path.mkdir(exist_ok=True, parents=True) - - return session_path +def temp_primaite_session(request, monkeypatch) -> TempPrimaiteSession: + """Create a temporary PrimaiteSession object.""" + monkeypatch.setattr(PRIMAITE_PATHS, "user_sessions_path", temp_user_sessions_path()) + config_path = request.param[0] + return TempPrimaiteSession.from_config(config_path=config_path) diff --git a/tests/e2e_integration_tests/test_primaite_session.py b/tests/e2e_integration_tests/test_primaite_session.py new file mode 100644 index 00000000..6839a190 --- /dev/null +++ b/tests/e2e_integration_tests/test_primaite_session.py @@ -0,0 +1,83 @@ +import pydantic +import pytest + +from tests import TEST_ASSETS_ROOT +from tests.conftest import TempPrimaiteSession + +CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml" +TRAINING_ONLY_PATH = TEST_ASSETS_ROOT / "configs/train_only_primaite_session.yaml" +EVAL_ONLY_PATH = TEST_ASSETS_ROOT / "configs/eval_only_primaite_session.yaml" +MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml" + + +class TestPrimaiteSession: + @pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True) + def test_creating_session(self, temp_primaite_session): + """Check that creating a session from config works.""" + with temp_primaite_session as session: + if not isinstance(session, TempPrimaiteSession): + raise AssertionError + + assert session is not None + assert session.simulation + assert len(session.agents) == 3 + assert len(session.rl_agents) == 1 + + assert session.policy + assert session.env + + assert session.simulation.network + assert len(session.simulation.network.nodes) == 10 + + @pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True) + def test_start_session(self, temp_primaite_session): + """Make sure you can go all the way through the session without errors.""" + with temp_primaite_session as session: + session: TempPrimaiteSession + session.start_session() + + session_path = session.io_manager.session_path + assert session_path.exists() + print(list(session_path.glob("*"))) + checkpoint_dir = session_path / "checkpoints" / "sb3_final" + assert checkpoint_dir.exists() + checkpoint_1 = checkpoint_dir / "sb3_model_640_steps.zip" + checkpoint_2 = checkpoint_dir / "sb3_model_1280_steps.zip" + checkpoint_3 = checkpoint_dir / "sb3_model_1920_steps.zip" + assert checkpoint_1.exists() + assert checkpoint_2.exists() + assert not checkpoint_3.exists() + + @pytest.mark.parametrize("temp_primaite_session", [[TRAINING_ONLY_PATH]], indirect=True) + def test_training_only_session(self, temp_primaite_session): + """Check that you can run a training-only session.""" + with temp_primaite_session as session: + session: TempPrimaiteSession + session.start_session() + # TODO: include checks that the model was trained, e.g. that the loss changed and checkpoints were saved? + + @pytest.mark.parametrize("temp_primaite_session", [[EVAL_ONLY_PATH]], indirect=True) + def test_eval_only_session(self, temp_primaite_session): + """Check that you can load a model and run an eval-only session.""" + with temp_primaite_session as session: + session: TempPrimaiteSession + session.start_session() + # TODO: include checks that the model was loaded and that the eval-only session ran + + def test_error_thrown_on_bad_configuration(self): + with pytest.raises(pydantic.ValidationError): + session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH) + + @pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True) + def test_session_sim_reset(self, temp_primaite_session): + with temp_primaite_session as session: + session: TempPrimaiteSession + client_1 = session.simulation.network.get_node_by_hostname("client_1") + client_1.software_manager.uninstall("DataManipulationBot") + + assert "DataManipulationBot" not in client_1.software_manager.software + + session.reset() + client_1 = session.simulation.network.get_node_by_hostname("client_1") + + assert "DataManipulationBot" in client_1.software_manager.software diff --git a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py index 13f4d1f3..81bbfc96 100644 --- a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py +++ b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py @@ -19,16 +19,16 @@ def test_data_manipulation(uc2_network): db_service.backup_database() # First check that the DB client on the web_server can successfully query the users table on the database - assert db_client.query("SELECT * FROM user;") + assert db_client.query("SELECT") # Now we run the DataManipulationBot db_manipulation_bot.run() # Now check that the DB client on the web_server cannot query the users table on the database - assert not db_client.query("SELECT * FROM user;") + assert not db_client.query("SELECT") # Now restore the database db_service.restore_backup() # Now check that the DB client on the web_server can successfully query the users table on the database - assert db_client.query("SELECT * FROM user;") + assert db_client.query("SELECT") diff --git a/tests/integration_tests/system/test_database_on_node.py b/tests/integration_tests/system/test_database_on_node.py index 92056981..027fae4a 100644 --- a/tests/integration_tests/system/test_database_on_node.py +++ b/tests/integration_tests/system/test_database_on_node.py @@ -57,7 +57,7 @@ def test_database_client_query(uc2_network): db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] db_client.connect() - assert db_client.query("SELECT * FROM user;") + assert db_client.query("SELECT") def test_create_database_backup(uc2_network): diff --git a/tests/integration_tests/system/test_web_client_server.py b/tests/integration_tests/system/test_web_client_server.py index e36cff2b..f4546cbf 100644 --- a/tests/integration_tests/system/test_web_client_server.py +++ b/tests/integration_tests/system/test_web_client_server.py @@ -10,7 +10,7 @@ def test_web_page_home_page(uc2_network): """Test to see if the browser is able to open the main page of the web server.""" client_1: Computer = uc2_network.get_node_by_hostname("client_1") web_client: WebBrowser = client_1.software_manager.software["WebBrowser"] - web_client.execute() + web_client.run() assert web_client.operating_state == ApplicationOperatingState.RUNNING assert web_client.get_webpage("http://arcd.com/") is True @@ -24,7 +24,7 @@ def test_web_page_get_users_page_request_with_domain_name(uc2_network): """Test to see if the client can handle requests with domain names""" client_1: Computer = uc2_network.get_node_by_hostname("client_1") web_client: WebBrowser = client_1.software_manager.software["WebBrowser"] - web_client.execute() + web_client.run() assert web_client.operating_state == ApplicationOperatingState.RUNNING assert web_client.get_webpage("http://arcd.com/users/") is True @@ -38,7 +38,7 @@ def test_web_page_get_users_page_request_with_ip_address(uc2_network): """Test to see if the client can handle requests that use ip_address.""" client_1: Computer = uc2_network.get_node_by_hostname("client_1") web_client: WebBrowser = client_1.software_manager.software["WebBrowser"] - web_client.execute() + web_client.run() web_server: Server = uc2_network.get_node_by_hostname("web_server") web_server_ip = web_server.nics.get(next(iter(web_server.nics))).ip_address diff --git a/tests/mock_and_patch/get_session_path_mock.py b/tests/mock_and_patch/get_session_path_mock.py index 16c4a274..06fe5893 100644 --- a/tests/mock_and_patch/get_session_path_mock.py +++ b/tests/mock_and_patch/get_session_path_mock.py @@ -9,7 +9,7 @@ from primaite import getLogger _LOGGER = getLogger(__name__) -def get_temp_session_path(session_timestamp: datetime) -> Path: +def temp_user_sessions_path() -> Path: """ Get a temp directory session path the test session will output to. diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py index 04e23e84..8a78beae 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py @@ -27,7 +27,7 @@ def test_create_dm_bot(dm_client): assert data_manipulation_bot.name == "DataManipulationBot" assert data_manipulation_bot.port == Port.POSTGRES_SERVER assert data_manipulation_bot.protocol == IPProtocol.TCP - assert data_manipulation_bot.payload == "DROP TABLE IF EXISTS user;" + assert data_manipulation_bot.payload == "DELETE" def test_dm_bot_logon(dm_bot):