diff --git a/.azure/azure-build-deploy-docs-pipeline.yml b/.azure/azure-build-deploy-docs-pipeline.yml index f60840a7..d9926ba7 100644 --- a/.azure/azure-build-deploy-docs-pipeline.yml +++ b/.azure/azure-build-deploy-docs-pipeline.yml @@ -29,17 +29,6 @@ jobs: pip install -e .[dev] displayName: 'Install PrimAITE for docs autosummary' - - script: | - GATE_WHEEL=$(ls ./GATE/arcd_gate*.whl) - python -m pip install $GATE_WHEEL[dev] - displayName: 'Install GATE' - condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' )) - - - script: | - forfiles /p GATE\ /m *.whl /c "cmd /c python -m pip install @file[dev]" - displayName: 'Install GATE' - condition: eq( variables['Agent.OS'], 'Windows_NT' ) - - script: | primaite setup displayName: 'Perform PrimAITE Setup' diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index efeba284..9070270a 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -81,17 +81,6 @@ stages: displayName: 'Install PrimAITE' condition: eq( variables['Agent.OS'], 'Windows_NT' ) - - script: | - GATE_WHEEL=$(ls ./GATE/arcd_gate*.whl) - python -m pip install $GATE_WHEEL[dev] - displayName: 'Install GATE' - condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' )) - - - script: | - forfiles /p GATE\ /m *.whl /c "cmd /c python -m pip install @file[dev]" - displayName: 'Install GATE' - condition: eq( variables['Agent.OS'], 'Windows_NT' ) - - script: | primaite setup displayName: 'Perform PrimAITE Setup' diff --git a/README.md b/README.md index 9ec8164e..7fc41681 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,6 @@ python3 -m venv .venv attrib +h .venv /s /d # Hides the .venv directory .\.venv\Scripts\activate pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl -pip install GATE/arcd_gate-0.1.0-py3-none-any.whl primaite setup ``` @@ -75,7 +74,6 @@ cd ~/primaite python3 -m venv .venv source .venv/bin/activate pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl -pip install arcd_gate-0.1.0-py3-none-any.whl primaite setup ``` @@ -120,7 +118,6 @@ source venv/bin/activate ```bash python3 -m pip install -e .[dev] -pip install arcd_gate-0.1.0-py3-none-any.whl ``` #### 6. Perform the PrimAITE setup: diff --git a/pyproject.toml b/pyproject.toml index 51ed84f2..92f78ec0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ "stable-baselines3[extra]==2.1.0", "tensorflow==2.12.0", "typer[all]==0.9.0", - "pydantic==2.1.1" + "pydantic==2.1.1", ] [tool.setuptools.dynamic] diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index 30fc9ab9..28245d33 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -29,6 +29,15 @@ class _PrimaitePaths: def __init__(self) -> None: self._dirs: Final[PlatformDirs] = PlatformDirs(appname="primaite", version=__version__) + self.user_home_path = self.generate_user_home_path() + self.user_sessions_path = self.generate_user_sessions_path() + self.user_config_path = self.generate_user_config_path() + self.user_notebooks_path = self.generate_user_notebooks_path() + self.app_home_path = self.generate_app_home_path() + self.app_config_dir_path = self.generate_app_config_dir_path() + self.app_config_file_path = self.generate_app_config_file_path() + self.app_log_dir_path = self.generate_app_log_dir_path() + self.app_log_file_path = self.generate_app_log_file_path() def _get_dirs_properties(self) -> List[str]: class_items = self.__class__.__dict__.items() @@ -43,55 +52,47 @@ class _PrimaitePaths: for p in self._get_dirs_properties(): getattr(self, p) - @property - def user_home_path(self) -> Path: + def generate_user_home_path(self) -> Path: """The PrimAITE user home path.""" path = Path.home() / "primaite" / __version__ path.mkdir(exist_ok=True, parents=True) return path - @property - def user_sessions_path(self) -> Path: + def generate_user_sessions_path(self) -> Path: """The PrimAITE user sessions path.""" path = self.user_home_path / "sessions" path.mkdir(exist_ok=True, parents=True) return path - @property - def user_config_path(self) -> Path: + def generate_user_config_path(self) -> Path: """The PrimAITE user config path.""" path = self.user_home_path / "config" path.mkdir(exist_ok=True, parents=True) return path - @property - def user_notebooks_path(self) -> Path: + def generate_user_notebooks_path(self) -> Path: """The PrimAITE user notebooks path.""" path = self.user_home_path / "notebooks" path.mkdir(exist_ok=True, parents=True) return path - @property - def app_home_path(self) -> Path: + def generate_app_home_path(self) -> Path: """The PrimAITE app home path.""" path = self._dirs.user_data_path path.mkdir(exist_ok=True, parents=True) return path - @property - def app_config_dir_path(self) -> Path: + def generate_app_config_dir_path(self) -> Path: """The PrimAITE app config directory path.""" path = self._dirs.user_config_path path.mkdir(exist_ok=True, parents=True) return path - @property - def app_config_file_path(self) -> Path: + def generate_app_config_file_path(self) -> Path: """The PrimAITE app config file path.""" return self.app_config_dir_path / "primaite_config.yaml" - @property - def app_log_dir_path(self) -> Path: + def generate_app_log_dir_path(self) -> Path: """The PrimAITE app log directory path.""" if sys.platform == "win32": path = self.app_home_path / "logs" @@ -100,8 +101,7 @@ class _PrimaitePaths: path.mkdir(exist_ok=True, parents=True) return path - @property - def app_log_file_path(self) -> Path: + def generate_app_log_file_path(self) -> Path: """The PrimAITE app log file path.""" return self.app_log_dir_path / "primaite.log" @@ -133,6 +133,7 @@ def _get_primaite_config() -> Dict: "DEBUG": logging.DEBUG, "INFO": logging.INFO, "WARN": logging.WARN, + "WARNING": logging.WARN, "ERROR": logging.ERROR, "CRITICAL": logging.CRITICAL, } diff --git a/src/primaite/cli.py b/src/primaite/cli.py index a5b3be46..81ab2792 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -95,8 +95,6 @@ def setup(overwrite_existing: bool = True) -> None: WARNING: All user-data will be lost. """ - from arcd_gate.cli import setup as gate_setup - from primaite import getLogger from primaite.setup import reset_demo_notebooks, reset_example_configs @@ -115,15 +113,13 @@ def setup(overwrite_existing: bool = True) -> None: _LOGGER.info("Rebuilding the example notebooks...") reset_example_configs.run(overwrite_existing=True) - _LOGGER.info("Setting up ARCD GATE...") - gate_setup() - _LOGGER.info("PrimAITE setup complete!") @app.command() def session( config: Optional[str] = None, + agent_load_file: Optional[str] = None, ) -> None: """ Run a PrimAITE session. @@ -131,16 +127,10 @@ def session( :param config: The path to the config file. Optional, if None, the example config will be used. :type config: Optional[str] """ - from threading import Thread - from primaite.config.load import example_config_path from primaite.main import run - from primaite.utils.start_gate_server import start_gate_server - - server_thread = Thread(target=start_gate_server) - server_thread.start() if not config: config = example_config_path() print(config) - run(config_path=config) + run(config_path=config, agent_load_path=agent_load_file) diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index ddf9d923..54007ef0 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -2,10 +2,17 @@ training_config: rl_framework: SB3 rl_algorithm: PPO seed: 333 - n_learn_episodes: 20 - n_learn_steps: 128 - n_eval_episodes: 20 - n_eval_steps: 128 + n_learn_episodes: 25 + n_eval_episodes: 5 + max_steps_per_episode: 128 + deterministic_eval: false + n_agents: 1 + agent_references: + - defender + +io_settings: + save_checkpoints: true + checkpoint_interval: 5 game_config: @@ -108,7 +115,7 @@ game_config: - ref: defender team: BLUE - type: GATERLAgent + type: ProxyAgent observation_space: type: UC2BlueObservation diff --git a/src/primaite/game/agent/GATE_agents.py b/src/primaite/game/agent/GATE_agents.py deleted file mode 100644 index e50d7831..00000000 --- a/src/primaite/game/agent/GATE_agents.py +++ /dev/null @@ -1,31 +0,0 @@ -# flake8: noqa -from typing import Dict, Optional, Tuple - -from gymnasium.core import ActType, ObsType - -from primaite.game.agent.actions import ActionManager -from primaite.game.agent.interface import AbstractGATEAgent, ObsType -from primaite.game.agent.observations import ObservationSpace -from primaite.game.agent.rewards import RewardFunction - - -class GATERLAgent(AbstractGATEAgent): - ... - # The communication with GATE needs to be handled by the PrimaiteSession, rather than by individual agents, - # because when we are supporting MARL, the actions form multiple agents will have to be batched - - # For example MultiAgentEnv in Ray allows sending a dict of observations of multiple agents, then it will reply - # with the actions for those agents. - - def __init__( - self, - agent_name: str | None, - action_space: ActionManager | None, - observation_space: ObservationSpace | None, - reward_function: RewardFunction | None, - ) -> None: - super().__init__(agent_name, action_space, observation_space, reward_function) - self.most_recent_action: ActType - - def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: - return self.most_recent_action diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 89f27f3f..75d209ce 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -1,15 +1,13 @@ """Interface for agents.""" from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple, TypeAlias, Union +from typing import Dict, List, Optional, Tuple -import numpy as np +from gymnasium.core import ActType, ObsType from primaite.game.agent.actions import ActionManager -from primaite.game.agent.observations import ObservationSpace +from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction -ObsType: TypeAlias = Union[Dict, np.ndarray] - class AbstractAgent(ABC): """Base class for scripted and RL agents.""" @@ -18,7 +16,7 @@ class AbstractAgent(ABC): self, agent_name: Optional[str], action_space: Optional[ActionManager], - observation_space: Optional[ObservationSpace], + observation_space: Optional[ObservationManager], reward_function: Optional[RewardFunction], ) -> None: """ @@ -34,24 +32,24 @@ class AbstractAgent(ABC): :type reward_function: Optional[RewardFunction] """ self.agent_name: str = agent_name or "unnamed_agent" - self.action_space: Optional[ActionManager] = action_space - self.observation_space: Optional[ObservationSpace] = observation_space + self.action_manager: Optional[ActionManager] = action_space + self.observation_manager: Optional[ObservationManager] = observation_space self.reward_function: Optional[RewardFunction] = reward_function # exection definiton converts CAOS action to Primaite simulator request, sometimes having to enrich the info # by for example specifying target ip addresses, or converting a node ID into a uuid self.execution_definition = None - def convert_state_to_obs(self, state: Dict) -> ObsType: + def update_observation(self, state: Dict) -> ObsType: """ Convert a state from the simulator into an observation for the agent using the observation space. state : dict state directly from simulation.describe_state output : dict state according to CAOS. """ - return self.observation_space.observe(state) + return self.observation_manager.update(state) - def calculate_reward_from_state(self, state: Dict) -> float: + def update_reward(self, state: Dict) -> float: """ Use the reward function to calculate a reward from the state. @@ -60,10 +58,10 @@ class AbstractAgent(ABC): :return: Reward from the state. :rtype: float """ - return self.reward_function.calculate(state) + return self.reward_function.update(state) @abstractmethod - def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: """ Return an action to be taken in the environment. @@ -76,7 +74,7 @@ class AbstractAgent(ABC): :return: Action to be taken in the environment. :rtype: Tuple[str, Dict] """ - # in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39, + # in RL agent, this method will send CAOS observation to RL agent, then receive a int 0-39, # then use a bespoke conversion to take 1-40 int back into CAOS action return ("DO_NOTHING", {}) @@ -84,7 +82,7 @@ class AbstractAgent(ABC): # this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator. # therefore the execution definition needs to be a mapping from CAOS into SIMULATOR """Format action into format expected by the simulator, and apply execution definition if applicable.""" - request = self.action_space.form_request(action_identifier=action, action_options=options) + request = self.action_manager.form_request(action_identifier=action, action_options=options) return request @@ -97,7 +95,7 @@ class AbstractScriptedAgent(AbstractAgent): class RandomAgent(AbstractScriptedAgent): """Agent that ignores its observation and acts completely at random.""" - def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: """Randomly sample an action from the action space. :param obs: _description_ @@ -107,10 +105,44 @@ class RandomAgent(AbstractScriptedAgent): :return: _description_ :rtype: Tuple[str, Dict] """ - return self.action_space.get_action(self.action_space.space.sample()) + return self.action_manager.get_action(self.action_manager.space.sample()) -class AbstractGATEAgent(AbstractAgent): - """Base class for actors controlled via external messages, such as RL policies.""" +class ProxyAgent(AbstractAgent): + """Agent that sends observations to an RL model and receives actions from that model.""" - ... + def __init__( + self, + agent_name: Optional[str], + action_space: Optional[ActionManager], + observation_space: Optional[ObservationManager], + reward_function: Optional[RewardFunction], + ) -> None: + super().__init__( + agent_name=agent_name, + action_space=action_space, + observation_space=observation_space, + reward_function=reward_function, + ) + self.most_recent_action: ActType + + def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: + """ + Return the agent's most recent action, formatted in CAOS format. + + :param obs: Observation for the agent. Not used by ProxyAgents, but required by the interface. + :type obs: ObsType + :param reward: Reward value for the agent. Not used by ProxyAgents, defaults to None. + :type reward: float, optional + :return: Action to be taken in CAOS format. + :rtype: Tuple[str, Dict] + """ + return self.action_manager.get_action(self.most_recent_action) + + def store_action(self, action: ActType): + """ + Store the most recent action taken by the agent. + + The environment is responsible for calling this method when it receives an action from the agent policy. + """ + self.most_recent_action = action diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index a3bafeea..a74771c0 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING from gymnasium import spaces +from gymnasium.core import ObsType from primaite import getLogger from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE @@ -926,7 +927,7 @@ class UC2GreenObservation(NullObservation): pass -class ObservationSpace: +class ObservationManager: """ Manage the observations of an Agent. @@ -947,15 +948,17 @@ class ObservationSpace: :type observation: AbstractObservation """ self.obs: AbstractObservation = observation + self.current_observation: ObsType - def observe(self, state: Dict) -> Dict: + def update(self, state: Dict) -> Dict: """ Generate observation based on the current state of the simulation. :param state: Simulation state dictionary :type state: Dict """ - return self.obs.observe(state) + self.current_observation = self.obs.observe(state) + return self.current_observation @property def space(self) -> None: @@ -963,7 +966,7 @@ class ObservationSpace: return self.obs.space @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationSpace": + def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationManager": """Create observation space from a config. :param config: Dictionary containing the configuration for this observation space. diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 6c408ff9..da1331b0 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -26,7 +26,7 @@ the structure: ``` """ from abc import abstractmethod -from typing import Dict, List, Tuple, TYPE_CHECKING +from typing import Dict, List, Tuple, Type, TYPE_CHECKING from primaite import getLogger from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE @@ -228,7 +228,7 @@ class WebServer404Penalty(AbstractReward): class RewardFunction: """Manages the reward function for the agent.""" - __rew_class_identifiers: Dict[str, type[AbstractReward]] = { + __rew_class_identifiers: Dict[str, Type[AbstractReward]] = { "DUMMY": DummyReward, "DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity, "WEB_SERVER_404_PENALTY": WebServer404Penalty, @@ -238,6 +238,7 @@ class RewardFunction: """Initialise the reward function object.""" self.reward_components: List[Tuple[AbstractReward, float]] = [] "attribute reward_components keeps track of reward components and the weights assigned to each." + self.current_reward: float def regsiter_component(self, component: AbstractReward, weight: float = 1.0) -> None: """Add a reward component to the reward function. @@ -249,7 +250,7 @@ class RewardFunction: """ self.reward_components.append((component, weight)) - def calculate(self, state: Dict) -> float: + def update(self, state: Dict) -> float: """Calculate the overall reward for the current state. :param state: The current state of the simulation. @@ -260,7 +261,8 @@ class RewardFunction: comp = comp_and_weight[0] weight = comp_and_weight[1] total += weight * comp.calculate(state=state) - return total + self.current_reward = total + return self.current_reward @classmethod def from_config(cls, config: Dict, session: "PrimaiteSession") -> "RewardFunction": diff --git a/src/primaite/game/io.py b/src/primaite/game/io.py new file mode 100644 index 00000000..e0b849c9 --- /dev/null +++ b/src/primaite/game/io.py @@ -0,0 +1,63 @@ +from datetime import datetime +from pathlib import Path +from typing import Optional + +from pydantic import BaseModel, ConfigDict + +from primaite import PRIMAITE_PATHS +from primaite.simulator import SIM_OUTPUT + + +class SessionIOSettings(BaseModel): + """Schema for session IO settings.""" + + model_config = ConfigDict(extra="forbid") + + save_final_model: bool = True + """Whether to save the final model right at the end of training.""" + save_checkpoints: bool = False + """Whether to save a checkpoint model every `checkpoint_interval` episodes""" + checkpoint_interval: int = 10 + """How often to save a checkpoint model (if save_checkpoints is True).""" + save_logs: bool = True + """Whether to save logs""" + save_transactions: bool = True + """Whether to save transactions, If true, the session path will have a transactions folder.""" + save_tensorboard_logs: bool = False + """Whether to save tensorboard logs. If true, the session path will have a tenorboard_logs folder.""" + + +class SessionIO: + """ + Class for managing session IO. + + Currently it's handling path generation, but could expand to handle loading, transaction, tensorboard, and so on. + """ + + def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None: + self.settings: SessionIOSettings = settings + self.session_path: Path = self.generate_session_path() + + # set global SIM_OUTPUT path + SIM_OUTPUT.path = self.session_path / "simulation_output" + + # warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's + # possible refactor needed + + def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path: + """Create a folder for the session and return the path to it.""" + if timestamp is None: + timestamp = datetime.now() + date_str = timestamp.strftime("%Y-%m-%d") + time_str = timestamp.strftime("%H-%M-%S") + session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str + session_path.mkdir(exist_ok=True, parents=True) + return session_path + + def generate_model_save_path(self, agent_name: str) -> Path: + """Return the path where the final model will be saved (excluding filename extension).""" + return self.session_path / "checkpoints" / f"{agent_name}_final" + + def generate_checkpoint_save_path(self, agent_name: str, episode: int) -> Path: + """Return the path where the checkpoint model will be saved (excluding filename extension).""" + return self.session_path / "checkpoints" / f"{agent_name}_checkpoint_{episode}.pt" diff --git a/src/primaite/game/policy/__init__.py b/src/primaite/game/policy/__init__.py new file mode 100644 index 00000000..29196112 --- /dev/null +++ b/src/primaite/game/policy/__init__.py @@ -0,0 +1,3 @@ +from primaite.game.policy.sb3 import SB3Policy + +__all__ = ["SB3Policy"] diff --git a/src/primaite/game/policy/policy.py b/src/primaite/game/policy/policy.py new file mode 100644 index 00000000..249c3b52 --- /dev/null +++ b/src/primaite/game/policy/policy.py @@ -0,0 +1,84 @@ +"""Base class and common logic for RL policies.""" +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, Dict, Type, TYPE_CHECKING + +if TYPE_CHECKING: + from primaite.game.session import PrimaiteSession, TrainingOptions + + +class PolicyABC(ABC): + """Base class for reinforcement learning agents.""" + + _registry: Dict[str, Type["PolicyABC"]] = {} + """ + Registry of policy types, keyed by name. + + Automatically populated when PolicyABC subclasses are defined. Used for defining from_config. + """ + + def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: + """ + Register a policy subclass. + + :param name: Identifier used by from_config to create an instance of the policy. + :type name: str + :raises ValueError: When attempting to create a policy with a duplicate name. + """ + super().__init_subclass__(**kwargs) + if identifier in cls._registry: + raise ValueError(f"Duplicate policy name {identifier}") + cls._registry[identifier] = cls + return + + @abstractmethod + def __init__(self, session: "PrimaiteSession") -> None: + """ + Initialize a reinforcement learning policy. + + :param session: The session context. + :type session: PrimaiteSession + :param agents: The agents to train. + :type agents: List[RLAgent] + """ + self.session: "PrimaiteSession" = session + """Reference to the session.""" + + @abstractmethod + def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: + """Train the agent.""" + pass + + @abstractmethod + def eval(self, n_episodes: int, timesteps_per_episode: int, deterministic: bool) -> None: + """Evaluate the agent.""" + pass + + @abstractmethod + def save(self, save_path: Path) -> None: + """Save the agent.""" + pass + + @abstractmethod + def load(self) -> None: + """Load agent from a file.""" + pass + + def close(self) -> None: + """Close the agent.""" + pass + + @classmethod + def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC": + """ + Create an RL policy from a config by calling the relevant subclass's from_config method. + + Subclasses should not call super().from_config(), they should just handle creation form config. + """ + # Assume that basically the contents of training_config are passed into here. + # I should really define a config schema class using pydantic. + + PolicyType = cls._registry[config.rl_framework] + return PolicyType.from_config(config=config, session=session) + + # saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py new file mode 100644 index 00000000..a4870054 --- /dev/null +++ b/src/primaite/game/policy/sb3.py @@ -0,0 +1,80 @@ +"""Stable baselines 3 policy.""" +from pathlib import Path +from typing import Literal, Optional, Type, TYPE_CHECKING, Union + +from stable_baselines3 import A2C, PPO +from stable_baselines3.a2c import MlpPolicy as A2C_MLP +from stable_baselines3.common.callbacks import CheckpointCallback +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.ppo import MlpPolicy as PPO_MLP + +from primaite.game.policy.policy import PolicyABC + +if TYPE_CHECKING: + from primaite.game.session import PrimaiteSession, TrainingOptions + + +class SB3Policy(PolicyABC, identifier="SB3"): + """Single agent RL policy using stable baselines 3.""" + + def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): + """Initialize a stable baselines 3 policy.""" + super().__init__(session=session) + + self._agent_class: Type[Union[PPO, A2C]] + if algorithm == "PPO": + self._agent_class = PPO + policy = PPO_MLP + elif algorithm == "A2C": + self._agent_class = A2C + policy = A2C_MLP + else: + raise ValueError(f"Unknown algorithm `{algorithm}` for stable_baselines3 policy") + self._agent = self._agent_class( + policy=policy, + env=self.session.env, + n_steps=128, # this is not the number of steps in an episode, but the number of steps in a batch + seed=seed, + ) + + def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: + """Train the agent.""" + if self.session.io_manager.settings.save_checkpoints: + checkpoint_callback = CheckpointCallback( + save_freq=timesteps_per_episode * self.session.io_manager.settings.checkpoint_interval, + save_path=self.session.io_manager.generate_model_save_path("sb3"), + name_prefix="sb3_model", + ) + else: + checkpoint_callback = None + self._agent.learn(total_timesteps=n_episodes * timesteps_per_episode, callback=checkpoint_callback) + + def eval(self, n_episodes: int, deterministic: bool) -> None: + """Evaluate the agent.""" + reward_data = evaluate_policy( + self._agent, + self.session.env, + n_eval_episodes=n_episodes, + deterministic=deterministic, + return_episode_rewards=True, + ) + print(reward_data) + + def save(self, save_path: Path) -> None: + """ + Save the current policy parameters. + + Warning: The recommended way to save model checkpoints is to use a callback within the `learn()` method. Please + refer to https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html for more information. + Therefore, this method is only used to save the final model. + """ + self._agent.save(save_path) + + def load(self, model_path: Path) -> None: + """Load agent from a checkpoint.""" + self._agent = self._agent_class.load(model_path, env=self.session.env) + + @classmethod + def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy": + """Create an agent from config file.""" + return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed) diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index d40d0754..ad0537e8 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -1,18 +1,20 @@ """PrimAITE session - the main entry point to training agents on PrimAITE.""" +from enum import Enum from ipaddress import IPv4Address -from typing import Any, Dict, List, Optional, Tuple +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple -from arcd_gate.client.gate_client import ActType, GATEClient -from gymnasium import spaces +import gymnasium from gymnasium.core import ActType, ObsType -from gymnasium.spaces.utils import flatten, flatten_space -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from primaite import getLogger from primaite.game.agent.actions import ActionManager -from primaite.game.agent.interface import AbstractAgent, RandomAgent -from primaite.game.agent.observations import ObservationSpace +from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent +from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction +from primaite.game.io import SessionIO, SessionIOSettings +from primaite.game.policy.policy import PolicyABC from primaite.simulator.network.hardware.base import Link, NIC, Node from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.router import ACLAction, Router @@ -34,109 +36,62 @@ from primaite.simulator.system.services.web_server.web_server import WebServer _LOGGER = getLogger(__name__) -class PrimaiteGATEClient(GATEClient): - """Lightweight wrapper around the GATEClient class that allows PrimAITE to message GATE.""" +class PrimaiteGymEnv(gymnasium.Env): + """ + Thin wrapper env to provide agents with a gymnasium API. - def __init__(self, parent_session: "PrimaiteSession", service_port: int = 50000): - """ - Create a new GATE client for PrimAITE. + This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some + assumptions about the agent list always having a list of length 1. + """ - :param parent_session: The parent session object. - :type parent_session: PrimaiteSession - :param service_port: The port on which the GATE service is running. - :type service_port: int, optional - """ - super().__init__(service_port=service_port) - self.parent_session: "PrimaiteSession" = parent_session + def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]): + """Initialise the environment.""" + super().__init__() + self.session: "PrimaiteSession" = session + self.agent: ProxyAgent = agents[0] - @property - def rl_framework(self) -> str: - """The reinforcement learning framework to use.""" - return self.parent_session.training_options.rl_framework + def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: + """Perform a step in the environment.""" + # make ProxyAgent store the action chosen my the RL policy + self.agent.store_action(action) + # apply_agent_actions accesses the action we just stored + self.session.apply_agent_actions() + self.session.advance_timestep() + state = self.session.get_sim_state() + self.session.update_agents(state) - @property - def rl_algorithm(self) -> str: - """The reinforcement learning algorithm to use.""" - return self.parent_session.training_options.rl_algorithm - - @property - def seed(self) -> int | None: - """The seed to use for the environment's random number generator.""" - return self.parent_session.training_options.seed - - @property - def n_learn_episodes(self) -> int: - """The number of episodes in each learning run.""" - return self.parent_session.training_options.n_learn_episodes - - @property - def n_learn_steps(self) -> int: - """The number of steps in each learning episode.""" - return self.parent_session.training_options.n_learn_steps - - @property - def n_eval_episodes(self) -> int: - """The number of episodes in each evaluation run.""" - return self.parent_session.training_options.n_eval_episodes - - @property - def n_eval_steps(self) -> int: - """The number of steps in each evaluation episode.""" - return self.parent_session.training_options.n_eval_steps - - @property - def action_space(self) -> spaces.Space: - """The gym action space of the agent.""" - return self.parent_session.rl_agent.action_space.space - - @property - def observation_space(self) -> spaces.Space: - """The gymnasium observation space of the agent.""" - return flatten_space(self.parent_session.rl_agent.observation_space.space) - - def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, Dict]: - """Take a step in the environment. - - This method is called by GATE to advance the simulation by one timestep. - - :param action: The agent's action. - :type action: ActType - :return: The observation, reward, terminal flag, truncated flag, and info dictionary. - :rtype: Tuple[ObsType, float, bool, bool, Dict] - """ - self.parent_session.rl_agent.most_recent_action = action - self.parent_session.step() - state = self.parent_session.simulation.describe_state() - obs = self.parent_session.rl_agent.observation_space.observe(state) - obs = flatten(self.parent_session.rl_agent.observation_space.space, obs) - rew = self.parent_session.rl_agent.reward_function.calculate(state) - term = False - trunc = False + next_obs = self._get_obs() + reward = self.agent.reward_function.current_reward + terminated = False + truncated = self.session.calculate_truncated() info = {} - return obs, rew, term, trunc, info - def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> Tuple[ObsType, Dict]: - """Reset the environment. + return next_obs, reward, terminated, truncated, info - This method is called when the environment is initialized and at the end of each episode. + def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: + """Reset the environment.""" + self.session.reset() + state = self.session.get_sim_state() + self.session.update_agents(state) + next_obs = self._get_obs() + info = {} + return next_obs, info - :param seed: The seed to use for the environment's random number generator. - :type seed: int, optional - :param options: Additional options for the reset. None are used by PrimAITE but this is included for - compatibility with GATE. - :type options: dict[str, Any], optional - :return: The initial observation and an empty info dictionary. - :rtype: Tuple[ObsType, Dict] - """ - self.parent_session.reset() - state = self.parent_session.simulation.describe_state() - obs = self.parent_session.rl_agent.observation_space.observe(state) - obs = flatten(self.parent_session.rl_agent.observation_space.space, obs) - return obs, {} + @property + def action_space(self) -> gymnasium.Space: + """Return the action space of the environment.""" + return self.agent.action_manager.space - def close(self): - """Close the session, this will stop the gate client and close the simulation.""" - self.parent_session.close() + @property + def observation_space(self) -> gymnasium.Space: + """Return the observation space of the environment.""" + return gymnasium.spaces.flatten_space(self.agent.observation_manager.space) + + def _get_obs(self) -> ObsType: + """Return the current observation.""" + unflat_space = self.agent.observation_manager.space + unflat_obs = self.agent.observation_manager.current_observation + return gymnasium.spaces.flatten(unflat_space, unflat_obs) class PrimaiteSessionOptions(BaseModel): @@ -146,6 +101,8 @@ class PrimaiteSessionOptions(BaseModel): Currently this is used to restrict which ports and protocols exist in the world of the simulation. """ + model_config = ConfigDict(extra="forbid") + ports: List[str] protocols: List[str] @@ -153,48 +110,102 @@ class PrimaiteSessionOptions(BaseModel): class TrainingOptions(BaseModel): """Options for training the RL agent.""" - rl_framework: str - rl_algorithm: str - seed: Optional[int] + model_config = ConfigDict(extra="forbid") + + rl_framework: Literal["SB3", "RLLIB"] + rl_algorithm: Literal["PPO", "A2C"] n_learn_episodes: int - n_learn_steps: int - n_eval_episodes: int - n_eval_steps: int + n_eval_episodes: Optional[int] = None + max_steps_per_episode: int + # checkpoint_freq: Optional[int] = None + deterministic_eval: bool + seed: Optional[int] + n_agents: int + agent_references: List[str] + + +class SessionMode(Enum): + """Helper to keep track of the current session mode.""" + + TRAIN = "train" + EVAL = "eval" + MANUAL = "manual" class PrimaiteSession: - """The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and connections to ARCD GATE.""" + """The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments.""" def __init__(self): + """Initialise a PrimaiteSession object.""" self.simulation: Simulation = Simulation() """Simulation object with which the agents will interact.""" + self.agents: List[AbstractAgent] = [] """List of agents.""" - self.rl_agent: AbstractAgent - """The agent from the list which communicates with GATE to perform reinforcement learning.""" + + self.rl_agents: List[ProxyAgent] = [] + """Subset of agent list including only the reinforcement learning agents.""" + self.step_counter: int = 0 """Current timestep within the episode.""" + self.episode_counter: int = 0 """Current episode number.""" + self.options: PrimaiteSessionOptions """Special options that apply for the entire game.""" + self.training_options: TrainingOptions """Options specific to agent training.""" + self.policy: PolicyABC + """The reinforcement learning policy.""" + self.ref_map_nodes: Dict[str, Node] = {} """Mapping from unique node reference name to node object. Used when parsing config files.""" + self.ref_map_services: Dict[str, Service] = {} """Mapping from human-readable service reference to service object. Used for parsing config files.""" + self.ref_map_applications: Dict[str, Application] = {} """Mapping from human-readable application reference to application object. Used for parsing config files.""" + self.ref_map_links: Dict[str, Link] = {} """Mapping from human-readable link reference to link object. Used when parsing config files.""" - self.gate_client: PrimaiteGATEClient = PrimaiteGATEClient(self) - """Reference to a GATE Client object, which will send data to GATE service for training RL agent.""" + + self.env: PrimaiteGymEnv + """The environment that the agent can consume. Could be PrimaiteEnv.""" + + self.mode: SessionMode = SessionMode.MANUAL + """Current session mode.""" + + self.io_manager = SessionIO() + """IO manager for the session.""" def start_session(self) -> None: - """Commence the training session, this gives the GATE client control over the simulation/agent loop.""" - self.gate_client.start() + """Commence the training session.""" + self.mode = SessionMode.TRAIN + n_learn_episodes = self.training_options.n_learn_episodes + n_eval_episodes = self.training_options.n_eval_episodes + max_steps_per_episode = self.training_options.max_steps_per_episode + + deterministic_eval = self.training_options.deterministic_eval + self.policy.learn( + n_episodes=n_learn_episodes, + timesteps_per_episode=max_steps_per_episode, + ) + self.save_models() + + self.mode = SessionMode.EVAL + if n_eval_episodes > 0: + self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval) + + self.mode = SessionMode.MANUAL + + def save_models(self) -> None: + """Save the RL models.""" + save_path = self.io_manager.generate_model_save_path("temp_model_name") + self.policy.save(save_path) def step(self): """ @@ -208,57 +219,76 @@ class PrimaiteSession: 4. Each agent chooses an action based on the observation. 5. Each agent converts the action to a request. 6. The simulation applies the requests. + + Warning: This method should only be used with scripted agents. For RL agents, the environment that the agent + interacts with should implement a step method that calls methods used by this method. For example, if using a + single-agent gym, make sure to update the ProxyAgent's action with the action before calling + ``self.apply_agent_actions()``. """ _LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}") - # currently designed with assumption that all agents act once per step in order + # Get the current state of the simulation + sim_state = self.get_sim_state() + + # Update agents' observations and rewards based on the current state + self.update_agents(sim_state) + + # Apply all actions to simulation as requests + self.apply_agent_actions() + + # Advance timestep + self.advance_timestep() + + def get_sim_state(self) -> Dict: + """Get the current state of the simulation.""" + return self.simulation.describe_state() + + def update_agents(self, state: Dict) -> None: + """Update agents' observations and rewards based on the current state.""" for agent in self.agents: - # 3. primaite session asks simulation to provide initial state - # 4. primate session gives state to all agents - # 5. primaite session asks agents to produce an action based on most recent state - _LOGGER.debug(f"Sending simulation state to agent {agent.agent_name}") - sim_state = self.simulation.describe_state() + agent.update_observation(state) + agent.update_reward(state) - # 6. each agent takes most recent state and converts it to CAOS observation - agent_obs = agent.convert_state_to_obs(sim_state) + def apply_agent_actions(self) -> None: + """Apply all actions to simulation as requests.""" + for agent in self.agents: + obs = agent.observation_manager.current_observation + rew = agent.reward_function.current_reward + action_choice, options = agent.get_action(obs, rew) + request = agent.format_request(action_choice, options) + self.simulation.apply_request(request) - # 7. meanwhile each agent also takes state and calculates reward - agent_reward = agent.calculate_reward_from_state(sim_state) - - # 8. each agent takes observation and applies decision rule to observation to create CAOS - # action(such as random, rulebased, or send to GATE) (therefore, converting CAOS action - # to discrete(40) is only necessary for purposes of RL learning, therefore that bit of - # code should live inside of the GATE agent subclass) - # gets action in CAOS format - _LOGGER.debug("Getting agent action") - agent_action, action_options = agent.get_action(agent_obs, agent_reward) - # 9. CAOS action is converted into request (extra information might be needed to enrich - # the request, this is what the execution definition is there for) - _LOGGER.debug(f"Formatting agent action {agent_action}") # maybe too many debug log statements - agent_request = agent.format_request(agent_action, action_options) - - # 10. primaite session receives the action from the agents and asks the simulation to apply each - _LOGGER.debug(f"Sending request to simulation: {agent_request}") - self.simulation.apply_request(agent_request) - - _LOGGER.debug(f"Initiating simulation step {self.step_counter}") - self.simulation.apply_timestep(self.step_counter) + def advance_timestep(self) -> None: + """Advance timestep.""" self.step_counter += 1 + _LOGGER.debug(f"Advancing timestep to {self.step_counter} ") + self.simulation.apply_timestep(self.step_counter) + + def calculate_truncated(self) -> bool: + """Calculate whether the episode is truncated.""" + current_step = self.step_counter + max_steps = self.training_options.max_steps_per_episode + if current_step >= max_steps: + return True + return False def reset(self) -> None: """Reset the session, this will reset the simulation.""" - return NotImplemented + self.episode_counter += 1 + self.step_counter = 0 + _LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}") + self.simulation.reset_component_for_episode(self.episode_counter) def close(self) -> None: - """Close the session, this will stop the gate client and close the simulation.""" + """Close the session, this will stop the env and close the simulation.""" return NotImplemented @classmethod - def from_config(cls, cfg: dict) -> "PrimaiteSession": + def from_config(cls, cfg: dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession": """Create a PrimaiteSession object from a config dictionary. The config dictionary should have the following top-level keys: - 1. training_config: options for training the RL agent. Used by GATE. + 1. training_config: options for training the RL agent. 2. game_config: options for the game itself. Used by PrimaiteSession. 3. simulation: defines the network topology and the initial state of the simulation. @@ -276,6 +306,11 @@ class PrimaiteSession: protocols=cfg["game_config"]["protocols"], ) sess.training_options = TrainingOptions(**cfg["training_config"]) + + # READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS... + io_settings = cfg.get("io_settings", {}) + sess.io_manager.settings = SessionIOSettings(**io_settings) + sim = sess.simulation net = sim.network @@ -412,7 +447,7 @@ class PrimaiteSession: reward_function_cfg = agent_cfg["reward_function"] # CREATE OBSERVATION SPACE - obs_space = ObservationSpace.from_config(observation_space_cfg, sess) + obs_space = ObservationManager.from_config(observation_space_cfg, sess) # CREATE ACTION SPACE action_space_cfg["options"]["node_uuids"] = [] @@ -448,15 +483,15 @@ class PrimaiteSession: reward_function=rew_function, ) sess.agents.append(new_agent) - elif agent_type == "GATERLAgent": - new_agent = RandomAgent( + elif agent_type == "ProxyAgent": + new_agent = ProxyAgent( agent_name=agent_cfg["ref"], action_space=action_space, observation_space=obs_space, reward_function=rew_function, ) sess.agents.append(new_agent) - sess.rl_agent = new_agent + sess.rl_agents.append(new_agent) elif agent_type == "RedDatabaseCorruptingAgent": new_agent = RandomAgent( agent_name=agent_cfg["ref"], @@ -468,4 +503,12 @@ class PrimaiteSession: else: print("agent type not found") + # CREATE ENVIRONMENT + sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents) + + # CREATE POLICY + sess.policy = PolicyABC.from_config(sess.training_options, session=sess) + if agent_load_path: + sess.policy.load(Path(agent_load_path)) + return sess diff --git a/src/primaite/main.py b/src/primaite/main.py index 831419d4..1699fe51 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -15,6 +15,7 @@ _LOGGER = getLogger(__name__) def run( config_path: Optional[Union[str, Path]] = "", + agent_load_path: Optional[Union[str, Path]] = None, ) -> None: """ Run the PrimAITE Session. @@ -31,7 +32,7 @@ def run( otherwise False. """ cfg = load(config_path) - sess = PrimaiteSession.from_config(cfg=cfg) + sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path) sess.start_session() diff --git a/src/primaite/simulator/__init__.py b/src/primaite/simulator/__init__.py index 8c55542f..19c86e28 100644 --- a/src/primaite/simulator/__init__.py +++ b/src/primaite/simulator/__init__.py @@ -1,14 +1,26 @@ +"""Warning: SIM_OUTPUT is a mutable global variable for the simulation output directory.""" from datetime import datetime +from pathlib import Path from primaite import _PRIMAITE_ROOT -SIM_OUTPUT = None -"A path at the repo root dir to use temporarily for sim output testing while in dev." -# TODO: Remove once we integrate the simulation into PrimAITE and it uses the primaite session path +__all__ = ["SIM_OUTPUT"] -if not SIM_OUTPUT: - session_timestamp = datetime.now() - date_dir = session_timestamp.strftime("%Y-%m-%d") - sim_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - SIM_OUTPUT = _PRIMAITE_ROOT.parent.parent / "simulation_output" / date_dir / sim_path - SIM_OUTPUT.mkdir(exist_ok=True, parents=True) + +class __SimOutput: + def __init__(self): + self._path: Path = ( + _PRIMAITE_ROOT.parent.parent / "simulation_output" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ) + + @property + def path(self) -> Path: + return self._path + + @path.setter + def path(self, new_path: Path) -> None: + self._path = new_path + self._path.mkdir(exist_ok=True, parents=True) + + +SIM_OUTPUT = __SimOutput() diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 537cebb2..29d3a05c 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -957,7 +957,7 @@ class Node(SimComponent): if not kwargs.get("session_manager"): kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp")) if not kwargs.get("root"): - kwargs["root"] = SIM_OUTPUT / kwargs["hostname"] + kwargs["root"] = SIM_OUTPUT.path / kwargs["hostname"] if not kwargs.get("file_system"): kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs") if not kwargs.get("software_manager"): diff --git a/src/primaite/simulator/system/core/packet_capture.py b/src/primaite/simulator/system/core/packet_capture.py index 2e5ed008..c2faeb10 100644 --- a/src/primaite/simulator/system/core/packet_capture.py +++ b/src/primaite/simulator/system/core/packet_capture.py @@ -75,7 +75,7 @@ class PacketCapture: def _get_log_path(self) -> Path: """Get the path for the log file.""" - root = SIM_OUTPUT / self.hostname + root = SIM_OUTPUT.path / self.hostname root.mkdir(exist_ok=True, parents=True) return root / f"{self._logger_name}.log" diff --git a/src/primaite/simulator/system/core/sys_log.py b/src/primaite/simulator/system/core/sys_log.py index 791e0be8..7ac6df85 100644 --- a/src/primaite/simulator/system/core/sys_log.py +++ b/src/primaite/simulator/system/core/sys_log.py @@ -41,7 +41,6 @@ class SysLog: JSON-like messages. """ log_path = self._get_log_path() - file_handler = logging.FileHandler(filename=log_path) file_handler.setLevel(logging.DEBUG) @@ -81,7 +80,7 @@ class SysLog: :return: Path object representing the location of the log file. """ - root = SIM_OUTPUT / self.hostname + root = SIM_OUTPUT.path / self.hostname root.mkdir(exist_ok=True, parents=True) return root / f"{self.hostname}_sys.log" diff --git a/src/primaite/utils/start_gate_server.py b/src/primaite/utils/start_gate_server.py deleted file mode 100644 index d91952f2..00000000 --- a/src/primaite/utils/start_gate_server.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Utility script to start the gate server for running PrimAITE in attached mode.""" -from arcd_gate.server.gate_service import GATEService - - -def start_gate_server(): - """Start the gate server.""" - service = GATEService() - service.start() - - -if __name__ == "__main__": - start_gate_server() diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml new file mode 100644 index 00000000..80567aea --- /dev/null +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -0,0 +1,725 @@ +training_config: + rl_framework: SB3 + rl_algorithm: PPO + se3ed: 333 # Purposeful typo to check that error is raised with bad configuration. + n_learn_steps: 2560 + n_eval_episodes: 5 + + + +game_config: + ports: + - ARP + - DNS + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + + agents: + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com + + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 + + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot + observations: + operating_status + health_status + folders: {} + + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com + + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 + + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot + observations: + operating_status + health_status + folders: {} + + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com + + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 + + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot + observations: + operating_status + health_status + folders: {} + + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com + + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 + + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot + observations: + operating_status + health_status + folders: {} + + action_space: + action_list: + - type: DONOTHING + # FileSystem: # PrimAITE v2 stuff -@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3 -class TempPrimaiteSession: # PrimaiteSession): +class TempPrimaiteSession(PrimaiteSession): """ A temporary PrimaiteSession class. Uses context manager for deletion of files upon exit. """ - # def __init__( - # self, - # training_config_path: Union[str, Path], - # lay_down_config_path: Union[str, Path], - # ): - # super().__init__(training_config_path, lay_down_config_path) - # self.setup() + @classmethod + def from_config(cls, config_path: Union[str, Path]) -> "TempPrimaiteSession": + """Create a temporary PrimaiteSession object from a config file.""" + config_path = Path(config_path) + with open(config_path, "r") as f: + config = yaml.safe_load(f) - # @property - # def env(self) -> Primaite: - # """Direct access to the env for ease of testing.""" - # return self._agent_session._env # noqa + return super().from_config(cfg=config) - # def __enter__(self): - # return self + def __enter__(self): + return self - # def __exit__(self, type, value, tb): - # shutil.rmtree(self.session_path) - # _LOGGER.debug(f"Deleted temp session directory: {self.session_path}") + def __exit__(self, type, value, tb): + pass -@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3 @pytest.fixture -def temp_primaite_session(request): - """ - Provides a temporary PrimaiteSession instance. - - It's temporary as it uses a temporary directory as the session path. - - To use this fixture you need to: - - - parametrize your test function with: - - - "temp_primaite_session" - - [[path to training config, path to lay down config]] - - Include the temp_primaite_session fixture as a param in your test - function. - - use the temp_primaite_session as a context manager assigning is the - name 'session'. - - .. code:: python - - from primaite.config.lay_down_config import dos_very_basic_config_path - from primaite.config.training_config import main_training_config_path - @pytest.mark.parametrize( - "temp_primaite_session", - [ - [main_training_config_path(), dos_very_basic_config_path()] - ], - indirect=True - ) - def test_primaite_session(temp_primaite_session): - with temp_primaite_session as session: - # Learning outputs are saved in session.learning_path - session.learn() - - # Evaluation outputs are saved in session.evaluation_path - session.evaluate() - - # To ensure that all files are written, you must call .close() - session.close() - - # If you need to inspect any session outputs, it must be done - # inside the context manager - - # Now that we've exited the context manager, the - # session.session_path directory and its contents are deleted - """ - training_config_path = request.param[0] - lay_down_config_path = request.param[1] - with patch("primaite.agents.agent_abc.get_session_path", get_temp_session_path) as mck: - mck.session_timestamp = datetime.now() - - return TempPrimaiteSession(training_config_path, lay_down_config_path) - - -@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3 -@pytest.fixture -def temp_session_path() -> Path: - """ - Get a temp directory session path the test session will output to. - - :return: The session directory path. - """ - session_timestamp = datetime.now() - date_dir = session_timestamp.strftime("%Y-%m-%d") - session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = Path(tempfile.gettempdir()) / "_primaite" / date_dir / session_path - session_path.mkdir(exist_ok=True, parents=True) - - return session_path +def temp_primaite_session(request, monkeypatch) -> TempPrimaiteSession: + """Create a temporary PrimaiteSession object.""" + monkeypatch.setattr(PRIMAITE_PATHS, "user_sessions_path", temp_user_sessions_path()) + config_path = request.param[0] + return TempPrimaiteSession.from_config(config_path=config_path) diff --git a/tests/e2e_integration_tests/test_primaite_session.py b/tests/e2e_integration_tests/test_primaite_session.py new file mode 100644 index 00000000..b6122bad --- /dev/null +++ b/tests/e2e_integration_tests/test_primaite_session.py @@ -0,0 +1,68 @@ +import pydantic +import pytest + +from tests.conftest import TempPrimaiteSession + +CFG_PATH = "tests/assets/configs/test_primaite_session.yaml" +TRAINING_ONLY_PATH = "tests/assets/configs/train_only_primaite_session.yaml" +EVAL_ONLY_PATH = "tests/assets/configs/eval_only_primaite_session.yaml" +MISCONFIGURED_PATH = "tests/assets/configs/bad_primaite_session.yaml" + + +class TestPrimaiteSession: + @pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True) + def test_creating_session(self, temp_primaite_session): + """Check that creating a session from config works.""" + with temp_primaite_session as session: + if not isinstance(session, TempPrimaiteSession): + raise AssertionError + + assert session is not None + assert session.simulation + assert len(session.agents) == 3 + assert len(session.rl_agents) == 1 + + assert session.policy + assert session.env + + assert session.simulation.network + assert len(session.simulation.network.nodes) == 10 + + @pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True) + def test_start_session(self, temp_primaite_session): + """Make sure you can go all the way through the session without errors.""" + with temp_primaite_session as session: + session: TempPrimaiteSession + session.start_session() + + session_path = session.io_manager.session_path + assert session_path.exists() + print(list(session_path.glob("*"))) + checkpoint_dir = session_path / "checkpoints" / "sb3_final" + assert checkpoint_dir.exists() + checkpoint_1 = checkpoint_dir / "sb3_model_640_steps.zip" + checkpoint_2 = checkpoint_dir / "sb3_model_1280_steps.zip" + checkpoint_3 = checkpoint_dir / "sb3_model_1920_steps.zip" + assert checkpoint_1.exists() + assert checkpoint_2.exists() + assert not checkpoint_3.exists() + + @pytest.mark.parametrize("temp_primaite_session", [[TRAINING_ONLY_PATH]], indirect=True) + def test_training_only_session(self, temp_primaite_session): + """Check that you can run a training-only session.""" + with temp_primaite_session as session: + session: TempPrimaiteSession + session.start_session() + # TODO: include checks that the model was trained, e.g. that the loss changed and checkpoints were saved? + + @pytest.mark.parametrize("temp_primaite_session", [[EVAL_ONLY_PATH]], indirect=True) + def test_eval_only_session(self, temp_primaite_session): + """Check that you can load a model and run an eval-only session.""" + with temp_primaite_session as session: + session: TempPrimaiteSession + session.start_session() + # TODO: include checks that the model was loaded and that the eval-only session ran + + def test_error_thrown_on_bad_configuration(self): + with pytest.raises(pydantic.ValidationError): + session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH) diff --git a/tests/mock_and_patch/get_session_path_mock.py b/tests/mock_and_patch/get_session_path_mock.py index 16c4a274..06fe5893 100644 --- a/tests/mock_and_patch/get_session_path_mock.py +++ b/tests/mock_and_patch/get_session_path_mock.py @@ -9,7 +9,7 @@ from primaite import getLogger _LOGGER = getLogger(__name__) -def get_temp_session_path(session_timestamp: datetime) -> Path: +def temp_user_sessions_path() -> Path: """ Get a temp directory session path the test session will output to.