From 21c06dbea1d3eb60ed90f7a1228973711b15fdf7 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 13 Nov 2023 16:04:25 +0000 Subject: [PATCH 01/19] Remove GATE-related code. --- .azure/azure-build-deploy-docs-pipeline.yml | 11 -- .azure/azure-ci-build-pipeline.yaml | 11 -- README.md | 3 - src/primaite/game/agent/GATE_agents.py | 31 ------ src/primaite/game/agent/interface.py | 8 +- src/primaite/game/session.py | 115 +------------------- src/primaite/utils/start_gate_server.py | 12 -- 7 files changed, 3 insertions(+), 188 deletions(-) delete mode 100644 src/primaite/game/agent/GATE_agents.py delete mode 100644 src/primaite/utils/start_gate_server.py diff --git a/.azure/azure-build-deploy-docs-pipeline.yml b/.azure/azure-build-deploy-docs-pipeline.yml index f60840a7..d9926ba7 100644 --- a/.azure/azure-build-deploy-docs-pipeline.yml +++ b/.azure/azure-build-deploy-docs-pipeline.yml @@ -29,17 +29,6 @@ jobs: pip install -e .[dev] displayName: 'Install PrimAITE for docs autosummary' - - script: | - GATE_WHEEL=$(ls ./GATE/arcd_gate*.whl) - python -m pip install $GATE_WHEEL[dev] - displayName: 'Install GATE' - condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' )) - - - script: | - forfiles /p GATE\ /m *.whl /c "cmd /c python -m pip install @file[dev]" - displayName: 'Install GATE' - condition: eq( variables['Agent.OS'], 'Windows_NT' ) - - script: | primaite setup displayName: 'Perform PrimAITE Setup' diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index efeba284..9070270a 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -81,17 +81,6 @@ stages: displayName: 'Install PrimAITE' condition: eq( variables['Agent.OS'], 'Windows_NT' ) - - script: | - GATE_WHEEL=$(ls ./GATE/arcd_gate*.whl) - python -m pip install $GATE_WHEEL[dev] - displayName: 'Install GATE' - condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' )) - - - script: | - forfiles /p GATE\ /m *.whl /c "cmd /c python -m pip install @file[dev]" - displayName: 'Install GATE' - condition: eq( variables['Agent.OS'], 'Windows_NT' ) - - script: | primaite setup displayName: 'Perform PrimAITE Setup' diff --git a/README.md b/README.md index 9ec8164e..7fc41681 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,6 @@ python3 -m venv .venv attrib +h .venv /s /d # Hides the .venv directory .\.venv\Scripts\activate pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl -pip install GATE/arcd_gate-0.1.0-py3-none-any.whl primaite setup ``` @@ -75,7 +74,6 @@ cd ~/primaite python3 -m venv .venv source .venv/bin/activate pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl -pip install arcd_gate-0.1.0-py3-none-any.whl primaite setup ``` @@ -120,7 +118,6 @@ source venv/bin/activate ```bash python3 -m pip install -e .[dev] -pip install arcd_gate-0.1.0-py3-none-any.whl ``` #### 6. Perform the PrimAITE setup: diff --git a/src/primaite/game/agent/GATE_agents.py b/src/primaite/game/agent/GATE_agents.py deleted file mode 100644 index e50d7831..00000000 --- a/src/primaite/game/agent/GATE_agents.py +++ /dev/null @@ -1,31 +0,0 @@ -# flake8: noqa -from typing import Dict, Optional, Tuple - -from gymnasium.core import ActType, ObsType - -from primaite.game.agent.actions import ActionManager -from primaite.game.agent.interface import AbstractGATEAgent, ObsType -from primaite.game.agent.observations import ObservationSpace -from primaite.game.agent.rewards import RewardFunction - - -class GATERLAgent(AbstractGATEAgent): - ... - # The communication with GATE needs to be handled by the PrimaiteSession, rather than by individual agents, - # because when we are supporting MARL, the actions form multiple agents will have to be batched - - # For example MultiAgentEnv in Ray allows sending a dict of observations of multiple agents, then it will reply - # with the actions for those agents. - - def __init__( - self, - agent_name: str | None, - action_space: ActionManager | None, - observation_space: ObservationSpace | None, - reward_function: RewardFunction | None, - ) -> None: - super().__init__(agent_name, action_space, observation_space, reward_function) - self.most_recent_action: ActType - - def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: - return self.most_recent_action diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 89f27f3f..e3b98777 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -76,7 +76,7 @@ class AbstractAgent(ABC): :return: Action to be taken in the environment. :rtype: Tuple[str, Dict] """ - # in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39, + # in RL agent, this method will send CAOS observation to RL agent, then receive a int 0-39, # then use a bespoke conversion to take 1-40 int back into CAOS action return ("DO_NOTHING", {}) @@ -108,9 +108,3 @@ class RandomAgent(AbstractScriptedAgent): :rtype: Tuple[str, Dict] """ return self.action_space.get_action(self.action_space.space.sample()) - - -class AbstractGATEAgent(AbstractAgent): - """Base class for actors controlled via external messages, such as RL policies.""" - - ... diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index d40d0754..459d9668 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -1,11 +1,7 @@ """PrimAITE session - the main entry point to training agents on PrimAITE.""" from ipaddress import IPv4Address -from typing import Any, Dict, List, Optional, Tuple +from typing import Dict, List, Optional -from arcd_gate.client.gate_client import ActType, GATEClient -from gymnasium import spaces -from gymnasium.core import ActType, ObsType -from gymnasium.spaces.utils import flatten, flatten_space from pydantic import BaseModel from primaite import getLogger @@ -34,111 +30,6 @@ from primaite.simulator.system.services.web_server.web_server import WebServer _LOGGER = getLogger(__name__) -class PrimaiteGATEClient(GATEClient): - """Lightweight wrapper around the GATEClient class that allows PrimAITE to message GATE.""" - - def __init__(self, parent_session: "PrimaiteSession", service_port: int = 50000): - """ - Create a new GATE client for PrimAITE. - - :param parent_session: The parent session object. - :type parent_session: PrimaiteSession - :param service_port: The port on which the GATE service is running. - :type service_port: int, optional - """ - super().__init__(service_port=service_port) - self.parent_session: "PrimaiteSession" = parent_session - - @property - def rl_framework(self) -> str: - """The reinforcement learning framework to use.""" - return self.parent_session.training_options.rl_framework - - @property - def rl_algorithm(self) -> str: - """The reinforcement learning algorithm to use.""" - return self.parent_session.training_options.rl_algorithm - - @property - def seed(self) -> int | None: - """The seed to use for the environment's random number generator.""" - return self.parent_session.training_options.seed - - @property - def n_learn_episodes(self) -> int: - """The number of episodes in each learning run.""" - return self.parent_session.training_options.n_learn_episodes - - @property - def n_learn_steps(self) -> int: - """The number of steps in each learning episode.""" - return self.parent_session.training_options.n_learn_steps - - @property - def n_eval_episodes(self) -> int: - """The number of episodes in each evaluation run.""" - return self.parent_session.training_options.n_eval_episodes - - @property - def n_eval_steps(self) -> int: - """The number of steps in each evaluation episode.""" - return self.parent_session.training_options.n_eval_steps - - @property - def action_space(self) -> spaces.Space: - """The gym action space of the agent.""" - return self.parent_session.rl_agent.action_space.space - - @property - def observation_space(self) -> spaces.Space: - """The gymnasium observation space of the agent.""" - return flatten_space(self.parent_session.rl_agent.observation_space.space) - - def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, Dict]: - """Take a step in the environment. - - This method is called by GATE to advance the simulation by one timestep. - - :param action: The agent's action. - :type action: ActType - :return: The observation, reward, terminal flag, truncated flag, and info dictionary. - :rtype: Tuple[ObsType, float, bool, bool, Dict] - """ - self.parent_session.rl_agent.most_recent_action = action - self.parent_session.step() - state = self.parent_session.simulation.describe_state() - obs = self.parent_session.rl_agent.observation_space.observe(state) - obs = flatten(self.parent_session.rl_agent.observation_space.space, obs) - rew = self.parent_session.rl_agent.reward_function.calculate(state) - term = False - trunc = False - info = {} - return obs, rew, term, trunc, info - - def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> Tuple[ObsType, Dict]: - """Reset the environment. - - This method is called when the environment is initialized and at the end of each episode. - - :param seed: The seed to use for the environment's random number generator. - :type seed: int, optional - :param options: Additional options for the reset. None are used by PrimAITE but this is included for - compatibility with GATE. - :type options: dict[str, Any], optional - :return: The initial observation and an empty info dictionary. - :rtype: Tuple[ObsType, Dict] - """ - self.parent_session.reset() - state = self.parent_session.simulation.describe_state() - obs = self.parent_session.rl_agent.observation_space.observe(state) - obs = flatten(self.parent_session.rl_agent.observation_space.space, obs) - return obs, {} - - def close(self): - """Close the session, this will stop the gate client and close the simulation.""" - self.parent_session.close() - - class PrimaiteSessionOptions(BaseModel): """ Global options which are applicable to all of the agents in the game. @@ -189,12 +80,10 @@ class PrimaiteSession: """Mapping from human-readable application reference to application object. Used for parsing config files.""" self.ref_map_links: Dict[str, Link] = {} """Mapping from human-readable link reference to link object. Used when parsing config files.""" - self.gate_client: PrimaiteGATEClient = PrimaiteGATEClient(self) - """Reference to a GATE Client object, which will send data to GATE service for training RL agent.""" def start_session(self) -> None: """Commence the training session, this gives the GATE client control over the simulation/agent loop.""" - self.gate_client.start() + raise NotImplementedError def step(self): """ diff --git a/src/primaite/utils/start_gate_server.py b/src/primaite/utils/start_gate_server.py deleted file mode 100644 index d91952f2..00000000 --- a/src/primaite/utils/start_gate_server.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Utility script to start the gate server for running PrimAITE in attached mode.""" -from arcd_gate.server.gate_service import GATEService - - -def start_gate_server(): - """Start the gate server.""" - service = GATEService() - service.start() - - -if __name__ == "__main__": - start_gate_server() From 707f2b59af1f1039957952ff2fbccfe78b74edaa Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 13 Nov 2023 16:08:39 +0000 Subject: [PATCH 02/19] Add SB3 RL agent --- src/primaite/game/policy/__init__.py | 0 src/primaite/game/policy/policy.py | 58 ++++++++++++++++++ src/primaite/game/policy/sb3.py | 89 ++++++++++++++++++++++++++++ 3 files changed, 147 insertions(+) create mode 100644 src/primaite/game/policy/__init__.py create mode 100644 src/primaite/game/policy/policy.py create mode 100644 src/primaite/game/policy/sb3.py diff --git a/src/primaite/game/policy/__init__.py b/src/primaite/game/policy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/game/policy/policy.py b/src/primaite/game/policy/policy.py new file mode 100644 index 00000000..8d5a9a08 --- /dev/null +++ b/src/primaite/game/policy/policy.py @@ -0,0 +1,58 @@ +from abc import ABC, abstractclassmethod, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from primaite.game.session import PrimaiteSession + + +class PolicyABC(ABC): + """Base class for reinforcement learning agents.""" + + @abstractmethod + def __init__(self, session: "PrimaiteSession") -> None: + """Initialize a reinforcement learning agent.""" + self.session: "PrimaiteSession" = session + pass + + @abstractmethod + def learn( + self, + ) -> None: + """Train the agent.""" + pass + + @abstractmethod + def eval( + self, + ) -> None: + """Evaluate the agent.""" + pass + + @abstractmethod + def save( + self, + ) -> None: + """Save the agent.""" + pass + + @abstractmethod + def load( + self, + ) -> None: + """Load agent from a file.""" + pass + + def close( + self, + ) -> None: + """Close the agent.""" + pass + + @abstractclassmethod + def from_config( + cls, + ) -> "PolicyABC": + """Create an agent from a config file.""" + pass + + # saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py new file mode 100644 index 00000000..9c6b49ae --- /dev/null +++ b/src/primaite/game/policy/sb3.py @@ -0,0 +1,89 @@ +from typing import Literal, TYPE_CHECKING, Union + +from stable_baselines3 import A2C, PPO +from stable_baselines3.a2c import MlpPolicy as A2C_MLP +from stable_baselines3.ppo import MlpPolicy as PPO_MLP + +from primaite.game.policy.policy import PolicyABC + +if TYPE_CHECKING: + from primaite.game.session import PrimaiteSession + + +class SB3Policy(PolicyABC): + """Single agent RL policy using stable baselines 3.""" + + def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"]): + """Initialize a stable baselines 3 policy.""" + super().__init__(session=session) + + self._agent_class: type[Union[PPO, A2C]] + if algorithm == "PPO": + self._agent_class = PPO + policy = PPO_MLP + elif algorithm == "A2C": + self._agent_class = A2C + policy = A2C_MLP + else: + raise ValueError(f"Unknown algorithm `{algorithm}` for stable_baselines3 policy") + self._agent = self._agent_class( + policy=policy, + env=self.session.env, + n_steps=..., + seed=..., + ) # TODO: populate values once I figure out how to get them from the config / session + + def learn( + self, + ) -> None: + """Train the agent.""" + time_steps = 9999 # TODO: populate values once I figure out how to get them from the config / session + episodes = 10 # TODO: populate values once I figure out how to get them from the config / session + for i in range(episodes): + self._agent.learn(total_timesteps=time_steps) + self._save_checkpoint() + pass + + def eval( + self, + ) -> None: + """Evaluate the agent.""" + time_steps = 9999 # TODO: populate values once I figure out how to get them from the config / session + num_episodes = 10 # TODO: populate values once I figure out how to get them from the config / session + deterministic = True # TODO: populate values once I figure out how to get them from the config / session + + for episode in range(num_episodes): + obs = self.session.env.reset() + for step in range(time_steps): + action, _states = self._agent.predict(obs, deterministic=deterministic) + obs, rewards, truncated, terminated, info = self.session.env.step(action) + + def save( + self, + ) -> None: + """Save the agent.""" + savepath = ( + "temp/path/to/save.pth" # TODO: populate values once I figure out how to get them from the config / session + ) + self._agent.save(savepath) + pass + + def load( + self, + ) -> None: + """Load agent from a checkpoint.""" + self._agent_class.load("temp/path/to/save.pth", env=self.session.env) + pass + + def close( + self, + ) -> None: + """Close the agent.""" + pass + + @classmethod + def from_config( + self, + ) -> "SB3Policy": + """Create an agent from config file.""" + pass From 08e88e52b0bb5961a67bfc1d11f9930681e511a6 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 13 Nov 2023 16:35:35 +0000 Subject: [PATCH 03/19] Begin implementing training loop in session --- src/primaite/game/policy/sb3.py | 1 + src/primaite/game/session.py | 27 ++++++++++++++++++++++++--- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py index 9c6b49ae..151e860d 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/game/policy/sb3.py @@ -52,6 +52,7 @@ class SB3Policy(PolicyABC): num_episodes = 10 # TODO: populate values once I figure out how to get them from the config / session deterministic = True # TODO: populate values once I figure out how to get them from the config / session + # TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB for episode in range(num_episodes): obs = self.session.env.reset() for step in range(time_steps): diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 459d9668..a088d05e 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -9,6 +9,7 @@ from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractAgent, RandomAgent from primaite.game.agent.observations import ObservationSpace from primaite.game.agent.rewards import RewardFunction +from primaite.game.policy.policy import PolicyABC from primaite.simulator.network.hardware.base import Link, NIC, Node from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.router import ACLAction, Router @@ -59,31 +60,51 @@ class PrimaiteSession: def __init__(self): self.simulation: Simulation = Simulation() """Simulation object with which the agents will interact.""" + self.agents: List[AbstractAgent] = [] """List of agents.""" - self.rl_agent: AbstractAgent - """The agent from the list which communicates with GATE to perform reinforcement learning.""" + + # self.rl_agent: AbstractAgent + # """The agent from the list which communicates with GATE to perform reinforcement learning.""" + self.step_counter: int = 0 """Current timestep within the episode.""" + self.episode_counter: int = 0 """Current episode number.""" + self.options: PrimaiteSessionOptions """Special options that apply for the entire game.""" + self.training_options: TrainingOptions """Options specific to agent training.""" + self.policy: PolicyABC + """The reinforcement learning policy.""" + self.ref_map_nodes: Dict[str, Node] = {} """Mapping from unique node reference name to node object. Used when parsing config files.""" + self.ref_map_services: Dict[str, Service] = {} """Mapping from human-readable service reference to service object. Used for parsing config files.""" + self.ref_map_applications: Dict[str, Application] = {} """Mapping from human-readable application reference to application object. Used for parsing config files.""" + self.ref_map_links: Dict[str, Link] = {} """Mapping from human-readable link reference to link object. Used when parsing config files.""" def start_session(self) -> None: """Commence the training session, this gives the GATE client control over the simulation/agent loop.""" - raise NotImplementedError + # n_learn_steps = self.training_options.n_learn_steps + n_learn_episodes = self.training_options.n_learn_episodes + # n_eval_steps = self.training_options.n_eval_steps + n_eval_episodes = self.training_options.n_eval_episodes + if n_learn_episodes > 0: + self.policy.learn() + + if n_eval_episodes > 0: + self.policy.eval() def step(self): """ From 1cb54da2dd91e2067a8214eaa7290ad3819667e1 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 13 Nov 2023 17:12:50 +0000 Subject: [PATCH 04/19] Remove more GATE stuff --- src/primaite/cli.py | 11 ------ .../config/_package_data/example_config.yaml | 2 +- src/primaite/game/policy/policy.py | 8 +--- src/primaite/game/policy/sb3.py | 39 ++++++------------- src/primaite/game/session.py | 9 +++-- 5 files changed, 19 insertions(+), 50 deletions(-) diff --git a/src/primaite/cli.py b/src/primaite/cli.py index a5b3be46..0f17525e 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -95,8 +95,6 @@ def setup(overwrite_existing: bool = True) -> None: WARNING: All user-data will be lost. """ - from arcd_gate.cli import setup as gate_setup - from primaite import getLogger from primaite.setup import reset_demo_notebooks, reset_example_configs @@ -115,9 +113,6 @@ def setup(overwrite_existing: bool = True) -> None: _LOGGER.info("Rebuilding the example notebooks...") reset_example_configs.run(overwrite_existing=True) - _LOGGER.info("Setting up ARCD GATE...") - gate_setup() - _LOGGER.info("PrimAITE setup complete!") @@ -131,14 +126,8 @@ def session( :param config: The path to the config file. Optional, if None, the example config will be used. :type config: Optional[str] """ - from threading import Thread - from primaite.config.load import example_config_path from primaite.main import run - from primaite.utils.start_gate_server import start_gate_server - - server_thread = Thread(target=start_gate_server) - server_thread.start() if not config: config = example_config_path() diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index ee42cf4f..676028bb 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -108,7 +108,7 @@ game_config: - ref: defender team: BLUE - type: GATERLAgent + type: idk??? observation_space: type: UC2BlueObservation diff --git a/src/primaite/game/policy/policy.py b/src/primaite/game/policy/policy.py index 8d5a9a08..404d6f31 100644 --- a/src/primaite/game/policy/policy.py +++ b/src/primaite/game/policy/policy.py @@ -15,16 +15,12 @@ class PolicyABC(ABC): pass @abstractmethod - def learn( - self, - ) -> None: + def learn(self, n_episodes: int, n_time_steps: int) -> None: """Train the agent.""" pass @abstractmethod - def eval( - self, - ) -> None: + def eval(self, n_episodes: int, n_time_steps: int, deterministic: bool) -> None: """Evaluate the agent.""" pass diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py index 151e860d..2d9da1db 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/game/policy/sb3.py @@ -33,35 +33,24 @@ class SB3Policy(PolicyABC): seed=..., ) # TODO: populate values once I figure out how to get them from the config / session - def learn( - self, - ) -> None: + def learn(self, n_episodes: int, n_time_steps: int) -> None: """Train the agent.""" - time_steps = 9999 # TODO: populate values once I figure out how to get them from the config / session - episodes = 10 # TODO: populate values once I figure out how to get them from the config / session - for i in range(episodes): - self._agent.learn(total_timesteps=time_steps) + # TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB + for i in range(n_episodes): + self._agent.learn(total_timesteps=n_time_steps) self._save_checkpoint() pass - def eval( - self, - ) -> None: + def eval(self, n_episodes: int, n_time_steps: int, deterministic: bool) -> None: """Evaluate the agent.""" - time_steps = 9999 # TODO: populate values once I figure out how to get them from the config / session - num_episodes = 10 # TODO: populate values once I figure out how to get them from the config / session - deterministic = True # TODO: populate values once I figure out how to get them from the config / session - # TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB - for episode in range(num_episodes): + for episode in range(n_episodes): obs = self.session.env.reset() - for step in range(time_steps): + for step in range(n_time_steps): action, _states = self._agent.predict(obs, deterministic=deterministic) obs, rewards, truncated, terminated, info = self.session.env.step(action) - def save( - self, - ) -> None: + def save(self) -> None: """Save the agent.""" savepath = ( "temp/path/to/save.pth" # TODO: populate values once I figure out how to get them from the config / session @@ -69,22 +58,16 @@ class SB3Policy(PolicyABC): self._agent.save(savepath) pass - def load( - self, - ) -> None: + def load(self) -> None: """Load agent from a checkpoint.""" self._agent_class.load("temp/path/to/save.pth", env=self.session.env) pass - def close( - self, - ) -> None: + def close(self) -> None: """Close the agent.""" pass @classmethod - def from_config( - self, - ) -> "SB3Policy": + def from_config(self) -> "SB3Policy": """Create an agent from config file.""" pass diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index a088d05e..9d241932 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -96,15 +96,16 @@ class PrimaiteSession: def start_session(self) -> None: """Commence the training session, this gives the GATE client control over the simulation/agent loop.""" - # n_learn_steps = self.training_options.n_learn_steps + n_learn_steps = self.training_options.n_learn_steps n_learn_episodes = self.training_options.n_learn_episodes - # n_eval_steps = self.training_options.n_eval_steps + n_eval_steps = self.training_options.n_eval_steps n_eval_episodes = self.training_options.n_eval_episodes + deterministic_eval = True # TODO: get this value from config if n_learn_episodes > 0: - self.policy.learn() + self.policy.learn(n_episodes=n_learn_episodes, n_time_steps=n_learn_steps) if n_eval_episodes > 0: - self.policy.eval() + self.policy.eval(n_episodes=n_eval_episodes, n_time_steps=n_eval_steps, deterministic=deterministic_eval) def step(self): """ From e6ead6e53248695f0699eb63a83dcada0d154b67 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 14 Nov 2023 15:10:07 +0000 Subject: [PATCH 05/19] Update agent interface to work better with envs --- .../config/_package_data/example_config.yaml | 8 +- src/primaite/game/agent/interface.py | 70 ++++++-- src/primaite/game/agent/observations.py | 11 +- src/primaite/game/agent/rewards.py | 6 +- src/primaite/game/policy/policy.py | 69 +++++--- src/primaite/game/policy/sb3.py | 15 +- src/primaite/game/session.py | 158 +++++++++++++----- 7 files changed, 240 insertions(+), 97 deletions(-) diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 676028bb..0c39333c 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -4,8 +4,12 @@ training_config: seed: 333 n_learn_episodes: 20 n_learn_steps: 128 - n_eval_episodes: 20 + n_eval_episodes: 5 n_eval_steps: 128 + deterministic_eval: false + n_agents: 1 + agent_references: + - defender game_config: @@ -108,7 +112,7 @@ game_config: - ref: defender team: BLUE - type: idk??? + type: RLAgent observation_space: type: UC2BlueObservation diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index e3b98777..75d209ce 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -1,15 +1,13 @@ """Interface for agents.""" from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple, TypeAlias, Union +from typing import Dict, List, Optional, Tuple -import numpy as np +from gymnasium.core import ActType, ObsType from primaite.game.agent.actions import ActionManager -from primaite.game.agent.observations import ObservationSpace +from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction -ObsType: TypeAlias = Union[Dict, np.ndarray] - class AbstractAgent(ABC): """Base class for scripted and RL agents.""" @@ -18,7 +16,7 @@ class AbstractAgent(ABC): self, agent_name: Optional[str], action_space: Optional[ActionManager], - observation_space: Optional[ObservationSpace], + observation_space: Optional[ObservationManager], reward_function: Optional[RewardFunction], ) -> None: """ @@ -34,24 +32,24 @@ class AbstractAgent(ABC): :type reward_function: Optional[RewardFunction] """ self.agent_name: str = agent_name or "unnamed_agent" - self.action_space: Optional[ActionManager] = action_space - self.observation_space: Optional[ObservationSpace] = observation_space + self.action_manager: Optional[ActionManager] = action_space + self.observation_manager: Optional[ObservationManager] = observation_space self.reward_function: Optional[RewardFunction] = reward_function # exection definiton converts CAOS action to Primaite simulator request, sometimes having to enrich the info # by for example specifying target ip addresses, or converting a node ID into a uuid self.execution_definition = None - def convert_state_to_obs(self, state: Dict) -> ObsType: + def update_observation(self, state: Dict) -> ObsType: """ Convert a state from the simulator into an observation for the agent using the observation space. state : dict state directly from simulation.describe_state output : dict state according to CAOS. """ - return self.observation_space.observe(state) + return self.observation_manager.update(state) - def calculate_reward_from_state(self, state: Dict) -> float: + def update_reward(self, state: Dict) -> float: """ Use the reward function to calculate a reward from the state. @@ -60,10 +58,10 @@ class AbstractAgent(ABC): :return: Reward from the state. :rtype: float """ - return self.reward_function.calculate(state) + return self.reward_function.update(state) @abstractmethod - def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: """ Return an action to be taken in the environment. @@ -84,7 +82,7 @@ class AbstractAgent(ABC): # this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator. # therefore the execution definition needs to be a mapping from CAOS into SIMULATOR """Format action into format expected by the simulator, and apply execution definition if applicable.""" - request = self.action_space.form_request(action_identifier=action, action_options=options) + request = self.action_manager.form_request(action_identifier=action, action_options=options) return request @@ -97,7 +95,7 @@ class AbstractScriptedAgent(AbstractAgent): class RandomAgent(AbstractScriptedAgent): """Agent that ignores its observation and acts completely at random.""" - def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: """Randomly sample an action from the action space. :param obs: _description_ @@ -107,4 +105,44 @@ class RandomAgent(AbstractScriptedAgent): :return: _description_ :rtype: Tuple[str, Dict] """ - return self.action_space.get_action(self.action_space.space.sample()) + return self.action_manager.get_action(self.action_manager.space.sample()) + + +class ProxyAgent(AbstractAgent): + """Agent that sends observations to an RL model and receives actions from that model.""" + + def __init__( + self, + agent_name: Optional[str], + action_space: Optional[ActionManager], + observation_space: Optional[ObservationManager], + reward_function: Optional[RewardFunction], + ) -> None: + super().__init__( + agent_name=agent_name, + action_space=action_space, + observation_space=observation_space, + reward_function=reward_function, + ) + self.most_recent_action: ActType + + def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: + """ + Return the agent's most recent action, formatted in CAOS format. + + :param obs: Observation for the agent. Not used by ProxyAgents, but required by the interface. + :type obs: ObsType + :param reward: Reward value for the agent. Not used by ProxyAgents, defaults to None. + :type reward: float, optional + :return: Action to be taken in CAOS format. + :rtype: Tuple[str, Dict] + """ + return self.action_manager.get_action(self.most_recent_action) + + def store_action(self, action: ActType): + """ + Store the most recent action taken by the agent. + + The environment is responsible for calling this method when it receives an action from the agent policy. + """ + self.most_recent_action = action diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index a3bafeea..a74771c0 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING from gymnasium import spaces +from gymnasium.core import ObsType from primaite import getLogger from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE @@ -926,7 +927,7 @@ class UC2GreenObservation(NullObservation): pass -class ObservationSpace: +class ObservationManager: """ Manage the observations of an Agent. @@ -947,15 +948,17 @@ class ObservationSpace: :type observation: AbstractObservation """ self.obs: AbstractObservation = observation + self.current_observation: ObsType - def observe(self, state: Dict) -> Dict: + def update(self, state: Dict) -> Dict: """ Generate observation based on the current state of the simulation. :param state: Simulation state dictionary :type state: Dict """ - return self.obs.observe(state) + self.current_observation = self.obs.observe(state) + return self.current_observation @property def space(self) -> None: @@ -963,7 +966,7 @@ class ObservationSpace: return self.obs.space @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationSpace": + def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationManager": """Create observation space from a config. :param config: Dictionary containing the configuration for this observation space. diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 6c408ff9..49d56e67 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -238,6 +238,7 @@ class RewardFunction: """Initialise the reward function object.""" self.reward_components: List[Tuple[AbstractReward, float]] = [] "attribute reward_components keeps track of reward components and the weights assigned to each." + self.current_reward: float def regsiter_component(self, component: AbstractReward, weight: float = 1.0) -> None: """Add a reward component to the reward function. @@ -249,7 +250,7 @@ class RewardFunction: """ self.reward_components.append((component, weight)) - def calculate(self, state: Dict) -> float: + def update(self, state: Dict) -> float: """Calculate the overall reward for the current state. :param state: The current state of the simulation. @@ -260,7 +261,8 @@ class RewardFunction: comp = comp_and_weight[0] weight = comp_and_weight[1] total += weight * comp.calculate(state=state) - return total + self.current_reward = total + return self.current_reward @classmethod def from_config(cls, config: Dict, session: "PrimaiteSession") -> "RewardFunction": diff --git a/src/primaite/game/policy/policy.py b/src/primaite/game/policy/policy.py index 404d6f31..5669a4ff 100644 --- a/src/primaite/game/policy/policy.py +++ b/src/primaite/game/policy/policy.py @@ -1,18 +1,47 @@ -from abc import ABC, abstractclassmethod, abstractmethod -from typing import TYPE_CHECKING +"""Base class and common logic for RL policies.""" +from abc import ABC, abstractmethod +from typing import Any, Dict, TYPE_CHECKING if TYPE_CHECKING: - from primaite.game.session import PrimaiteSession + from primaite.game.session import PrimaiteSession, TrainingOptions class PolicyABC(ABC): """Base class for reinforcement learning agents.""" + _registry: Dict[str, type["PolicyABC"]] = {} + """ + Registry of policy types, keyed by name. + + Automatically populated when PolicyABC subclasses are defined. Used for defining from_config. + """ + + def __init_subclass__(cls, name: str, **kwargs: Any) -> None: + """ + Register a policy subclass. + + :param name: Identifier used by from_config to create an instance of the policy. + :type name: str + :raises ValueError: When attempting to create a policy with a duplicate name. + """ + super().__init_subclass__(**kwargs) + if name in cls._registry: + raise ValueError(f"Duplicate policy name {name}") + cls._registry[name] = cls + return + @abstractmethod def __init__(self, session: "PrimaiteSession") -> None: - """Initialize a reinforcement learning agent.""" + """ + Initialize a reinforcement learning policy. + + :param session: The session context. + :type session: PrimaiteSession + :param agents: The agents to train. + :type agents: List[RLAgent] + """ self.session: "PrimaiteSession" = session - pass + """Reference to the session.""" @abstractmethod def learn(self, n_episodes: int, n_time_steps: int) -> None: @@ -25,30 +54,30 @@ class PolicyABC(ABC): pass @abstractmethod - def save( - self, - ) -> None: + def save(self) -> None: """Save the agent.""" pass @abstractmethod - def load( - self, - ) -> None: + def load(self) -> None: """Load agent from a file.""" pass - def close( - self, - ) -> None: + def close(self) -> None: """Close the agent.""" pass - @abstractclassmethod - def from_config( - cls, - ) -> "PolicyABC": - """Create an agent from a config file.""" - pass + @classmethod + def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC": + """ + Create an RL policy from a config by calling the relevant subclass's from_config method. + + Subclasses should not call super().from_config(), they should just handle creation form config. + """ + # Assume that basically the contents of training_config are passed into here. + # I should really define a config schema class using pydantic. + + PolicyType = cls._registry[config.rl_framework] + return PolicyType.from_config() # saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py index 2d9da1db..73df1b98 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/game/policy/sb3.py @@ -1,4 +1,5 @@ -from typing import Literal, TYPE_CHECKING, Union +"""Stable baselines 3 policy.""" +from typing import Literal, Optional, TYPE_CHECKING, Union from stable_baselines3 import A2C, PPO from stable_baselines3.a2c import MlpPolicy as A2C_MLP @@ -7,13 +8,13 @@ from stable_baselines3.ppo import MlpPolicy as PPO_MLP from primaite.game.policy.policy import PolicyABC if TYPE_CHECKING: - from primaite.game.session import PrimaiteSession + from primaite.game.session import PrimaiteSession, TrainingOptions class SB3Policy(PolicyABC): """Single agent RL policy using stable baselines 3.""" - def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"]): + def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): """Initialize a stable baselines 3 policy.""" super().__init__(session=session) @@ -29,8 +30,8 @@ class SB3Policy(PolicyABC): self._agent = self._agent_class( policy=policy, env=self.session.env, - n_steps=..., - seed=..., + n_steps=128, # this is not the number of steps in an episode, but the number of steps in a batch + seed=seed, ) # TODO: populate values once I figure out how to get them from the config / session def learn(self, n_episodes: int, n_time_steps: int) -> None: @@ -68,6 +69,6 @@ class SB3Policy(PolicyABC): pass @classmethod - def from_config(self) -> "SB3Policy": + def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy": """Create an agent from config file.""" - pass + return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed) diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 9d241932..5556dd87 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -1,13 +1,15 @@ """PrimAITE session - the main entry point to training agents on PrimAITE.""" from ipaddress import IPv4Address -from typing import Dict, List, Optional +from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple +import gymnasium +from gymnasium.core import ActType, ObsType from pydantic import BaseModel from primaite import getLogger from primaite.game.agent.actions import ActionManager -from primaite.game.agent.interface import AbstractAgent, RandomAgent -from primaite.game.agent.observations import ObservationSpace +from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent +from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction from primaite.game.policy.policy import PolicyABC from primaite.simulator.network.hardware.base import Link, NIC, Node @@ -31,6 +33,58 @@ from primaite.simulator.system.services.web_server.web_server import WebServer _LOGGER = getLogger(__name__) +class PrimaiteEnv(gymnasium.Env): + """ + Thin wrapper env to provide agents with a gymnasium API. + + This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some + assumptions about the agent list always having a list of length 1. + """ + + def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]): + """Initialise the environment.""" + super().__init__() + self.session: "PrimaiteSession" = session + self.agent: ProxyAgent = agents[0] + + def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: + """Perform a step in the environment.""" + # make ProxyAgent store the action chosen my the RL policy + self.agent.store_action(action) + # apply_agent_actions accesses the action we just stored + self.session.apply_agent_actions() + self.session.advance_timestep() + state = self.session.get_sim_state() + self.session.update_agents(state) + + next_obs = self.agent.observation_manager.current_observation + reward = self.agent.reward_function.current_reward + terminated = False + truncated = ... + info = {} + + return next_obs, reward, terminated, truncated, info + + def reset(self, seed: Optional[int] = None) -> tuple[ObsType, dict[str, Any]]: + """Reset the environment.""" + self.session.reset() + state = self.session.get_sim_state() + self.session.update_agents(state) + next_obs = self.agent.observation_manager.current_observation + info = {} + return next_obs, info + + @property + def action_space(self) -> gymnasium.Space: + """Return the action space of the environment.""" + return self.agent.action_manager.action_space + + @property + def observation_space(self) -> gymnasium.Space: + """Return the observation space of the environment.""" + return self.agent.observation_manager.observation_space + + class PrimaiteSessionOptions(BaseModel): """ Global options which are applicable to all of the agents in the game. @@ -45,28 +99,29 @@ class PrimaiteSessionOptions(BaseModel): class TrainingOptions(BaseModel): """Options for training the RL agent.""" - rl_framework: str - rl_algorithm: str + rl_framework: Literal["SB3", "RLLIB"] + rl_algorithm: Literal["PPO", "A2C"] seed: Optional[int] n_learn_episodes: int n_learn_steps: int - n_eval_episodes: int - n_eval_steps: int + n_eval_episodes: int = 0 + n_eval_steps: Optional[int] = None + deterministic_eval: bool + n_agents: int + agent_references: List[str] class PrimaiteSession: - """The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and connections to ARCD GATE.""" + """The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments.""" def __init__(self): + """Initialise a PrimaiteSession object.""" self.simulation: Simulation = Simulation() """Simulation object with which the agents will interact.""" self.agents: List[AbstractAgent] = [] """List of agents.""" - # self.rl_agent: AbstractAgent - # """The agent from the list which communicates with GATE to perform reinforcement learning.""" - self.step_counter: int = 0 """Current timestep within the episode.""" @@ -94,8 +149,10 @@ class PrimaiteSession: self.ref_map_links: Dict[str, Link] = {} """Mapping from human-readable link reference to link object. Used when parsing config files.""" + # self.env: + def start_session(self) -> None: - """Commence the training session, this gives the GATE client control over the simulation/agent loop.""" + """Commence the training session.""" n_learn_steps = self.training_options.n_learn_steps n_learn_episodes = self.training_options.n_learn_episodes n_eval_steps = self.training_options.n_eval_steps @@ -119,40 +176,47 @@ class PrimaiteSession: 4. Each agent chooses an action based on the observation. 5. Each agent converts the action to a request. 6. The simulation applies the requests. + + Warning: This method should only be used with scripted agents. For RL agents, the environment that the agent + interacts with should implement a step method that calls methods used by this method. For example, if using a + single-agent gym, make sure to update the ProxyAgent's action with the action before calling + ``self.apply_agent_actions()``. """ _LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}") - # currently designed with assumption that all agents act once per step in order + # Get the current state of the simulation + sim_state = self.get_sim_state() + + # Update agents' observations and rewards based on the current state + self.update_agents(sim_state) + + # Apply all actions to simulation as requests + self.apply_agent_actions() + + # Advance timestep + self.advance_timestep() + + def get_sim_state(self) -> Dict: + """Get the current state of the simulation.""" + return self.simulation.describe_state() + + def update_agents(self, state: Dict) -> None: + """Update agents' observations and rewards based on the current state.""" for agent in self.agents: - # 3. primaite session asks simulation to provide initial state - # 4. primate session gives state to all agents - # 5. primaite session asks agents to produce an action based on most recent state - _LOGGER.debug(f"Sending simulation state to agent {agent.agent_name}") - sim_state = self.simulation.describe_state() + agent.update_observation(state) + agent.update_reward(state) - # 6. each agent takes most recent state and converts it to CAOS observation - agent_obs = agent.convert_state_to_obs(sim_state) + def apply_agent_actions(self) -> None: + """Apply all actions to simulation as requests.""" + for agent in self.agents: + obs = agent.observation_manager.current_observation + rew = agent.reward_function.current_reward + action_choice, options = agent.get_action(obs, rew) + request = agent.format_request(action_choice, options) + self.simulation.apply_request(request) - # 7. meanwhile each agent also takes state and calculates reward - agent_reward = agent.calculate_reward_from_state(sim_state) - - # 8. each agent takes observation and applies decision rule to observation to create CAOS - # action(such as random, rulebased, or send to GATE) (therefore, converting CAOS action - # to discrete(40) is only necessary for purposes of RL learning, therefore that bit of - # code should live inside of the GATE agent subclass) - # gets action in CAOS format - _LOGGER.debug("Getting agent action") - agent_action, action_options = agent.get_action(agent_obs, agent_reward) - # 9. CAOS action is converted into request (extra information might be needed to enrich - # the request, this is what the execution definition is there for) - _LOGGER.debug(f"Formatting agent action {agent_action}") # maybe too many debug log statements - agent_request = agent.format_request(agent_action, action_options) - - # 10. primaite session receives the action from the agents and asks the simulation to apply each - _LOGGER.debug(f"Sending request to simulation: {agent_request}") - self.simulation.apply_request(agent_request) - - _LOGGER.debug(f"Initiating simulation step {self.step_counter}") + def advance_timestep(self) -> None: + """Advance timestep.""" self.simulation.apply_timestep(self.step_counter) self.step_counter += 1 @@ -161,7 +225,7 @@ class PrimaiteSession: return NotImplemented def close(self) -> None: - """Close the session, this will stop the gate client and close the simulation.""" + """Close the session, this will stop the env and close the simulation.""" return NotImplemented @classmethod @@ -169,7 +233,7 @@ class PrimaiteSession: """Create a PrimaiteSession object from a config dictionary. The config dictionary should have the following top-level keys: - 1. training_config: options for training the RL agent. Used by GATE. + 1. training_config: options for training the RL agent. 2. game_config: options for the game itself. Used by PrimaiteSession. 3. simulation: defines the network topology and the initial state of the simulation. @@ -323,7 +387,7 @@ class PrimaiteSession: reward_function_cfg = agent_cfg["reward_function"] # CREATE OBSERVATION SPACE - obs_space = ObservationSpace.from_config(observation_space_cfg, sess) + obs_space = ObservationManager.from_config(observation_space_cfg, sess) # CREATE ACTION SPACE action_space_cfg["options"]["node_uuids"] = [] @@ -359,15 +423,14 @@ class PrimaiteSession: reward_function=rew_function, ) sess.agents.append(new_agent) - elif agent_type == "GATERLAgent": - new_agent = RandomAgent( + elif agent_type == "RLAgent": + new_agent = ProxyAgent( agent_name=agent_cfg["ref"], action_space=action_space, observation_space=obs_space, reward_function=rew_function, ) sess.agents.append(new_agent) - sess.rl_agent = new_agent elif agent_type == "RedDatabaseCorruptingAgent": new_agent = RandomAgent( agent_name=agent_cfg["ref"], @@ -379,4 +442,7 @@ class PrimaiteSession: else: print("agent type not found") + # CREATE POLICY + sess.policy = PolicyABC.from_config(sess.training_options) + return sess From c8f2f193bd609a665f847f926beeebca55bf10b2 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 15 Nov 2023 12:52:18 +0000 Subject: [PATCH 06/19] Implement agent training with sb3 --- .../config/_package_data/example_config.yaml | 2 +- src/primaite/game/policy/__init__.py | 3 ++ src/primaite/game/policy/policy.py | 10 +++--- src/primaite/game/policy/sb3.py | 9 +++-- src/primaite/game/session.py | 34 +++++++++++++------ 5 files changed, 39 insertions(+), 19 deletions(-) diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 0c39333c..17e5f5a5 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -112,7 +112,7 @@ game_config: - ref: defender team: BLUE - type: RLAgent + type: ProxyAgent observation_space: type: UC2BlueObservation diff --git a/src/primaite/game/policy/__init__.py b/src/primaite/game/policy/__init__.py index e69de29b..29196112 100644 --- a/src/primaite/game/policy/__init__.py +++ b/src/primaite/game/policy/__init__.py @@ -0,0 +1,3 @@ +from primaite.game.policy.sb3 import SB3Policy + +__all__ = ["SB3Policy"] diff --git a/src/primaite/game/policy/policy.py b/src/primaite/game/policy/policy.py index 5669a4ff..4c8dc447 100644 --- a/src/primaite/game/policy/policy.py +++ b/src/primaite/game/policy/policy.py @@ -16,7 +16,7 @@ class PolicyABC(ABC): Automatically populated when PolicyABC subclasses are defined. Used for defining from_config. """ - def __init_subclass__(cls, name: str, **kwargs: Any) -> None: + def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: """ Register a policy subclass. @@ -25,9 +25,9 @@ class PolicyABC(ABC): :raises ValueError: When attempting to create a policy with a duplicate name. """ super().__init_subclass__(**kwargs) - if name in cls._registry: - raise ValueError(f"Duplicate policy name {name}") - cls._registry[name] = cls + if identifier in cls._registry: + raise ValueError(f"Duplicate policy name {identifier}") + cls._registry[identifier] = cls return @abstractmethod @@ -78,6 +78,6 @@ class PolicyABC(ABC): # I should really define a config schema class using pydantic. PolicyType = cls._registry[config.rl_framework] - return PolicyType.from_config() + return PolicyType.from_config(config=config, session=session) # saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py index 73df1b98..391b3115 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/game/policy/sb3.py @@ -1,6 +1,7 @@ """Stable baselines 3 policy.""" from typing import Literal, Optional, TYPE_CHECKING, Union +import numpy as np from stable_baselines3 import A2C, PPO from stable_baselines3.a2c import MlpPolicy as A2C_MLP from stable_baselines3.ppo import MlpPolicy as PPO_MLP @@ -11,7 +12,7 @@ if TYPE_CHECKING: from primaite.game.session import PrimaiteSession, TrainingOptions -class SB3Policy(PolicyABC): +class SB3Policy(PolicyABC, identifier="SB3"): """Single agent RL policy using stable baselines 3.""" def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): @@ -39,16 +40,18 @@ class SB3Policy(PolicyABC): # TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB for i in range(n_episodes): self._agent.learn(total_timesteps=n_time_steps) - self._save_checkpoint() + # self._save_checkpoint() pass def eval(self, n_episodes: int, n_time_steps: int, deterministic: bool) -> None: """Evaluate the agent.""" # TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB for episode in range(n_episodes): - obs = self.session.env.reset() + obs, info = self.session.env.reset() for step in range(n_time_steps): action, _states = self._agent.predict(obs, deterministic=deterministic) + if isinstance(action, np.ndarray): + action = np.int64(action) obs, rewards, truncated, terminated, info = self.session.env.step(action) def save(self) -> None: diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 5556dd87..8017d0d4 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -33,7 +33,7 @@ from primaite.simulator.system.services.web_server.web_server import WebServer _LOGGER = getLogger(__name__) -class PrimaiteEnv(gymnasium.Env): +class PrimaiteGymEnv(gymnasium.Env): """ Thin wrapper env to provide agents with a gymnasium API. @@ -57,10 +57,10 @@ class PrimaiteEnv(gymnasium.Env): state = self.session.get_sim_state() self.session.update_agents(state) - next_obs = self.agent.observation_manager.current_observation + next_obs = self._get_obs() reward = self.agent.reward_function.current_reward terminated = False - truncated = ... + truncated = False info = {} return next_obs, reward, terminated, truncated, info @@ -70,19 +70,25 @@ class PrimaiteEnv(gymnasium.Env): self.session.reset() state = self.session.get_sim_state() self.session.update_agents(state) - next_obs = self.agent.observation_manager.current_observation + next_obs = self._get_obs() info = {} return next_obs, info @property def action_space(self) -> gymnasium.Space: """Return the action space of the environment.""" - return self.agent.action_manager.action_space + return self.agent.action_manager.space @property def observation_space(self) -> gymnasium.Space: """Return the observation space of the environment.""" - return self.agent.observation_manager.observation_space + return gymnasium.spaces.flatten_space(self.agent.observation_manager.space) + + def _get_obs(self) -> ObsType: + """Return the current observation.""" + unflat_space = self.agent.observation_manager.space + unflat_obs = self.agent.observation_manager.current_observation + return gymnasium.spaces.flatten(unflat_space, unflat_obs) class PrimaiteSessionOptions(BaseModel): @@ -122,6 +128,9 @@ class PrimaiteSession: self.agents: List[AbstractAgent] = [] """List of agents.""" + self.rl_agents: List[ProxyAgent] = [] + """Subset of agent list including only the reinforcement learning agents.""" + self.step_counter: int = 0 """Current timestep within the episode.""" @@ -149,7 +158,8 @@ class PrimaiteSession: self.ref_map_links: Dict[str, Link] = {} """Mapping from human-readable link reference to link object. Used when parsing config files.""" - # self.env: + self.env: PrimaiteGymEnv + """The environment that the agent can consume. Could be PrimaiteEnv.""" def start_session(self) -> None: """Commence the training session.""" @@ -423,7 +433,7 @@ class PrimaiteSession: reward_function=rew_function, ) sess.agents.append(new_agent) - elif agent_type == "RLAgent": + elif agent_type == "ProxyAgent": new_agent = ProxyAgent( agent_name=agent_cfg["ref"], action_space=action_space, @@ -431,6 +441,7 @@ class PrimaiteSession: reward_function=rew_function, ) sess.agents.append(new_agent) + sess.rl_agents.append(new_agent) elif agent_type == "RedDatabaseCorruptingAgent": new_agent = RandomAgent( agent_name=agent_cfg["ref"], @@ -442,7 +453,10 @@ class PrimaiteSession: else: print("agent type not found") - # CREATE POLICY - sess.policy = PolicyABC.from_config(sess.training_options) + # CREATE ENVIRONMENT + sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents) + + # CREATE POLICY + sess.policy = PolicyABC.from_config(sess.training_options, session=sess) return sess From 6182b53bfd6858d8c33d70ad8adfa1f8ca2dbabb Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 15 Nov 2023 14:49:44 +0000 Subject: [PATCH 07/19] Fix incorrect number of steps per episode --- src/primaite/__init__.py | 1 + .../config/_package_data/example_config.yaml | 5 +- src/primaite/game/policy/sb3.py | 30 ++++----- src/primaite/game/session.py | 65 +++++++++++++++---- 4 files changed, 69 insertions(+), 32 deletions(-) diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index 30fc9ab9..789517f7 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -133,6 +133,7 @@ def _get_primaite_config() -> Dict: "DEBUG": logging.DEBUG, "INFO": logging.INFO, "WARN": logging.WARN, + "WARNING": logging.WARN, "ERROR": logging.ERROR, "CRITICAL": logging.CRITICAL, } diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 17e5f5a5..dca9620f 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -2,10 +2,9 @@ training_config: rl_framework: SB3 rl_algorithm: PPO seed: 333 - n_learn_episodes: 20 - n_learn_steps: 128 + n_learn_steps: 2560 n_eval_episodes: 5 - n_eval_steps: 128 + max_steps_per_episode: 128 deterministic_eval: false n_agents: 1 agent_references: diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py index 391b3115..ff710944 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/game/policy/sb3.py @@ -1,9 +1,9 @@ """Stable baselines 3 policy.""" from typing import Literal, Optional, TYPE_CHECKING, Union -import numpy as np from stable_baselines3 import A2C, PPO from stable_baselines3.a2c import MlpPolicy as A2C_MLP +from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.ppo import MlpPolicy as PPO_MLP from primaite.game.policy.policy import PolicyABC @@ -33,26 +33,22 @@ class SB3Policy(PolicyABC, identifier="SB3"): env=self.session.env, n_steps=128, # this is not the number of steps in an episode, but the number of steps in a batch seed=seed, - ) # TODO: populate values once I figure out how to get them from the config / session + ) - def learn(self, n_episodes: int, n_time_steps: int) -> None: + def learn(self, n_time_steps: int) -> None: """Train the agent.""" - # TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB - for i in range(n_episodes): - self._agent.learn(total_timesteps=n_time_steps) - # self._save_checkpoint() - pass + self._agent.learn(total_timesteps=n_time_steps) - def eval(self, n_episodes: int, n_time_steps: int, deterministic: bool) -> None: + def eval(self, n_episodes: int, deterministic: bool) -> None: """Evaluate the agent.""" - # TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB - for episode in range(n_episodes): - obs, info = self.session.env.reset() - for step in range(n_time_steps): - action, _states = self._agent.predict(obs, deterministic=deterministic) - if isinstance(action, np.ndarray): - action = np.int64(action) - obs, rewards, truncated, terminated, info = self.session.env.step(action) + reward_data = evaluate_policy( + self._agent, + self.session.env, + n_eval_episodes=n_episodes, + deterministic=deterministic, + return_episode_rewards=True, + ) + print(reward_data) def save(self) -> None: """Save the agent.""" diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 8017d0d4..e85328ef 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -1,7 +1,9 @@ """PrimAITE session - the main entry point to training agents on PrimAITE.""" +from enum import Enum from ipaddress import IPv4Address from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple +import enlighten import gymnasium from gymnasium.core import ActType, ObsType from pydantic import BaseModel @@ -30,6 +32,8 @@ from primaite.simulator.system.services.red_services.data_manipulation_bot impor from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.web_server.web_server import WebServer +progress_bar_manager = enlighten.get_manager() + _LOGGER = getLogger(__name__) @@ -60,7 +64,7 @@ class PrimaiteGymEnv(gymnasium.Env): next_obs = self._get_obs() reward = self.agent.reward_function.current_reward terminated = False - truncated = False + truncated = self.session.calculate_truncated() info = {} return next_obs, reward, terminated, truncated, info @@ -108,15 +112,22 @@ class TrainingOptions(BaseModel): rl_framework: Literal["SB3", "RLLIB"] rl_algorithm: Literal["PPO", "A2C"] seed: Optional[int] - n_learn_episodes: int n_learn_steps: int - n_eval_episodes: int = 0 - n_eval_steps: Optional[int] = None + n_eval_episodes: Optional[int] = None + max_steps_per_episode: int deterministic_eval: bool n_agents: int agent_references: List[str] +class SessionMode(Enum): + """Helper to keep track of the current session mode.""" + + TRAIN = "train" + EVAL = "eval" + MANUAL = "manual" + + class PrimaiteSession: """The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments.""" @@ -161,18 +172,31 @@ class PrimaiteSession: self.env: PrimaiteGymEnv """The environment that the agent can consume. Could be PrimaiteEnv.""" + self.training_progress_bar: Optional[enlighten.Counter] = None + """training steps counter""" + + self.eval_progress_bar: Optional[enlighten.Counter] = None + """evaluation episodes counter""" + + self.mode: SessionMode = SessionMode.MANUAL + def start_session(self) -> None: """Commence the training session.""" + self.mode = SessionMode.TRAIN + self.training_progress_bar = progress_bar_manager.counter( + total=self.training_options.n_learn_steps, desc="Training steps" + ) n_learn_steps = self.training_options.n_learn_steps - n_learn_episodes = self.training_options.n_learn_episodes - n_eval_steps = self.training_options.n_eval_steps n_eval_episodes = self.training_options.n_eval_episodes - deterministic_eval = True # TODO: get this value from config - if n_learn_episodes > 0: - self.policy.learn(n_episodes=n_learn_episodes, n_time_steps=n_learn_steps) + deterministic_eval = self.training_options.deterministic_eval + self.policy.learn(n_time_steps=n_learn_steps) + self.mode = SessionMode.EVAL if n_eval_episodes > 0: - self.policy.eval(n_episodes=n_eval_episodes, n_time_steps=n_eval_steps, deterministic=deterministic_eval) + self.eval_progress_bar = progress_bar_manager.counter(total=n_eval_episodes, desc="Evaluation episodes") + self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval) + + self.mode = SessionMode.MANUAL def step(self): """ @@ -227,12 +251,29 @@ class PrimaiteSession: def advance_timestep(self) -> None: """Advance timestep.""" - self.simulation.apply_timestep(self.step_counter) self.step_counter += 1 + _LOGGER.debug(f"Advancing timestep to {self.step_counter} ") + self.simulation.apply_timestep(self.step_counter) + + if self.training_progress_bar and self.mode == SessionMode.TRAIN: + self.training_progress_bar.update() + + def calculate_truncated(self) -> bool: + """Calculate whether the episode is truncated.""" + current_step = self.step_counter + max_steps = self.training_options.max_steps_per_episode + if current_step >= max_steps: + return True + return False def reset(self) -> None: """Reset the session, this will reset the simulation.""" - return NotImplemented + self.episode_counter += 1 + self.step_counter = 0 + _LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}") + self.simulation.reset_component_for_episode(self.episode_counter) + if self.eval_progress_bar and self.mode == SessionMode.EVAL: + self.eval_progress_bar.update() def close(self) -> None: """Close the session, this will stop the env and close the simulation.""" From 64e8b3bceaa5ef75208791757d24f01c85b27db7 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 15 Nov 2023 16:04:16 +0000 Subject: [PATCH 08/19] Add basic primaite session e2e tests --- .../assets/configs/bad_primaite_session.yaml | 725 +++++++++++++++++ .../configs/eval_only_primaite_session.yaml | 729 ++++++++++++++++++ .../assets/configs/test_primaite_session.yaml | 729 ++++++++++++++++++ .../configs/train_only_primaite_session.yaml | 729 ++++++++++++++++++ tests/conftest.py | 104 +-- .../test_primaite_session.py | 51 ++ 6 files changed, 2981 insertions(+), 86 deletions(-) create mode 100644 tests/assets/configs/bad_primaite_session.yaml create mode 100644 tests/assets/configs/eval_only_primaite_session.yaml create mode 100644 tests/assets/configs/test_primaite_session.yaml create mode 100644 tests/assets/configs/train_only_primaite_session.yaml create mode 100644 tests/e2e_integration_tests/test_primaite_session.py diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml new file mode 100644 index 00000000..752d98a5 --- /dev/null +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -0,0 +1,725 @@ +training_config: + rl_framework: SB3 + rl_algorithm: PPO + se3ed: 333 + n_learn_steps: 2560 + n_eval_episodes: 5 + + + +game_config: + ports: + - ARP + - DNS + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + + agents: + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com + + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 + + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot + observations: + operating_status + health_status + folders: {} + + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com + + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 + + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot + observations: + operating_status + health_status + folders: {} + + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com + + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 + + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot + observations: + operating_status + health_status + folders: {} + + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com + + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 + + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot + observations: + operating_status + health_status + folders: {} + + action_space: + action_list: + - type: DONOTHING + # FileSystem: # PrimAITE v2 stuff -@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3 -class TempPrimaiteSession: # PrimaiteSession): +class TempPrimaiteSession(PrimaiteSession): """ A temporary PrimaiteSession class. Uses context manager for deletion of files upon exit. """ - # def __init__( - # self, - # training_config_path: Union[str, Path], - # lay_down_config_path: Union[str, Path], - # ): - # super().__init__(training_config_path, lay_down_config_path) - # self.setup() + @classmethod + def from_config(cls, config_path: Union[str, Path]) -> "TempPrimaiteSession": + """Create a temporary PrimaiteSession object from a config file.""" + config_path = Path(config_path) + with open(config_path, "r") as f: + config = yaml.safe_load(f) - # @property - # def env(self) -> Primaite: - # """Direct access to the env for ease of testing.""" - # return self._agent_session._env # noqa + return super().from_config(cfg=config) - # def __enter__(self): - # return self + def __enter__(self): + return self - # def __exit__(self, type, value, tb): - # shutil.rmtree(self.session_path) - # _LOGGER.debug(f"Deleted temp session directory: {self.session_path}") + def __exit__(self, type, value, tb): + pass -@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3 @pytest.fixture -def temp_primaite_session(request): - """ - Provides a temporary PrimaiteSession instance. +def temp_primaite_session(request) -> TempPrimaiteSession: + """Create a temporary PrimaiteSession object.""" - It's temporary as it uses a temporary directory as the session path. - - To use this fixture you need to: - - - parametrize your test function with: - - - "temp_primaite_session" - - [[path to training config, path to lay down config]] - - Include the temp_primaite_session fixture as a param in your test - function. - - use the temp_primaite_session as a context manager assigning is the - name 'session'. - - .. code:: python - - from primaite.config.lay_down_config import dos_very_basic_config_path - from primaite.config.training_config import main_training_config_path - @pytest.mark.parametrize( - "temp_primaite_session", - [ - [main_training_config_path(), dos_very_basic_config_path()] - ], - indirect=True - ) - def test_primaite_session(temp_primaite_session): - with temp_primaite_session as session: - # Learning outputs are saved in session.learning_path - session.learn() - - # Evaluation outputs are saved in session.evaluation_path - session.evaluate() - - # To ensure that all files are written, you must call .close() - session.close() - - # If you need to inspect any session outputs, it must be done - # inside the context manager - - # Now that we've exited the context manager, the - # session.session_path directory and its contents are deleted - """ - training_config_path = request.param[0] - lay_down_config_path = request.param[1] - with patch("primaite.agents.agent_abc.get_session_path", get_temp_session_path) as mck: - mck.session_timestamp = datetime.now() - - return TempPrimaiteSession(training_config_path, lay_down_config_path) - - -@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3 -@pytest.fixture -def temp_session_path() -> Path: - """ - Get a temp directory session path the test session will output to. - - :return: The session directory path. - """ - session_timestamp = datetime.now() - date_dir = session_timestamp.strftime("%Y-%m-%d") - session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = Path(tempfile.gettempdir()) / "_primaite" / date_dir / session_path - session_path.mkdir(exist_ok=True, parents=True) - - return session_path + config_path = request.param[0] + return TempPrimaiteSession.from_config(config_path=config_path) diff --git a/tests/e2e_integration_tests/test_primaite_session.py b/tests/e2e_integration_tests/test_primaite_session.py new file mode 100644 index 00000000..5e1da4ff --- /dev/null +++ b/tests/e2e_integration_tests/test_primaite_session.py @@ -0,0 +1,51 @@ +import pytest + +from tests.conftest import TempPrimaiteSession + +CFG_PATH = "tests/assets/configs/test_primaite_session.yaml" +TRAINING_ONLY_PATH = "tests/assets/configs/train_only_primaite_session.yaml" +EVAL_ONLY_PATH = "tests/assets/configs/eval_only_primaite_session.yaml" + + +class TestPrimaiteSession: + @pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True) + def test_creating_session(self, temp_primaite_session): + """Check that creating a session from config works.""" + with temp_primaite_session as session: + if not isinstance(session, TempPrimaiteSession): + raise AssertionError + + assert session is not None + assert session.simulation + assert len(session.agents) == 3 + assert len(session.rl_agents) == 1 + + assert session.policy + assert session.env + + assert session.simulation.network + assert len(session.simulation.network.nodes) == 10 + + @pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True) + def test_start_session(self, temp_primaite_session): + """Make sure you can go all the way through the session without errors.""" + with temp_primaite_session as session: + session: TempPrimaiteSession + session.start_session() + # TODO: check that env was closed, that the model was saved, etc. + + @pytest.mark.parametrize("temp_primaite_session", [[TRAINING_ONLY_PATH]], indirect=True) + def test_training_only_session(self, temp_primaite_session): + """Check that you can run a training-only session.""" + with temp_primaite_session as session: + session: TempPrimaiteSession + session.start_session() + # TODO: include checks that the model was trained, e.g. that the loss changed and checkpoints were saved? + + @pytest.mark.parametrize("temp_primaite_session", [[EVAL_ONLY_PATH]], indirect=True) + def test_eval_only_session(self, temp_primaite_session): + """Check that you can load a model and run an eval-only session.""" + with temp_primaite_session as session: + session: TempPrimaiteSession + session.start_session() + # TODO: include checks that the model was loaded and that the eval-only session ran From 4cc7ba152244ed1d171b192a7d096a0ad20dac9c Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 15 Nov 2023 16:59:56 +0000 Subject: [PATCH 09/19] Add ability to save sb3 final model --- src/primaite/game/io.py | 54 ++++++++++++++++++++++++++++++ src/primaite/game/policy/policy.py | 3 +- src/primaite/game/policy/sb3.py | 16 +++++---- src/primaite/game/session.py | 15 +++++++++ 4 files changed, 81 insertions(+), 7 deletions(-) create mode 100644 src/primaite/game/io.py diff --git a/src/primaite/game/io.py b/src/primaite/game/io.py new file mode 100644 index 00000000..76d5ed1c --- /dev/null +++ b/src/primaite/game/io.py @@ -0,0 +1,54 @@ +from datetime import datetime +from pathlib import Path +from typing import Optional + +from pydantic import BaseModel + +from primaite import PRIMAITE_PATHS + + +class SessionIOSettings(BaseModel): + """Schema for session IO settings.""" + + save_final_model: bool = True + """Whether to save the final model right at the end of training.""" + save_checkpoints: bool = False + """Whether to save a checkpoint model every `checkpoint_interval` episodes""" + checkpoint_interval: int = 10 + """How often to save a checkpoint model (if save_checkpoints is True).""" + save_logs: bool = True + """Whether to save logs""" + save_transactions: bool = True + """Whether to save transactions, If true, the session path will have a transactions folder.""" + save_tensorboard_logs: bool = False + """Whether to save tensorboard logs. If true, the session path will have a tenorboard_logs folder.""" + + +class SessionIO: + """ + Class for managing session IO. + + Currently it's handling path generation, but could expand to handle loading, transaction, tensorboard, and so on. + """ + + def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None: + self.settings = settings + self.session_path = self.generate_session_path() + + def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path: + """Create a folder for the session and return the path to it.""" + if timestamp is None: + timestamp = datetime.now() + date_str = timestamp.strftime("%Y-%m-%d") + time_str = timestamp.strftime("%H-%M-%S") + session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str + session_path.mkdir(exist_ok=True, parents=True) + return session_path + + def generate_model_save_path(self, agent_name: str) -> Path: + """Return the path where the final model will be saved (excluding filename extension).""" + return self.session_path / "checkpoints" / f"{agent_name}_final" + + def generate_checkpoint_save_path(self, agent_name: str, episode: int) -> Path: + """Return the path where the checkpoint model will be saved (excluding filename extension).""" + return self.session_path / "checkpoints" / f"{agent_name}_checkpoint_{episode}.pt" diff --git a/src/primaite/game/policy/policy.py b/src/primaite/game/policy/policy.py index 4c8dc447..6a2381c1 100644 --- a/src/primaite/game/policy/policy.py +++ b/src/primaite/game/policy/policy.py @@ -1,5 +1,6 @@ """Base class and common logic for RL policies.""" from abc import ABC, abstractmethod +from pathlib import Path from typing import Any, Dict, TYPE_CHECKING if TYPE_CHECKING: @@ -54,7 +55,7 @@ class PolicyABC(ABC): pass @abstractmethod - def save(self) -> None: + def save(self, save_path: Path) -> None: """Save the agent.""" pass diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py index ff710944..1be4f915 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/game/policy/sb3.py @@ -1,4 +1,5 @@ """Stable baselines 3 policy.""" +from pathlib import Path from typing import Literal, Optional, TYPE_CHECKING, Union from stable_baselines3 import A2C, PPO @@ -50,12 +51,15 @@ class SB3Policy(PolicyABC, identifier="SB3"): ) print(reward_data) - def save(self) -> None: - """Save the agent.""" - savepath = ( - "temp/path/to/save.pth" # TODO: populate values once I figure out how to get them from the config / session - ) - self._agent.save(savepath) + def save(self, save_path: Path) -> None: + """ + Save the current policy parameters. + + Warning: The recommended way to save model checkpoints is to use a callback within the `learn()` method. Please + refer to https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html for more information. + Therefore, this method is only used to save the final model. + """ + self._agent.save(save_path) pass def load(self) -> None: diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index e85328ef..37c34da9 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -13,6 +13,7 @@ from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction +from primaite.game.io import SessionIO, SessionIOSettings from primaite.game.policy.policy import PolicyABC from primaite.simulator.network.hardware.base import Link, NIC, Node from primaite.simulator.network.hardware.nodes.computer import Computer @@ -179,6 +180,10 @@ class PrimaiteSession: """evaluation episodes counter""" self.mode: SessionMode = SessionMode.MANUAL + """Current session mode.""" + + self.io_manager = SessionIO() + """IO manager for the session.""" def start_session(self) -> None: """Commence the training session.""" @@ -190,6 +195,7 @@ class PrimaiteSession: n_eval_episodes = self.training_options.n_eval_episodes deterministic_eval = self.training_options.deterministic_eval self.policy.learn(n_time_steps=n_learn_steps) + self.save_models() self.mode = SessionMode.EVAL if n_eval_episodes > 0: @@ -198,6 +204,11 @@ class PrimaiteSession: self.mode = SessionMode.MANUAL + def save_models(self) -> None: + """Save the RL models.""" + save_path = self.io_manager.generate_model_save_path("temp_model_name") + self.policy.save(save_path) + def step(self): """ Perform one step of the simulation/agent loop. @@ -500,4 +511,8 @@ class PrimaiteSession: # CREATE POLICY sess.policy = PolicyABC.from_config(sess.training_options, session=sess) + # READ IO SETTINGS + io_settings = cfg.get("io_settings", {}) + sess.io_manager.settings = SessionIO(settings=SessionIOSettings(**io_settings)) + return sess From 829500a60f70da63a0c9b2eba30d24f45c523393 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 16 Nov 2023 14:37:37 +0000 Subject: [PATCH 10/19] Get sb3 checkpoints saving during training --- .../config/_package_data/example_config.yaml | 6 ++++- src/primaite/game/io.py | 4 ++-- src/primaite/game/policy/policy.py | 4 ++-- src/primaite/game/policy/sb3.py | 19 +++++++++++----- src/primaite/game/session.py | 22 ++++++++++++------- 5 files changed, 36 insertions(+), 19 deletions(-) diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index dca9620f..e0ff9276 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -2,7 +2,7 @@ training_config: rl_framework: SB3 rl_algorithm: PPO seed: 333 - n_learn_steps: 2560 + n_learn_episodes: 25 n_eval_episodes: 5 max_steps_per_episode: 128 deterministic_eval: false @@ -10,6 +10,10 @@ training_config: agent_references: - defender +io_settings: + save_checkpoints: true + checkpoint_interval: 5 + game_config: ports: diff --git a/src/primaite/game/io.py b/src/primaite/game/io.py index 76d5ed1c..e613316d 100644 --- a/src/primaite/game/io.py +++ b/src/primaite/game/io.py @@ -32,8 +32,8 @@ class SessionIO: """ def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None: - self.settings = settings - self.session_path = self.generate_session_path() + self.settings: SessionIOSettings = settings + self.session_path: Path = self.generate_session_path() def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path: """Create a folder for the session and return the path to it.""" diff --git a/src/primaite/game/policy/policy.py b/src/primaite/game/policy/policy.py index 6a2381c1..a7052367 100644 --- a/src/primaite/game/policy/policy.py +++ b/src/primaite/game/policy/policy.py @@ -45,12 +45,12 @@ class PolicyABC(ABC): """Reference to the session.""" @abstractmethod - def learn(self, n_episodes: int, n_time_steps: int) -> None: + def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: """Train the agent.""" pass @abstractmethod - def eval(self, n_episodes: int, n_time_steps: int, deterministic: bool) -> None: + def eval(self, n_episodes: int, timesteps_per_episode: int, deterministic: bool) -> None: """Evaluate the agent.""" pass diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py index 1be4f915..10f22e05 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/game/policy/sb3.py @@ -4,6 +4,7 @@ from typing import Literal, Optional, TYPE_CHECKING, Union from stable_baselines3 import A2C, PPO from stable_baselines3.a2c import MlpPolicy as A2C_MLP +from stable_baselines3.common.callbacks import CheckpointCallback from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.ppo import MlpPolicy as PPO_MLP @@ -36,9 +37,17 @@ class SB3Policy(PolicyABC, identifier="SB3"): seed=seed, ) - def learn(self, n_time_steps: int) -> None: + def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: """Train the agent.""" - self._agent.learn(total_timesteps=n_time_steps) + if self.session.io_manager.settings.save_checkpoints: + checkpoint_callback = CheckpointCallback( + save_freq=timesteps_per_episode * self.session.io_manager.settings.checkpoint_interval, + save_path=self.session.io_manager.generate_model_save_path("sb3"), + name_prefix="sb3_model", + ) + else: + checkpoint_callback = None + self._agent.learn(total_timesteps=n_episodes * timesteps_per_episode, callback=checkpoint_callback) def eval(self, n_episodes: int, deterministic: bool) -> None: """Evaluate the agent.""" @@ -60,12 +69,10 @@ class SB3Policy(PolicyABC, identifier="SB3"): Therefore, this method is only used to save the final model. """ self._agent.save(save_path) - pass - def load(self) -> None: + def load(self, model_path: Path) -> None: """Load agent from a checkpoint.""" - self._agent_class.load("temp/path/to/save.pth", env=self.session.env) - pass + self._agent = self._agent_class.load(model_path, env=self.session.env) def close(self) -> None: """Close the agent.""" diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 37c34da9..a2e83cbb 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -112,11 +112,12 @@ class TrainingOptions(BaseModel): rl_framework: Literal["SB3", "RLLIB"] rl_algorithm: Literal["PPO", "A2C"] - seed: Optional[int] - n_learn_steps: int + n_learn_episodes: int n_eval_episodes: Optional[int] = None max_steps_per_episode: int + # checkpoint_freq: Optional[int] = None deterministic_eval: bool + seed: Optional[int] n_agents: int agent_references: List[str] @@ -188,13 +189,18 @@ class PrimaiteSession: def start_session(self) -> None: """Commence the training session.""" self.mode = SessionMode.TRAIN - self.training_progress_bar = progress_bar_manager.counter( - total=self.training_options.n_learn_steps, desc="Training steps" - ) - n_learn_steps = self.training_options.n_learn_steps + n_learn_episodes = self.training_options.n_learn_episodes n_eval_episodes = self.training_options.n_eval_episodes + max_steps_per_episode = self.training_options.max_steps_per_episode + self.training_progress_bar = progress_bar_manager.counter( + total=n_learn_episodes * max_steps_per_episode, desc="Training steps" + ) + deterministic_eval = self.training_options.deterministic_eval - self.policy.learn(n_time_steps=n_learn_steps) + self.policy.learn( + n_episodes=n_learn_episodes, + timesteps_per_episode=max_steps_per_episode, + ) self.save_models() self.mode = SessionMode.EVAL @@ -513,6 +519,6 @@ class PrimaiteSession: # READ IO SETTINGS io_settings = cfg.get("io_settings", {}) - sess.io_manager.settings = SessionIO(settings=SessionIOSettings(**io_settings)) + sess.io_manager = SessionIO(settings=SessionIOSettings(**io_settings)) return sess From 7545c25a467c5768c3516e774c702a659c172f48 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 16 Nov 2023 15:11:03 +0000 Subject: [PATCH 11/19] Make pytest patch with temporary session dir --- src/primaite/__init__.py | 36 +++++++++---------- .../configs/eval_only_primaite_session.yaml | 2 +- .../assets/configs/test_primaite_session.yaml | 2 +- .../configs/train_only_primaite_session.yaml | 2 +- tests/conftest.py | 14 ++++---- tests/mock_and_patch/get_session_path_mock.py | 2 +- 6 files changed, 30 insertions(+), 28 deletions(-) diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index 789517f7..28245d33 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -29,6 +29,15 @@ class _PrimaitePaths: def __init__(self) -> None: self._dirs: Final[PlatformDirs] = PlatformDirs(appname="primaite", version=__version__) + self.user_home_path = self.generate_user_home_path() + self.user_sessions_path = self.generate_user_sessions_path() + self.user_config_path = self.generate_user_config_path() + self.user_notebooks_path = self.generate_user_notebooks_path() + self.app_home_path = self.generate_app_home_path() + self.app_config_dir_path = self.generate_app_config_dir_path() + self.app_config_file_path = self.generate_app_config_file_path() + self.app_log_dir_path = self.generate_app_log_dir_path() + self.app_log_file_path = self.generate_app_log_file_path() def _get_dirs_properties(self) -> List[str]: class_items = self.__class__.__dict__.items() @@ -43,55 +52,47 @@ class _PrimaitePaths: for p in self._get_dirs_properties(): getattr(self, p) - @property - def user_home_path(self) -> Path: + def generate_user_home_path(self) -> Path: """The PrimAITE user home path.""" path = Path.home() / "primaite" / __version__ path.mkdir(exist_ok=True, parents=True) return path - @property - def user_sessions_path(self) -> Path: + def generate_user_sessions_path(self) -> Path: """The PrimAITE user sessions path.""" path = self.user_home_path / "sessions" path.mkdir(exist_ok=True, parents=True) return path - @property - def user_config_path(self) -> Path: + def generate_user_config_path(self) -> Path: """The PrimAITE user config path.""" path = self.user_home_path / "config" path.mkdir(exist_ok=True, parents=True) return path - @property - def user_notebooks_path(self) -> Path: + def generate_user_notebooks_path(self) -> Path: """The PrimAITE user notebooks path.""" path = self.user_home_path / "notebooks" path.mkdir(exist_ok=True, parents=True) return path - @property - def app_home_path(self) -> Path: + def generate_app_home_path(self) -> Path: """The PrimAITE app home path.""" path = self._dirs.user_data_path path.mkdir(exist_ok=True, parents=True) return path - @property - def app_config_dir_path(self) -> Path: + def generate_app_config_dir_path(self) -> Path: """The PrimAITE app config directory path.""" path = self._dirs.user_config_path path.mkdir(exist_ok=True, parents=True) return path - @property - def app_config_file_path(self) -> Path: + def generate_app_config_file_path(self) -> Path: """The PrimAITE app config file path.""" return self.app_config_dir_path / "primaite_config.yaml" - @property - def app_log_dir_path(self) -> Path: + def generate_app_log_dir_path(self) -> Path: """The PrimAITE app log directory path.""" if sys.platform == "win32": path = self.app_home_path / "logs" @@ -100,8 +101,7 @@ class _PrimaitePaths: path.mkdir(exist_ok=True, parents=True) return path - @property - def app_log_file_path(self) -> Path: + def generate_app_log_file_path(self) -> Path: """The PrimAITE app log file path.""" return self.app_log_dir_path / "primaite.log" diff --git a/tests/assets/configs/eval_only_primaite_session.yaml b/tests/assets/configs/eval_only_primaite_session.yaml index 2ab7a2cc..1c9104d1 100644 --- a/tests/assets/configs/eval_only_primaite_session.yaml +++ b/tests/assets/configs/eval_only_primaite_session.yaml @@ -2,7 +2,7 @@ training_config: rl_framework: SB3 rl_algorithm: PPO seed: 333 - n_learn_steps: 0 + n_learn_episodes: 0 n_eval_episodes: 5 max_steps_per_episode: 128 deterministic_eval: false diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index dca9620f..201528eb 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -2,7 +2,7 @@ training_config: rl_framework: SB3 rl_algorithm: PPO seed: 333 - n_learn_steps: 2560 + n_learn_episodes: 10 n_eval_episodes: 5 max_steps_per_episode: 128 deterministic_eval: false diff --git a/tests/assets/configs/train_only_primaite_session.yaml b/tests/assets/configs/train_only_primaite_session.yaml index 5f0cfc77..1ed10212 100644 --- a/tests/assets/configs/train_only_primaite_session.yaml +++ b/tests/assets/configs/train_only_primaite_session.yaml @@ -2,7 +2,7 @@ training_config: rl_framework: SB3 rl_algorithm: PPO seed: 333 - n_learn_steps: 2560 + n_learn_episodes: 10 n_eval_episodes: 0 max_steps_per_episode: 128 deterministic_eval: false diff --git a/tests/conftest.py b/tests/conftest.py index 60b69a1e..fe450213 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,6 @@ import tempfile from datetime import datetime from pathlib import Path from typing import Any, Dict, Union -from unittest.mock import patch import nodeenv import pytest @@ -22,13 +21,15 @@ from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.services.service import Service -from tests.mock_and_patch.get_session_path_mock import get_temp_session_path +from tests.mock_and_patch.get_session_path_mock import temp_user_sessions_path ACTION_SPACE_NODE_VALUES = 1 ACTION_SPACE_NODE_ACTION_VALUES = 1 _LOGGER = getLogger(__name__) +from primaite import PRIMAITE_PATHS + # PrimAITE v3 stuff from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.hardware.base import Node @@ -97,8 +98,9 @@ class TempPrimaiteSession(PrimaiteSession): @pytest.fixture -def temp_primaite_session(request) -> TempPrimaiteSession: +def temp_primaite_session(request, monkeypatch) -> TempPrimaiteSession: """Create a temporary PrimaiteSession object.""" - - config_path = request.param[0] - return TempPrimaiteSession.from_config(config_path=config_path) + with monkeypatch.context() as m: + m.setattr(PRIMAITE_PATHS, "user_sessions_path", temp_user_sessions_path()) + config_path = request.param[0] + return TempPrimaiteSession.from_config(config_path=config_path) diff --git a/tests/mock_and_patch/get_session_path_mock.py b/tests/mock_and_patch/get_session_path_mock.py index 16c4a274..06fe5893 100644 --- a/tests/mock_and_patch/get_session_path_mock.py +++ b/tests/mock_and_patch/get_session_path_mock.py @@ -9,7 +9,7 @@ from primaite import getLogger _LOGGER = getLogger(__name__) -def get_temp_session_path(session_timestamp: datetime) -> Path: +def temp_user_sessions_path() -> Path: """ Get a temp directory session path the test session will output to. From 13c49bf3eaeca21c737f68b745351d59b1098f8e Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 16 Nov 2023 15:19:14 +0000 Subject: [PATCH 12/19] Fix session path monkeypatch --- tests/conftest.py | 7 +++---- tests/e2e_integration_tests/test_primaite_session.py | 2 ++ 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index fe450213..6a65b12f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -100,7 +100,6 @@ class TempPrimaiteSession(PrimaiteSession): @pytest.fixture def temp_primaite_session(request, monkeypatch) -> TempPrimaiteSession: """Create a temporary PrimaiteSession object.""" - with monkeypatch.context() as m: - m.setattr(PRIMAITE_PATHS, "user_sessions_path", temp_user_sessions_path()) - config_path = request.param[0] - return TempPrimaiteSession.from_config(config_path=config_path) + monkeypatch.setattr(PRIMAITE_PATHS, "user_sessions_path", temp_user_sessions_path()) + config_path = request.param[0] + return TempPrimaiteSession.from_config(config_path=config_path) diff --git a/tests/e2e_integration_tests/test_primaite_session.py b/tests/e2e_integration_tests/test_primaite_session.py index 5e1da4ff..c6179e9a 100644 --- a/tests/e2e_integration_tests/test_primaite_session.py +++ b/tests/e2e_integration_tests/test_primaite_session.py @@ -40,6 +40,8 @@ class TestPrimaiteSession: with temp_primaite_session as session: session: TempPrimaiteSession session.start_session() + for i in range(100): + print(session.io_manager.generate_session_path()) # TODO: include checks that the model was trained, e.g. that the loss changed and checkpoints were saved? @pytest.mark.parametrize("temp_primaite_session", [[EVAL_ONLY_PATH]], indirect=True) From 0b9bdedebd719849038ed6736598d8bd09b321ff Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 16 Nov 2023 15:28:38 +0000 Subject: [PATCH 13/19] Fix typehints --- src/primaite/game/agent/rewards.py | 4 ++-- src/primaite/game/policy/policy.py | 4 ++-- src/primaite/game/policy/sb3.py | 4 ++-- src/primaite/game/session.py | 4 ++-- tests/e2e_integration_tests/test_primaite_session.py | 2 -- 5 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 49d56e67..da1331b0 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -26,7 +26,7 @@ the structure: ``` """ from abc import abstractmethod -from typing import Dict, List, Tuple, TYPE_CHECKING +from typing import Dict, List, Tuple, Type, TYPE_CHECKING from primaite import getLogger from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE @@ -228,7 +228,7 @@ class WebServer404Penalty(AbstractReward): class RewardFunction: """Manages the reward function for the agent.""" - __rew_class_identifiers: Dict[str, type[AbstractReward]] = { + __rew_class_identifiers: Dict[str, Type[AbstractReward]] = { "DUMMY": DummyReward, "DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity, "WEB_SERVER_404_PENALTY": WebServer404Penalty, diff --git a/src/primaite/game/policy/policy.py b/src/primaite/game/policy/policy.py index a7052367..249c3b52 100644 --- a/src/primaite/game/policy/policy.py +++ b/src/primaite/game/policy/policy.py @@ -1,7 +1,7 @@ """Base class and common logic for RL policies.""" from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, TYPE_CHECKING +from typing import Any, Dict, Type, TYPE_CHECKING if TYPE_CHECKING: from primaite.game.session import PrimaiteSession, TrainingOptions @@ -10,7 +10,7 @@ if TYPE_CHECKING: class PolicyABC(ABC): """Base class for reinforcement learning agents.""" - _registry: Dict[str, type["PolicyABC"]] = {} + _registry: Dict[str, Type["PolicyABC"]] = {} """ Registry of policy types, keyed by name. diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py index 10f22e05..bb35775a 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/game/policy/sb3.py @@ -1,6 +1,6 @@ """Stable baselines 3 policy.""" from pathlib import Path -from typing import Literal, Optional, TYPE_CHECKING, Union +from typing import Literal, Optional, Type, TYPE_CHECKING, Union from stable_baselines3 import A2C, PPO from stable_baselines3.a2c import MlpPolicy as A2C_MLP @@ -21,7 +21,7 @@ class SB3Policy(PolicyABC, identifier="SB3"): """Initialize a stable baselines 3 policy.""" super().__init__(session=session) - self._agent_class: type[Union[PPO, A2C]] + self._agent_class: Type[Union[PPO, A2C]] if algorithm == "PPO": self._agent_class = PPO policy = PPO_MLP diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index a2e83cbb..88c1e061 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -52,7 +52,7 @@ class PrimaiteGymEnv(gymnasium.Env): self.session: "PrimaiteSession" = session self.agent: ProxyAgent = agents[0] - def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: + def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """Perform a step in the environment.""" # make ProxyAgent store the action chosen my the RL policy self.agent.store_action(action) @@ -70,7 +70,7 @@ class PrimaiteGymEnv(gymnasium.Env): return next_obs, reward, terminated, truncated, info - def reset(self, seed: Optional[int] = None) -> tuple[ObsType, dict[str, Any]]: + def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: """Reset the environment.""" self.session.reset() state = self.session.get_sim_state() diff --git a/tests/e2e_integration_tests/test_primaite_session.py b/tests/e2e_integration_tests/test_primaite_session.py index c6179e9a..5e1da4ff 100644 --- a/tests/e2e_integration_tests/test_primaite_session.py +++ b/tests/e2e_integration_tests/test_primaite_session.py @@ -40,8 +40,6 @@ class TestPrimaiteSession: with temp_primaite_session as session: session: TempPrimaiteSession session.start_session() - for i in range(100): - print(session.io_manager.generate_session_path()) # TODO: include checks that the model was trained, e.g. that the loss changed and checkpoints were saved? @pytest.mark.parametrize("temp_primaite_session", [[EVAL_ONLY_PATH]], indirect=True) From e52d1fbd4500195a6bd0df9935a24ca3ec0291f5 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 16 Nov 2023 15:29:48 +0000 Subject: [PATCH 14/19] Add enlighten dependency --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 51ed84f2..1e074c25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,8 @@ dependencies = [ "stable-baselines3[extra]==2.1.0", "tensorflow==2.12.0", "typer[all]==0.9.0", - "pydantic==2.1.1" + "pydantic==2.1.1", + "enlighten==1.12.2" ] [tool.setuptools.dynamic] From 0861663cc1617cc38f056d8f64a64d4ac4313ca5 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 16 Nov 2023 15:40:49 +0000 Subject: [PATCH 15/19] Add agent loading --- src/primaite/cli.py | 3 ++- src/primaite/game/session.py | 5 ++++- src/primaite/main.py | 3 ++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/primaite/cli.py b/src/primaite/cli.py index 0f17525e..81ab2792 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -119,6 +119,7 @@ def setup(overwrite_existing: bool = True) -> None: @app.command() def session( config: Optional[str] = None, + agent_load_file: Optional[str] = None, ) -> None: """ Run a PrimAITE session. @@ -132,4 +133,4 @@ def session( if not config: config = example_config_path() print(config) - run(config_path=config) + run(config_path=config, agent_load_path=agent_load_file) diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 88c1e061..f265b7d9 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -1,6 +1,7 @@ """PrimAITE session - the main entry point to training agents on PrimAITE.""" from enum import Enum from ipaddress import IPv4Address +from pathlib import Path from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple import enlighten @@ -297,7 +298,7 @@ class PrimaiteSession: return NotImplemented @classmethod - def from_config(cls, cfg: dict) -> "PrimaiteSession": + def from_config(cls, cfg: dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession": """Create a PrimaiteSession object from a config dictionary. The config dictionary should have the following top-level keys: @@ -516,6 +517,8 @@ class PrimaiteSession: # CREATE POLICY sess.policy = PolicyABC.from_config(sess.training_options, session=sess) + if agent_load_path: + sess.policy.load(Path(agent_load_path)) # READ IO SETTINGS io_settings = cfg.get("io_settings", {}) diff --git a/src/primaite/main.py b/src/primaite/main.py index 831419d4..1699fe51 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -15,6 +15,7 @@ _LOGGER = getLogger(__name__) def run( config_path: Optional[Union[str, Path]] = "", + agent_load_path: Optional[Union[str, Path]] = None, ) -> None: """ Run the PrimAITE Session. @@ -31,7 +32,7 @@ def run( otherwise False. """ cfg = load(config_path) - sess = PrimaiteSession.from_config(cfg=cfg) + sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path) sess.start_session() From ba580b00b41324343c31e1ccc4f4767ffc8c26a2 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 16 Nov 2023 16:14:50 +0000 Subject: [PATCH 16/19] Improve config validation and fix tests --- src/primaite/game/io.py | 6 +++++- src/primaite/game/session.py | 8 ++++++-- tests/assets/configs/test_primaite_session.yaml | 4 ++++ .../e2e_integration_tests/test_primaite_session.py | 13 ++++++++++++- 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/src/primaite/game/io.py b/src/primaite/game/io.py index e613316d..d510d108 100644 --- a/src/primaite/game/io.py +++ b/src/primaite/game/io.py @@ -2,7 +2,7 @@ from datetime import datetime from pathlib import Path from typing import Optional -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from primaite import PRIMAITE_PATHS @@ -10,6 +10,8 @@ from primaite import PRIMAITE_PATHS class SessionIOSettings(BaseModel): """Schema for session IO settings.""" + model_config = ConfigDict(extra="forbid") + save_final_model: bool = True """Whether to save the final model right at the end of training.""" save_checkpoints: bool = False @@ -34,6 +36,8 @@ class SessionIO: def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None: self.settings: SessionIOSettings = settings self.session_path: Path = self.generate_session_path() + # warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's + # possible refactor needed def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path: """Create a folder for the session and return the path to it.""" diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index f265b7d9..655e2459 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple import enlighten import gymnasium from gymnasium.core import ActType, ObsType -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from primaite import getLogger from primaite.game.agent.actions import ActionManager @@ -104,6 +104,8 @@ class PrimaiteSessionOptions(BaseModel): Currently this is used to restrict which ports and protocols exist in the world of the simulation. """ + model_config = ConfigDict(extra="forbid") + ports: List[str] protocols: List[str] @@ -111,6 +113,8 @@ class PrimaiteSessionOptions(BaseModel): class TrainingOptions(BaseModel): """Options for training the RL agent.""" + model_config = ConfigDict(extra="forbid") + rl_framework: Literal["SB3", "RLLIB"] rl_algorithm: Literal["PPO", "A2C"] n_learn_episodes: int @@ -522,6 +526,6 @@ class PrimaiteSession: # READ IO SETTINGS io_settings = cfg.get("io_settings", {}) - sess.io_manager = SessionIO(settings=SessionIOSettings(**io_settings)) + sess.io_manager.settings = SessionIOSettings(**io_settings) return sess diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index 201528eb..9445cd2b 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -10,6 +10,10 @@ training_config: agent_references: - defender +io_settings: + save_checkpoints: true + checkpoint_interval: 5 + game_config: ports: diff --git a/tests/e2e_integration_tests/test_primaite_session.py b/tests/e2e_integration_tests/test_primaite_session.py index 5e1da4ff..3ef5b6da 100644 --- a/tests/e2e_integration_tests/test_primaite_session.py +++ b/tests/e2e_integration_tests/test_primaite_session.py @@ -32,7 +32,18 @@ class TestPrimaiteSession: with temp_primaite_session as session: session: TempPrimaiteSession session.start_session() - # TODO: check that env was closed, that the model was saved, etc. + + session_path = session.io_manager.session_path + assert session_path.exists() + print(list(session_path.glob("*"))) + checkpoint_dir = session_path / "checkpoints" / "sb3_final" + assert checkpoint_dir.exists() + checkpoint_1 = checkpoint_dir / "sb3_model_640_steps.zip" + checkpoint_2 = checkpoint_dir / "sb3_model_1280_steps.zip" + checkpoint_3 = checkpoint_dir / "sb3_model_1920_steps.zip" + assert checkpoint_1.exists() + assert checkpoint_2.exists() + assert not checkpoint_3.exists() @pytest.mark.parametrize("temp_primaite_session", [[TRAINING_ONLY_PATH]], indirect=True) def test_training_only_session(self, temp_primaite_session): From 5bda952ead90e566593681eefdfa9d223c84af3a Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 17 Nov 2023 10:20:26 +0000 Subject: [PATCH 17/19] Fix sim output --- src/primaite/game/io.py | 5 ++++ src/primaite/game/session.py | 9 +++--- src/primaite/simulator/__init__.py | 30 +++++++++++++------ .../simulator/network/hardware/base.py | 2 +- .../simulator/system/core/packet_capture.py | 2 +- src/primaite/simulator/system/core/sys_log.py | 3 +- 6 files changed, 34 insertions(+), 17 deletions(-) diff --git a/src/primaite/game/io.py b/src/primaite/game/io.py index d510d108..e0b849c9 100644 --- a/src/primaite/game/io.py +++ b/src/primaite/game/io.py @@ -5,6 +5,7 @@ from typing import Optional from pydantic import BaseModel, ConfigDict from primaite import PRIMAITE_PATHS +from primaite.simulator import SIM_OUTPUT class SessionIOSettings(BaseModel): @@ -36,6 +37,10 @@ class SessionIO: def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None: self.settings: SessionIOSettings = settings self.session_path: Path = self.generate_session_path() + + # set global SIM_OUTPUT path + SIM_OUTPUT.path = self.session_path / "simulation_output" + # warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's # possible refactor needed diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 655e2459..a2c04980 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -324,6 +324,11 @@ class PrimaiteSession: protocols=cfg["game_config"]["protocols"], ) sess.training_options = TrainingOptions(**cfg["training_config"]) + + # READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS... + io_settings = cfg.get("io_settings", {}) + sess.io_manager.settings = SessionIOSettings(**io_settings) + sim = sess.simulation net = sim.network @@ -524,8 +529,4 @@ class PrimaiteSession: if agent_load_path: sess.policy.load(Path(agent_load_path)) - # READ IO SETTINGS - io_settings = cfg.get("io_settings", {}) - sess.io_manager.settings = SessionIOSettings(**io_settings) - return sess diff --git a/src/primaite/simulator/__init__.py b/src/primaite/simulator/__init__.py index 8c55542f..19c86e28 100644 --- a/src/primaite/simulator/__init__.py +++ b/src/primaite/simulator/__init__.py @@ -1,14 +1,26 @@ +"""Warning: SIM_OUTPUT is a mutable global variable for the simulation output directory.""" from datetime import datetime +from pathlib import Path from primaite import _PRIMAITE_ROOT -SIM_OUTPUT = None -"A path at the repo root dir to use temporarily for sim output testing while in dev." -# TODO: Remove once we integrate the simulation into PrimAITE and it uses the primaite session path +__all__ = ["SIM_OUTPUT"] -if not SIM_OUTPUT: - session_timestamp = datetime.now() - date_dir = session_timestamp.strftime("%Y-%m-%d") - sim_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - SIM_OUTPUT = _PRIMAITE_ROOT.parent.parent / "simulation_output" / date_dir / sim_path - SIM_OUTPUT.mkdir(exist_ok=True, parents=True) + +class __SimOutput: + def __init__(self): + self._path: Path = ( + _PRIMAITE_ROOT.parent.parent / "simulation_output" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ) + + @property + def path(self) -> Path: + return self._path + + @path.setter + def path(self, new_path: Path) -> None: + self._path = new_path + self._path.mkdir(exist_ok=True, parents=True) + + +SIM_OUTPUT = __SimOutput() diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 537cebb2..29d3a05c 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -957,7 +957,7 @@ class Node(SimComponent): if not kwargs.get("session_manager"): kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp")) if not kwargs.get("root"): - kwargs["root"] = SIM_OUTPUT / kwargs["hostname"] + kwargs["root"] = SIM_OUTPUT.path / kwargs["hostname"] if not kwargs.get("file_system"): kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs") if not kwargs.get("software_manager"): diff --git a/src/primaite/simulator/system/core/packet_capture.py b/src/primaite/simulator/system/core/packet_capture.py index 2e5ed008..c2faeb10 100644 --- a/src/primaite/simulator/system/core/packet_capture.py +++ b/src/primaite/simulator/system/core/packet_capture.py @@ -75,7 +75,7 @@ class PacketCapture: def _get_log_path(self) -> Path: """Get the path for the log file.""" - root = SIM_OUTPUT / self.hostname + root = SIM_OUTPUT.path / self.hostname root.mkdir(exist_ok=True, parents=True) return root / f"{self._logger_name}.log" diff --git a/src/primaite/simulator/system/core/sys_log.py b/src/primaite/simulator/system/core/sys_log.py index 791e0be8..7ac6df85 100644 --- a/src/primaite/simulator/system/core/sys_log.py +++ b/src/primaite/simulator/system/core/sys_log.py @@ -41,7 +41,6 @@ class SysLog: JSON-like messages. """ log_path = self._get_log_path() - file_handler = logging.FileHandler(filename=log_path) file_handler.setLevel(logging.DEBUG) @@ -81,7 +80,7 @@ class SysLog: :return: Path object representing the location of the log file. """ - root = SIM_OUTPUT / self.hostname + root = SIM_OUTPUT.path / self.hostname root.mkdir(exist_ok=True, parents=True) return root / f"{self.hostname}_sys.log" From c5b4ae45be8c13162f98113dc52553d40cfb4668 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 17 Nov 2023 11:40:36 +0000 Subject: [PATCH 18/19] Remove problematic progress bars --- pyproject.toml | 1 - src/primaite/game/session.py | 18 ------------------ 2 files changed, 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1e074c25..92f78ec0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,6 @@ dependencies = [ "tensorflow==2.12.0", "typer[all]==0.9.0", "pydantic==2.1.1", - "enlighten==1.12.2" ] [tool.setuptools.dynamic] diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index a2c04980..ad0537e8 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -4,7 +4,6 @@ from ipaddress import IPv4Address from pathlib import Path from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple -import enlighten import gymnasium from gymnasium.core import ActType, ObsType from pydantic import BaseModel, ConfigDict @@ -34,8 +33,6 @@ from primaite.simulator.system.services.red_services.data_manipulation_bot impor from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.web_server.web_server import WebServer -progress_bar_manager = enlighten.get_manager() - _LOGGER = getLogger(__name__) @@ -179,12 +176,6 @@ class PrimaiteSession: self.env: PrimaiteGymEnv """The environment that the agent can consume. Could be PrimaiteEnv.""" - self.training_progress_bar: Optional[enlighten.Counter] = None - """training steps counter""" - - self.eval_progress_bar: Optional[enlighten.Counter] = None - """evaluation episodes counter""" - self.mode: SessionMode = SessionMode.MANUAL """Current session mode.""" @@ -197,9 +188,6 @@ class PrimaiteSession: n_learn_episodes = self.training_options.n_learn_episodes n_eval_episodes = self.training_options.n_eval_episodes max_steps_per_episode = self.training_options.max_steps_per_episode - self.training_progress_bar = progress_bar_manager.counter( - total=n_learn_episodes * max_steps_per_episode, desc="Training steps" - ) deterministic_eval = self.training_options.deterministic_eval self.policy.learn( @@ -210,7 +198,6 @@ class PrimaiteSession: self.mode = SessionMode.EVAL if n_eval_episodes > 0: - self.eval_progress_bar = progress_bar_manager.counter(total=n_eval_episodes, desc="Evaluation episodes") self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval) self.mode = SessionMode.MANUAL @@ -277,9 +264,6 @@ class PrimaiteSession: _LOGGER.debug(f"Advancing timestep to {self.step_counter} ") self.simulation.apply_timestep(self.step_counter) - if self.training_progress_bar and self.mode == SessionMode.TRAIN: - self.training_progress_bar.update() - def calculate_truncated(self) -> bool: """Calculate whether the episode is truncated.""" current_step = self.step_counter @@ -294,8 +278,6 @@ class PrimaiteSession: self.step_counter = 0 _LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}") self.simulation.reset_component_for_episode(self.episode_counter) - if self.eval_progress_bar and self.mode == SessionMode.EVAL: - self.eval_progress_bar.update() def close(self) -> None: """Close the session, this will stop the env and close the simulation.""" From 9d0a98b22122e8f8b4c000ad84d613acf45252f8 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 17 Nov 2023 20:30:07 +0000 Subject: [PATCH 19/19] Apply suggestions from code review --- src/primaite/game/policy/sb3.py | 4 ---- tests/assets/configs/bad_primaite_session.yaml | 2 +- tests/e2e_integration_tests/test_primaite_session.py | 6 ++++++ 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py index bb35775a..a4870054 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/game/policy/sb3.py @@ -74,10 +74,6 @@ class SB3Policy(PolicyABC, identifier="SB3"): """Load agent from a checkpoint.""" self._agent = self._agent_class.load(model_path, env=self.session.env) - def close(self) -> None: - """Close the agent.""" - pass - @classmethod def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy": """Create an agent from config file.""" diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index 752d98a5..80567aea 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -1,7 +1,7 @@ training_config: rl_framework: SB3 rl_algorithm: PPO - se3ed: 333 + se3ed: 333 # Purposeful typo to check that error is raised with bad configuration. n_learn_steps: 2560 n_eval_episodes: 5 diff --git a/tests/e2e_integration_tests/test_primaite_session.py b/tests/e2e_integration_tests/test_primaite_session.py index 3ef5b6da..b6122bad 100644 --- a/tests/e2e_integration_tests/test_primaite_session.py +++ b/tests/e2e_integration_tests/test_primaite_session.py @@ -1,3 +1,4 @@ +import pydantic import pytest from tests.conftest import TempPrimaiteSession @@ -5,6 +6,7 @@ from tests.conftest import TempPrimaiteSession CFG_PATH = "tests/assets/configs/test_primaite_session.yaml" TRAINING_ONLY_PATH = "tests/assets/configs/train_only_primaite_session.yaml" EVAL_ONLY_PATH = "tests/assets/configs/eval_only_primaite_session.yaml" +MISCONFIGURED_PATH = "tests/assets/configs/bad_primaite_session.yaml" class TestPrimaiteSession: @@ -60,3 +62,7 @@ class TestPrimaiteSession: session: TempPrimaiteSession session.start_session() # TODO: include checks that the model was loaded and that the eval-only session ran + + def test_error_thrown_on_bad_configuration(self): + with pytest.raises(pydantic.ValidationError): + session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH)