diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index 066c66b2..0bb03594 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -86,5 +86,5 @@ stages: displayName: 'Perform PrimAITE Setup' - script: | - pytest tests/ + pytest -n 4 displayName: 'Run tests' diff --git a/.gitignore b/.gitignore index 5d6434f1..ef1050e6 100644 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,9 @@ coverage.xml .hypothesis/ .pytest_cache/ cover/ +tests/assets/**/*.png +tests/assets/**/tensorboard_logs/ +tests/assets/**/checkpoints/ # Translations *.mo diff --git a/docs/source/primaite_session.rst b/docs/source/primaite_session.rst index bfc5ee16..bfb66332 100644 --- a/docs/source/primaite_session.rst +++ b/docs/source/primaite_session.rst @@ -43,6 +43,36 @@ The sub-directory is formatted as such: ``~/primaite/sessions//) + +When PrimAITE runs a loaded session, PrimAITE will output in the provided session directory Outputs ------- diff --git a/pyproject.toml b/pyproject.toml index dc04f609..fc0551c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ dev = [ "pip-licenses==4.3.0", "pre-commit==2.20.0", "pytest==7.2.0", + "pytest-xdist==3.3.1", "pytest-cov==4.0.0", "pytest-flake8==1.1.1", "setuptools==66", diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent_abc.py similarity index 62% rename from src/primaite/agents/agent.py rename to src/primaite/agents/agent_abc.py index 1f06a371..515adfd0 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent_abc.py @@ -1,21 +1,19 @@ from __future__ import annotations import json -import time from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path -from typing import Dict, Final, Union +from typing import Dict, Optional, Union from uuid import uuid4 -import yaml - import primaite from primaite import getLogger, SESSIONS_DIR from primaite.config import lay_down_config, training_config from primaite.config.training_config import TrainingConfig from primaite.data_viz.session_plots import plot_av_reward_per_episode from primaite.environment.primaite_env import Primaite +from primaite.utils.session_metadata_parser import parse_session_metadata _LOGGER = getLogger(__name__) @@ -47,38 +45,63 @@ class AgentSessionABC(ABC): """ @abstractmethod - def __init__(self, training_config_path, lay_down_config_path): + def __init__( + self, + training_config_path: Optional[Union[str, Path]] = None, + lay_down_config_path: Optional[Union[str, Path]] = None, + session_path: Optional[Union[str, Path]] = None, + ): """ - Initialise an agent session from config files. + Initialise an agent session from config files, or load a previous session. + + If training configuration and laydown configuration are provided with a session path, + the session path will be used. :param training_config_path: YAML file containing configurable items defined in `primaite.config.training_config.TrainingConfig` :type training_config_path: Union[path, str] :param lay_down_config_path: YAML file containing configurable items for generating network laydown. :type lay_down_config_path: Union[path, str] + :param session_path: directory path of the session to load """ - if not isinstance(training_config_path, Path): - training_config_path = Path(training_config_path) - self._training_config_path: Final[Union[Path, str]] = training_config_path - self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path) - - if not isinstance(lay_down_config_path, Path): - lay_down_config_path = Path(lay_down_config_path) - self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path - self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) - self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level - + # initialise variables self._env: Primaite self._agent = None self._can_learn: bool = False self._can_evaluate: bool = False self.is_eval = False - self._uuid = str(uuid4()) self.session_timestamp: datetime = datetime.now() - "The session timestamp" - self.session_path = get_session_path(self.session_timestamp) - "The Session path" + + # convert session to path + if session_path is not None: + if not isinstance(session_path, Path): + session_path = Path(session_path) + + # if a session path is provided, load it + if not session_path.exists(): + raise Exception(f"Session could not be loaded. Path does not exist: {session_path}") + + # load session + self.load(session_path) + else: + # set training config path + if not isinstance(training_config_path, Path): + training_config_path = Path(training_config_path) + self._training_config_path: Union[Path, str] = training_config_path + self._training_config: TrainingConfig = training_config.load(self._training_config_path) + + if not isinstance(lay_down_config_path, Path): + lay_down_config_path = Path(lay_down_config_path) + self._lay_down_config_path: Union[Path, str] = lay_down_config_path + self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) + self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level + + # set random UUID for session + self._uuid = str(uuid4()) + "The session timestamp" + self.session_path = get_session_path(self.session_timestamp) + "The Session path" @property def timestamp_str(self) -> str: @@ -227,52 +250,27 @@ class AgentSessionABC(ABC): def _get_latest_checkpoint(self): pass - @classmethod - @abstractmethod - def load(cls, path: Union[str, Path]) -> AgentSessionABC: + def load(self, path: Union[str, Path]): """Load an agent from file.""" - if not isinstance(path, Path): - path = Path(path) + md_dict, training_config_path, laydown_config_path = parse_session_metadata(path) - if path.exists(): - # Unpack the session_metadata.json file - md_file = path / "session_metadata.json" - with open(md_file, "r") as file: - md_dict = json.load(file) + # set training config path + self._training_config_path: Union[Path, str] = training_config_path + self._training_config: TrainingConfig = training_config.load(self._training_config_path) + self._lay_down_config_path: Union[Path, str] = laydown_config_path + self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) + self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level - # Create a temp directory and dump the training and lay down - # configs into it - temp_dir = path / ".temp" - temp_dir.mkdir(exist_ok=True) + # set random UUID for session + self._uuid = md_dict["uuid"] - temp_tc = temp_dir / "tc.yaml" - with open(temp_tc, "w") as file: - yaml.dump(md_dict["env"]["training_config"], file) - - temp_ldc = temp_dir / "ldc.yaml" - with open(temp_ldc, "w") as file: - yaml.dump(md_dict["env"]["lay_down_config"], file) - - agent = cls(temp_tc, temp_ldc) - - agent.session_path = path - - return agent - - else: - # Session path does not exist - msg = f"Failed to load PrimAITE Session, path does not exist: {path}" - _LOGGER.error(msg) - raise FileNotFoundError(msg) - pass + # set the session path + self.session_path = path + "The Session path" @property def _saved_agent_path(self) -> Path: - file_name = ( - f"{self._training_config.agent_framework}_" - f"{self._training_config.agent_identifier}_" - f"{self.timestamp_str}.zip" - ) + file_name = f"{self._training_config.agent_framework}_" f"{self._training_config.agent_identifier}" f".zip" return self.learning_path / file_name @abstractmethod @@ -308,104 +306,3 @@ class AgentSessionABC(ABC): fig = plot_av_reward_per_episode(path, title, subtitle) fig.write_image(image_path) _LOGGER.debug(f"Saved average rewards per episode plot to: {path}") - - -class HardCodedAgentSessionABC(AgentSessionABC): - """ - An Agent Session ABC for evaluation deterministic agents. - - This class cannot be directly instantiated and must be inherited from with all implemented abstract methods - implemented. - """ - - def __init__(self, training_config_path, lay_down_config_path): - """ - Initialise a hardcoded agent session. - - :param training_config_path: YAML file containing configurable items defined in - `primaite.config.training_config.TrainingConfig` - :type training_config_path: Union[path, str] - :param lay_down_config_path: YAML file containing configurable items for generating network laydown. - :type lay_down_config_path: Union[path, str] - """ - super().__init__(training_config_path, lay_down_config_path) - self._setup() - - def _setup(self): - self._env: Primaite = Primaite( - training_config_path=self._training_config_path, - lay_down_config_path=self._lay_down_config_path, - session_path=self.session_path, - timestamp_str=self.timestamp_str, - ) - super()._setup() - self._can_learn = False - self._can_evaluate = True - - def _save_checkpoint(self): - pass - - def _get_latest_checkpoint(self): - pass - - def learn( - self, - **kwargs, - ): - """ - Train the agent. - - :param kwargs: Any agent-specific key-word args to be passed. - """ - _LOGGER.warning("Deterministic agents cannot learn") - - @abstractmethod - def _calculate_action(self, obs): - pass - - def evaluate( - self, - **kwargs, - ): - """ - Evaluate the agent. - - :param kwargs: Any agent-specific key-word args to be passed. - """ - self._env.set_as_eval() # noqa - self.is_eval = True - - time_steps = self._training_config.num_eval_steps - episodes = self._training_config.num_eval_episodes - - obs = self._env.reset() - for episode in range(episodes): - # Reset env and collect initial observation - for step in range(time_steps): - # Calculate action - action = self._calculate_action(obs) - - # Perform the step - obs, reward, done, info = self._env.step(action) - - if done: - break - - # Introduce a delay between steps - time.sleep(self._training_config.time_delay / 1000) - obs = self._env.reset() - self._env.close() - super().evaluate() - - @classmethod - def load(cls): - """Load an agent from file.""" - _LOGGER.warning("Deterministic agents cannot be loaded") - - def save(self): - """Save the agent.""" - _LOGGER.warning("Deterministic agents cannot be saved") - - def export(self): - """Export the agent to transportable file format.""" - _LOGGER.warning("Deterministic agents cannot be exported") diff --git a/src/primaite/agents/hardcoded_abc.py b/src/primaite/agents/hardcoded_abc.py new file mode 100644 index 00000000..cfee3e16 --- /dev/null +++ b/src/primaite/agents/hardcoded_abc.py @@ -0,0 +1,115 @@ +import time +from abc import abstractmethod +from pathlib import Path +from typing import Optional, Union + +from primaite import getLogger +from primaite.agents.agent_abc import AgentSessionABC +from primaite.environment.primaite_env import Primaite + +_LOGGER = getLogger(__name__) + + +class HardCodedAgentSessionABC(AgentSessionABC): + """ + An Agent Session ABC for evaluation deterministic agents. + + This class cannot be directly instantiated and must be inherited from with all implemented abstract methods + implemented. + """ + + def __init__( + self, + training_config_path: Optional[Union[str, Path]] = "", + lay_down_config_path: Optional[Union[str, Path]] = "", + session_path: Optional[Union[str, Path]] = None, + ): + """ + Initialise a hardcoded agent session. + + :param training_config_path: YAML file containing configurable items defined in + `primaite.config.training_config.TrainingConfig` + :type training_config_path: Union[path, str] + :param lay_down_config_path: YAML file containing configurable items for generating network laydown. + :type lay_down_config_path: Union[path, str] + """ + super().__init__(training_config_path, lay_down_config_path, session_path) + self._setup() + + def _setup(self): + self._env: Primaite = Primaite( + training_config_path=self._training_config_path, + lay_down_config_path=self._lay_down_config_path, + session_path=self.session_path, + timestamp_str=self.timestamp_str, + ) + super()._setup() + self._can_learn = False + self._can_evaluate = True + + def _save_checkpoint(self): + pass + + def _get_latest_checkpoint(self): + pass + + def learn( + self, + **kwargs, + ): + """ + Train the agent. + + :param kwargs: Any agent-specific key-word args to be passed. + """ + _LOGGER.warning("Deterministic agents cannot learn") + + @abstractmethod + def _calculate_action(self, obs): + pass + + def evaluate( + self, + **kwargs, + ): + """ + Evaluate the agent. + + :param kwargs: Any agent-specific key-word args to be passed. + """ + self._env.set_as_eval() # noqa + self.is_eval = True + + time_steps = self._training_config.num_eval_steps + episodes = self._training_config.num_eval_episodes + + obs = self._env.reset() + for episode in range(episodes): + # Reset env and collect initial observation + for step in range(time_steps): + # Calculate action + action = self._calculate_action(obs) + + # Perform the step + obs, reward, done, info = self._env.step(action) + + if done: + break + + # Introduce a delay between steps + time.sleep(self._training_config.time_delay / 1000) + obs = self._env.reset() + self._env.close() + + @classmethod + def load(cls, path=None): + """Load an agent from file.""" + _LOGGER.warning("Deterministic agents cannot be loaded") + + def save(self): + """Save the agent.""" + _LOGGER.warning("Deterministic agents cannot be saved") + + def export(self): + """Export the agent to transportable file format.""" + _LOGGER.warning("Deterministic agents cannot be exported") diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py index 166ff415..e08a1d6d 100644 --- a/src/primaite/agents/hardcoded_acl.py +++ b/src/primaite/agents/hardcoded_acl.py @@ -4,7 +4,7 @@ import numpy as np from primaite.acl.access_control_list import AccessControlList from primaite.acl.acl_rule import ACLRule -from primaite.agents.agent import HardCodedAgentSessionABC +from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC from primaite.agents.utils import ( get_new_action, get_node_of_ip, diff --git a/src/primaite/agents/hardcoded_node.py b/src/primaite/agents/hardcoded_node.py index c00cf421..113f622a 100644 --- a/src/primaite/agents/hardcoded_node.py +++ b/src/primaite/agents/hardcoded_node.py @@ -1,6 +1,6 @@ import numpy as np -from primaite.agents.agent import HardCodedAgentSessionABC +from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index 6253f574..1707cb81 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -4,7 +4,7 @@ import json import shutil from datetime import datetime from pathlib import Path -from typing import Union +from typing import Optional, Union from uuid import uuid4 from ray.rllib.algorithms import Algorithm @@ -14,7 +14,7 @@ from ray.tune.logger import UnifiedLogger from ray.tune.registry import register_env from primaite import getLogger -from primaite.agents.agent import AgentSessionABC +from primaite.agents.agent_abc import AgentSessionABC from primaite.common.enums import AgentFramework, AgentIdentifier from primaite.environment.primaite_env import Primaite @@ -43,7 +43,12 @@ def _custom_log_creator(session_path: Path): class RLlibAgent(AgentSessionABC): """An AgentSession class that implements a Ray RLlib agent.""" - def __init__(self, training_config_path, lay_down_config_path): + def __init__( + self, + training_config_path: Optional[Union[str, Path]] = "", + lay_down_config_path: Optional[Union[str, Path]] = "", + session_path: Optional[Union[str, Path]] = None, + ): """ Initialise the RLLib Agent training session. @@ -56,6 +61,13 @@ class RLlibAgent(AgentSessionABC): :raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO` or `A2C`) """ + # TODO: implement RLlib agent loading + if session_path is not None: + msg = "RLlib agent loading has not been implemented yet" + _LOGGER.error(msg) + print(msg) + raise NotImplementedError + super().__init__(training_config_path, lay_down_config_path) if not self._training_config.agent_framework == AgentFramework.RLLIB: msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}" diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index cb00985a..862a0116 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -1,14 +1,15 @@ from __future__ import annotations +import json from pathlib import Path -from typing import Union +from typing import Optional, Union import numpy as np from stable_baselines3 import A2C, PPO from stable_baselines3.ppo import MlpPolicy as PPOMlp from primaite import getLogger -from primaite.agents.agent import AgentSessionABC +from primaite.agents.agent_abc import AgentSessionABC from primaite.common.enums import AgentFramework, AgentIdentifier from primaite.environment.primaite_env import Primaite @@ -18,7 +19,12 @@ _LOGGER = getLogger(__name__) class SB3Agent(AgentSessionABC): """An AgentSession class that implements a Stable Baselines3 agent.""" - def __init__(self, training_config_path, lay_down_config_path): + def __init__( + self, + training_config_path: Optional[Union[str, Path]] = None, + lay_down_config_path: Optional[Union[str, Path]] = None, + session_path: Optional[Union[str, Path]] = None, + ): """ Initialise the SB3 Agent training session. @@ -31,7 +37,7 @@ class SB3Agent(AgentSessionABC): :raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO` or `A2C`) """ - super().__init__(training_config_path, lay_down_config_path) + super().__init__(training_config_path, lay_down_config_path, session_path) if not self._training_config.agent_framework == AgentFramework.SB3: msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}" _LOGGER.error(msg) @@ -47,7 +53,7 @@ class SB3Agent(AgentSessionABC): self._tensorboard_log_path = self.learning_path / "tensorboard_logs" self._tensorboard_log_path.mkdir(parents=True, exist_ok=True) - self._setup() + _LOGGER.debug( f"Created {self.__class__.__name__} using: " f"agent_framework={self._training_config.agent_framework}, " @@ -57,8 +63,10 @@ class SB3Agent(AgentSessionABC): self.is_eval = False + self._setup() + def _setup(self): - super()._setup() + """Set up the SB3 Agent.""" self._env = Primaite( training_config_path=self._training_config_path, lay_down_config_path=self._lay_down_config_path, @@ -66,14 +74,43 @@ class SB3Agent(AgentSessionABC): timestamp_str=self.timestamp_str, ) - self._agent = self._agent_class( - PPOMlp, - self._env, - verbose=self.sb3_output_verbose_level, - n_steps=self._training_config.num_train_steps, - tensorboard_log=str(self._tensorboard_log_path), - seed=self._training_config.seed, - ) + # check if there is a zip file that needs to be loaded + load_file = next(self.session_path.rglob("*.zip"), None) + + if not load_file: + # create a new env and agent + + self._agent = self._agent_class( + PPOMlp, + self._env, + verbose=self.sb3_output_verbose_level, + n_steps=self._training_config.num_train_steps, + tensorboard_log=str(self._tensorboard_log_path), + seed=self._training_config.seed, + ) + else: + # set env values from session metadata + with open(self.session_path / "session_metadata.json", "r") as file: + md_dict = json.load(file) + + # load environment values + if self.is_eval: + # evaluation always starts at 0 + self._env.episode_count = 0 + self._env.total_step_count = 0 + else: + # carry on from previous learning sessions + self._env.episode_count = md_dict["learning"]["total_episodes"] + self._env.total_step_count = md_dict["learning"]["total_time_steps"] + + # load the file + self._agent = self._agent_class.load(load_file, env=self._env) + + # set agent values + self._agent.verbose = self.sb3_output_verbose_level + self._agent.tensorboard_log = self.session_path / "learning/tensorboard_logs" + + super()._setup() def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes @@ -145,11 +182,6 @@ class SB3Agent(AgentSessionABC): self._env.close() super().evaluate() - @classmethod - def load(cls, path: Union[str, Path]) -> SB3Agent: - """Load an agent from file.""" - raise NotImplementedError - def save(self): """Save the agent.""" self._agent.save(self._saved_agent_path) diff --git a/src/primaite/agents/simple.py b/src/primaite/agents/simple.py index b429a2f5..f81163ea 100644 --- a/src/primaite/agents/simple.py +++ b/src/primaite/agents/simple.py @@ -1,4 +1,4 @@ -from primaite.agents.agent import HardCodedAgentSessionABC +from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum diff --git a/src/primaite/cli.py b/src/primaite/cli.py index 40e8cf0d..adc9cb32 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -151,7 +151,7 @@ def setup(overwrite_existing: bool = True): @app.command() -def session(tc: Optional[str] = None, ldc: Optional[str] = None): +def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[str] = None): """ Run a PrimAITE session. @@ -162,11 +162,19 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None): ldc: The lay down config file path. Optional. If no value is passed then example default lay down config is used from: ~/primaite/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml. + + load: The directory of a previous session. Optional. If no value is passed, then the session + will use the default training config and laydown config. Inversely, if a training config and laydown config + is passed while a session directory is passed, PrimAITE will load the session and ignore the training config + and laydown config. """ from primaite.config.lay_down_config import dos_very_basic_config_path from primaite.config.training_config import main_training_config_path from primaite.main import run + if load is not None: + run(session_path=load) + if not tc: tc = main_training_config_path() diff --git a/src/primaite/main.py b/src/primaite/main.py index f2d1b9c2..9fcc4df6 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -2,7 +2,7 @@ """The main PrimAITE session runner module.""" import argparse from pathlib import Path -from typing import Union +from typing import Optional, Union from primaite import getLogger from primaite.primaite_session import PrimaiteSession @@ -11,16 +11,21 @@ _LOGGER = getLogger(__name__) def run( - training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path], + training_config_path: Optional[Union[str, Path]] = "", + lay_down_config_path: Optional[Union[str, Path]] = "", + session_path: Optional[Union[str, Path]] = None, ): """ Run the PrimAITE Session. - :param training_config_path: The training config filepath. - :param lay_down_config_path: The lay down config filepath. + :param training_config_path: YAML file containing configurable items defined in + `primaite.config.training_config.TrainingConfig` + :type training_config_path: Union[path, str] + :param lay_down_config_path: YAML file containing configurable items for generating network laydown. + :type lay_down_config_path: Union[path, str] + :param session_path: directory path of the session to load """ - session = PrimaiteSession(training_config_path, lay_down_config_path) + session = PrimaiteSession(training_config_path, lay_down_config_path, session_path) session.setup() session.learn() @@ -31,9 +36,14 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--tc") parser.add_argument("--ldc") + parser.add_argument("--load") + args = parser.parse_args() - if not args.tc: - _LOGGER.error("Please provide a training config file using the --tc " "argument") - if not args.ldc: - _LOGGER.error("Please provide a lay down config file using the --ldc " "argument") - run(training_config_path=args.tc, lay_down_config_path=args.ldc) + if args.load: + run(session_path=args.load) + else: + if not args.tc: + _LOGGER.error("Please provide a training config file using the --tc " "argument") + if not args.ldc: + _LOGGER.error("Please provide a lay down config file using the --ldc " "argument") + run(training_config_path=args.tc, lay_down_config_path=args.ldc) diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index caa85e9e..76134238 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -2,10 +2,10 @@ from __future__ import annotations from pathlib import Path -from typing import Dict, Final, Union +from typing import Dict, Final, Optional, Union from primaite import getLogger -from primaite.agents.agent import AgentSessionABC +from primaite.agents.agent_abc import AgentSessionABC from primaite.agents.hardcoded_acl import HardCodedACLAgent from primaite.agents.hardcoded_node import HardCodedNodeAgent from primaite.agents.rllib import RLlibAgent @@ -14,6 +14,7 @@ from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyA from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType from primaite.config import lay_down_config, training_config from primaite.config.training_config import TrainingConfig +from primaite.utils.session_metadata_parser import parse_session_metadata _LOGGER = getLogger(__name__) @@ -27,15 +28,39 @@ class PrimaiteSession: def __init__( self, - training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path], + training_config_path: Optional[Union[str, Path]] = "", + lay_down_config_path: Optional[Union[str, Path]] = "", + session_path: Optional[Union[str, Path]] = None, ): """ The PrimaiteSession constructor. - :param training_config_path: The training config path. - :param lay_down_config_path: The lay down config path. + :param training_config_path: YAML file containing configurable items defined in + `primaite.config.training_config.TrainingConfig` + :type training_config_path: Union[path, str] + :param lay_down_config_path: YAML file containing configurable items for generating network laydown. + :type lay_down_config_path: Union[path, str] + :param session_path: directory path of the session to load """ + self._agent_session: AgentSessionABC = None # noqa + self.session_path: Path = session_path # noqa + self.timestamp_str: str = None # noqa + self.learning_path: Path = None # noqa + self.evaluation_path: Path = None # noqa + + # check if session path is provided + if session_path is not None: + # set load_session to true + self.is_load_session = True + if not isinstance(session_path, Path): + session_path = Path(session_path) + + # if a session path is provided, load it + if not session_path.exists(): + raise Exception(f"Session could not be loaded. Path does not exist: {session_path}") + + md_dict, training_config_path, lay_down_config_path = parse_session_metadata(session_path) + if not isinstance(training_config_path, Path): training_config_path = Path(training_config_path) self._training_config_path: Final[Union[Path, str]] = training_config_path @@ -46,12 +71,6 @@ class PrimaiteSession: self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) - self._agent_session: AgentSessionABC = None # noqa - self.session_path: Path = None # noqa - self.timestamp_str: str = None # noqa - self.learning_path: Path = None # noqa - self.evaluation_path: Path = None # noqa - def setup(self): """Performs the session setup.""" if self._training_config.agent_framework == AgentFramework.CUSTOM: @@ -60,11 +79,15 @@ class PrimaiteSession: _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}") if self._training_config.action_type == ActionType.NODE: # Deterministic Hardcoded Agent with Node Action Space - self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path) + self._agent_session = HardCodedNodeAgent( + self._training_config_path, self._lay_down_config_path, self.session_path + ) elif self._training_config.action_type == ActionType.ACL: # Deterministic Hardcoded Agent with ACL Action Space - self._agent_session = HardCodedACLAgent(self._training_config_path, self._lay_down_config_path) + self._agent_session = HardCodedACLAgent( + self._training_config_path, self._lay_down_config_path, self.session_path + ) elif self._training_config.action_type == ActionType.ANY: # Deterministic Hardcoded Agent with ANY Action Space @@ -77,11 +100,15 @@ class PrimaiteSession: elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING: _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHING}") if self._training_config.action_type == ActionType.NODE: - self._agent_session = DoNothingNodeAgent(self._training_config_path, self._lay_down_config_path) + self._agent_session = DoNothingNodeAgent( + self._training_config_path, self._lay_down_config_path, self.session_path + ) elif self._training_config.action_type == ActionType.ACL: # Deterministic Hardcoded Agent with ACL Action Space - self._agent_session = DoNothingACLAgent(self._training_config_path, self._lay_down_config_path) + self._agent_session = DoNothingACLAgent( + self._training_config_path, self._lay_down_config_path, self.session_path + ) elif self._training_config.action_type == ActionType.ANY: # Deterministic Hardcoded Agent with ANY Action Space @@ -93,10 +120,14 @@ class PrimaiteSession: elif self._training_config.agent_identifier == AgentIdentifier.RANDOM: _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}") - self._agent_session = RandomAgent(self._training_config_path, self._lay_down_config_path) + self._agent_session = RandomAgent( + self._training_config_path, self._lay_down_config_path, self.session_path + ) elif self._training_config.agent_identifier == AgentIdentifier.DUMMY: _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}") - self._agent_session = DummyAgent(self._training_config_path, self._lay_down_config_path) + self._agent_session = DummyAgent( + self._training_config_path, self._lay_down_config_path, self.session_path + ) else: # Invalid AgentFramework AgentIdentifier combo @@ -105,12 +136,12 @@ class PrimaiteSession: elif self._training_config.agent_framework == AgentFramework.SB3: _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}") # Stable Baselines3 Agent - self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path) + self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path, self.session_path) elif self._training_config.agent_framework == AgentFramework.RLLIB: _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}") # Ray RLlib Agent - self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path) + self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path, self.session_path) else: # Invalid AgentFramework diff --git a/src/primaite/utils/session_metadata_parser.py b/src/primaite/utils/session_metadata_parser.py new file mode 100644 index 00000000..936d3269 --- /dev/null +++ b/src/primaite/utils/session_metadata_parser.py @@ -0,0 +1,58 @@ +import json +from pathlib import Path +from typing import Union + +import yaml + +from primaite import getLogger + +_LOGGER = getLogger(__name__) + + +def parse_session_metadata(session_path: Union[Path, str], dict_only=False): + """ + Loads a session metadata from the given directory path. + + :param session_path: Directory where the session metadata file is in + :param dict_only: If dict_only is true, the function will only return the dict contents of session metadata + + :return: Dictionary which has all the session metadata contents + :rtype: Dict + + :return: Path where the YAML copy of the training config is dumped into + :rtype: str + :return: Path where the YAML copy of the laydown config is dumped into + :rtype: str + """ + if not isinstance(session_path, Path): + session_path = Path(session_path) + + if not session_path.exists(): + # Session path does not exist + msg = f"Failed to load PrimAITE Session, path does not exist: {session_path}" + _LOGGER.error(msg) + raise FileNotFoundError(msg) + + # Unpack the session_metadata.json file + md_file = session_path / "session_metadata.json" + with open(md_file, "r") as file: + md_dict = json.load(file) + + # if dict only, return dict without doing anything else + if dict_only: + return md_dict + + # Create a temp directory and dump the training and lay down + # configs into it + temp_dir = session_path / ".temp" + temp_dir.mkdir(exist_ok=True) + + temp_tc = temp_dir / "tc.yaml" + with open(temp_tc, "w") as file: + yaml.dump(md_dict["env"]["training_config"], file) + + temp_ldc = temp_dir / "ldc.yaml" + with open(temp_ldc, "w") as file: + yaml.dump(md_dict["env"]["lay_down_config"], file) + + return [md_dict, temp_tc, temp_ldc] diff --git a/src/primaite/utils/session_output_reader.py b/src/primaite/utils/session_output_reader.py index ad3dd4f4..2ff4a16a 100644 --- a/src/primaite/utils/session_output_reader.py +++ b/src/primaite/utils/session_output_reader.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Dict, Union +from typing import Any, Dict, Tuple, Union # Using polars as it's faster than Pandas; it will speed things up when # files get big! @@ -13,8 +13,33 @@ def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]: The dictionary keys are the episode number, and the values are the mean reward that episode. :param av_rewards_csv_file: The average rewards per episode csv file path. - :return: The average rewards per episode cdv as a dict. + :return: The average rewards per episode csv as a dict. """ - df = pl.read_csv(av_rewards_csv_file).to_dict() + df_dict = pl.read_csv(av_rewards_csv_file).to_dict() - return {v: df["Average Reward"][i] for i, v in enumerate(df["Episode"])} + return {v: df_dict["Average Reward"][i] for i, v in enumerate(df_dict["Episode"])} + + +def all_transactions_dict(all_transactions_csv_file: Union[str, Path]) -> Dict[Tuple[int, int], Dict[str, Any]]: + """ + Read an all transactions csv file and return as a dict. + + The dict keys are a tuple with the structure (episode, step). The dict + values are the remaining columns as a dict. + + :param all_transactions_csv_file: The all transactions csv file path. + :return: The all transactions csv file as a dict. + """ + df_dict = pl.read_csv(all_transactions_csv_file).to_dict() + new_dict = {} + + episodes = df_dict["Episode"] + steps = df_dict["Step"] + keys = list(df_dict.keys()) + + for i in range(len(episodes)): + key = (episodes[i], steps[i]) + value_dict = {key: df_dict[key][i] for key in keys if key not in ["Episode", "Step"]} + new_dict[key] = value_dict + + return new_dict diff --git a/tests/__init__.py b/tests/__init__.py index 4a0bdce1..31744e29 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -4,3 +4,6 @@ from typing import Final TEST_CONFIG_ROOT: Final[Path] = Path(__file__).parent / "config" "The tests config root directory." + +TEST_ASSETS_ROOT: Final[Path] = Path(__file__).parent / "assets" +"The tests assets root directory." diff --git a/tests/assets/example_sb3_agent_session/learning/SB3_PPO.zip b/tests/assets/example_sb3_agent_session/learning/SB3_PPO.zip new file mode 100644 index 00000000..666151e7 Binary files /dev/null and b/tests/assets/example_sb3_agent_session/learning/SB3_PPO.zip differ diff --git a/tests/assets/example_sb3_agent_session/session_metadata.json b/tests/assets/example_sb3_agent_session/session_metadata.json new file mode 100644 index 00000000..20f6a77c --- /dev/null +++ b/tests/assets/example_sb3_agent_session/session_metadata.json @@ -0,0 +1 @@ +{"uuid": "301874d3-2e14-43c2-ba7f-e2b03ad05dde", "start_datetime": "2023-07-14T09:48:22.973005", "end_datetime": "2023-07-14T09:48:34.182715", "learning": {"total_episodes": 10, "total_time_steps": 2560}, "evaluation": {"total_episodes": 5, "total_time_steps": 1280}, "env": {"training_config": {"agent_framework": "SB3", "deep_learning_framework": "TF2", "agent_identifier": "PPO", "hard_coded_agent_view": "FULL", "random_red_agent": false, "action_type": "NODE", "num_train_episodes": 10, "num_train_steps": 256, "num_eval_episodes": 5, "num_eval_steps": 256, "checkpoint_every_n_episodes": 10, "observation_space": {"components": [{"name": "NODE_LINK_TABLE"}]}, "time_delay": 5, "session_type": "TRAIN_EVAL", "load_agent": false, "agent_load_file": null, "observation_space_high_value": 1000000000, "sb3_output_verbose_level": "NONE", "all_ok": 0, "off_should_be_on": -0.001, "off_should_be_resetting": -0.0005, "on_should_be_off": -0.0002, "on_should_be_resetting": -0.0005, "resetting_should_be_on": -0.0005, "resetting_should_be_off": -0.0002, "resetting": -0.0003, "good_should_be_patching": 0.0002, "good_should_be_compromised": 0.0005, "good_should_be_overwhelmed": 0.0005, "patching_should_be_good": -0.0005, "patching_should_be_compromised": 0.0002, "patching_should_be_overwhelmed": 0.0002, "patching": -0.0003, "compromised_should_be_good": -0.002, "compromised_should_be_patching": -0.002, "compromised_should_be_overwhelmed": -0.002, "compromised": -0.002, "overwhelmed_should_be_good": -0.002, "overwhelmed_should_be_patching": -0.002, "overwhelmed_should_be_compromised": -0.002, "overwhelmed": -0.002, "good_should_be_repairing": 0.0002, "good_should_be_restoring": 0.0002, "good_should_be_corrupt": 0.0005, "good_should_be_destroyed": 0.001, "repairing_should_be_good": -0.0005, "repairing_should_be_restoring": 0.0002, "repairing_should_be_corrupt": 0.0002, "repairing_should_be_destroyed": 0.0, "repairing": -0.0003, "restoring_should_be_good": -0.001, "restoring_should_be_repairing": -0.0002, "restoring_should_be_corrupt": 0.0001, "restoring_should_be_destroyed": 0.0002, "restoring": -0.0006, "corrupt_should_be_good": -0.001, "corrupt_should_be_repairing": -0.001, "corrupt_should_be_restoring": -0.001, "corrupt_should_be_destroyed": 0.0002, "corrupt": -0.001, "destroyed_should_be_good": -0.002, "destroyed_should_be_repairing": -0.002, "destroyed_should_be_restoring": -0.002, "destroyed_should_be_corrupt": -0.002, "destroyed": -0.002, "scanning": -0.0002, "red_ier_running": -0.0005, "green_ier_blocked": -0.001, "os_patching_duration": 5, "node_reset_duration": 5, "node_booting_duration": 3, "node_shutdown_duration": 2, "service_patching_duration": 5, "file_system_repairing_limit": 5, "file_system_restoring_limit": 5, "file_system_scanning_limit": 5, "deterministic": true, "seed": 12345}, "lay_down_config": [{"item_type": "PORTS", "ports_list": [{"port": "80"}]}, {"item_type": "SERVICES", "service_list": [{"name": "TCP"}]}, {"item_type": "NODE", "node_id": "1", "name": "PC1", "node_class": "SERVICE", "node_type": "COMPUTER", "priority": "P5", "hardware_state": "ON", "ip_address": "192.168.1.2", "software_state": "GOOD", "file_system_state": "GOOD", "services": [{"name": "TCP", "port": "80", "state": "GOOD"}]}, {"item_type": "NODE", "node_id": "2", "name": "PC2", "node_class": "SERVICE", "node_type": "COMPUTER", "priority": "P5", "hardware_state": "ON", "ip_address": "192.168.1.3", "software_state": "GOOD", "file_system_state": "GOOD", "services": [{"name": "TCP", "port": "80", "state": "GOOD"}]}, {"item_type": "NODE", "node_id": "3", "name": "SWITCH1", "node_class": "ACTIVE", "node_type": "SWITCH", "priority": "P2", "hardware_state": "ON", "ip_address": "192.168.1.1", "software_state": "GOOD", "file_system_state": "GOOD"}, {"item_type": "NODE", "node_id": "4", "name": "SERVER1", "node_class": "SERVICE", "node_type": "SERVER", "priority": "P5", "hardware_state": "ON", "ip_address": "192.168.1.4", "software_state": "GOOD", "file_system_state": "GOOD", "services": [{"name": "TCP", "port": "80", "state": "GOOD"}]}, {"item_type": "LINK", "id": "5", "name": "link1", "bandwidth": 1000000000, "source": "1", "destination": "3"}, {"item_type": "LINK", "id": "6", "name": "link2", "bandwidth": 1000000000, "source": "2", "destination": "3"}, {"item_type": "LINK", "id": "7", "name": "link3", "bandwidth": 1000000000, "source": "3", "destination": "4"}, {"item_type": "GREEN_IER", "id": "8", "start_step": 1, "end_step": 256, "load": 10000, "protocol": "TCP", "port": "80", "source": "1", "destination": "4", "mission_criticality": 1}, {"item_type": "GREEN_IER", "id": "9", "start_step": 1, "end_step": 256, "load": 10000, "protocol": "TCP", "port": "80", "source": "2", "destination": "4", "mission_criticality": 1}, {"item_type": "GREEN_IER", "id": "10", "start_step": 1, "end_step": 256, "load": 10000, "protocol": "TCP", "port": "80", "source": "4", "destination": "2", "mission_criticality": 5}, {"item_type": "ACL_RULE", "id": "11", "permission": "ALLOW", "source": "192.168.1.2", "destination": "192.168.1.4", "protocol": "TCP", "port": 80}, {"item_type": "ACL_RULE", "id": "12", "permission": "ALLOW", "source": "192.168.1.3", "destination": "192.168.1.4", "protocol": "TCP", "port": 80}, {"item_type": "ACL_RULE", "id": "13", "permission": "ALLOW", "source": "192.168.1.4", "destination": "192.168.1.3", "protocol": "TCP", "port": 80}, {"item_type": "RED_POL", "id": "14", "start_step": 20, "end_step": 20, "targetNodeId": "1", "initiator": "DIRECT", "type": "SERVICE", "protocol": "TCP", "state": "COMPROMISED", "sourceNodeId": "NA", "sourceNodeService": "NA", "sourceNodeServiceState": "NA"}, {"item_type": "RED_IER", "id": "15", "start_step": 30, "end_step": 256, "load": 10000000, "protocol": "TCP", "port": "80", "source": "1", "destination": "4", "mission_criticality": 0}, {"item_type": "RED_POL", "id": "16", "start_step": 40, "end_step": 40, "targetNodeId": "4", "initiator": "IER", "type": "SERVICE", "protocol": "TCP", "state": "OVERWHELMED", "sourceNodeId": "NA", "sourceNodeService": "NA", "sourceNodeServiceState": "NA"}]}} diff --git a/tests/config/one_node_states_on_off_main_config.yaml b/tests/config/one_node_states_on_off_main_config.yaml index b02cfe30..f57cac05 100644 --- a/tests/config/one_node_states_on_off_main_config.yaml +++ b/tests/config/one_node_states_on_off_main_config.yaml @@ -7,6 +7,14 @@ # "CUSTOM" (Custom Agent) agent_framework: CUSTOM +# Sets which deep learning framework will be used (by RLlib ONLY). +# Default is TF (Tensorflow). +# Options are: +# "TF" (Tensorflow) +# TF2 (Tensorflow 2.X) +# TORCH (PyTorch) +deep_learning_framework: TF2 + # Sets which Agent class will be used. # Options are: # "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) @@ -17,27 +25,66 @@ agent_framework: CUSTOM # "DUMMY" (primaite.agents.simple.DummyAgent) agent_identifier: DUMMY +# Sets whether Red Agent POL and IER is randomised. +# Options are: +# True +# False +random_red_agent: False + +# The (integer) seed to be used in random number generation +# Default is None (null) +seed: null + +# Set whether the agent will be deterministic instead of stochastic +# Options are: +# True +# False +deterministic: False + +# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC. +# Options are: +# "BASIC" (The current observation space only) +# "FULL" (Full environment view with actions taken and reward feedback) +hard_coded_agent_view: FULL + # Sets How the Action Space is defined: # "NODE" # "ACL" # "ANY" node and acl actions action_type: NODE +# observation space +observation_space: + # flatten: true + components: + - name: NODE_LINK_TABLE + # - name: NODE_STATUSES + # - name: LINK_TRAFFIC_LEVELS +# Number of episodes for training to run per session +num_train_episodes: 10 + +# Number of time_steps for training per episode +num_train_steps: 256 + # Number of episodes for evaluation to run per session num_eval_episodes: 1 # Number of time_steps for evaluation per episode num_eval_steps: 15 -# Time delay between steps (for generic agents) -time_delay: 1 -# Type of session to be run (TRAINING or EVALUATION) +# Sets how often the agent will save a checkpoint (every n time episodes). +# Set to 0 if no checkpoints are required. Default is 10 +checkpoint_every_n_episodes: 10 + +# Time delay (milliseconds) between steps for CUSTOM agents. +time_delay: 5 + +# Type of session to be run. Options are: +# "TRAIN" (Trains an agent) +# "EVAL" (Evaluates an agent) +# "TRAIN_EVAL" (Trains then evaluates an agent) session_type: EVAL -# Determine whether to load an agent from file -load_agent: False -# File path and file name of agent if you're loading one in -agent_load_file: C:\[Path]\[agent_saved_filename.zip] # Environment config values # The high value for the observation space @@ -45,6 +92,13 @@ observation_space_high_value: 1000000000 implicit_acl_rule: DENY max_number_acl_rules: 10 +# The Stable Baselines3 learn/eval output verbosity level: +# Options are: +# "NONE" (No Output) +# "INFO" (Info Messages (such as devices and wrappers used)) +# "DEBUG" (All Messages) +sb3_output_verbose_level: NONE + # Reward values # Generic all_ok: 0 diff --git a/tests/config/test_random_red_main_config.yaml b/tests/config/test_random_red_main_config.yaml index e2b24b41..9e034355 100644 --- a/tests/config/test_random_red_main_config.yaml +++ b/tests/config/test_random_red_main_config.yaml @@ -5,7 +5,15 @@ # "SB3" (Stable Baselines3) # "RLLIB" (Ray RLlib) # "CUSTOM" (Custom Agent) -agent_framework: CUSTOM +agent_framework: SB3 + +# Sets which deep learning framework will be used (by RLlib ONLY). +# Default is TF (Tensorflow). +# Options are: +# "TF" (Tensorflow) +# TF2 (Tensorflow 2.X) +# TORCH (PyTorch) +deep_learning_framework: TF2 # Sets which Agent class will be used. # Options are: @@ -15,7 +23,7 @@ agent_framework: CUSTOM # "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) # "RANDOM" (primaite.agents.simple.RandomAgent) # "DUMMY" (primaite.agents.simple.DummyAgent) -agent_identifier: DUMMY +agent_identifier: PPO # Sets whether Red Agent POL and IER is randomised. # Options are: @@ -23,92 +31,128 @@ agent_identifier: DUMMY # False random_red_agent: True +# The (integer) seed to be used in random number generation +# Default is None (null) +seed: null + +# Set whether the agent will be deterministic instead of stochastic +# Options are: +# True +# False +deterministic: False + +# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC. +# Options are: +# "BASIC" (The current observation space only) +# "FULL" (Full environment view with actions taken and reward feedback) +hard_coded_agent_view: FULL + # Sets How the Action Space is defined: # "NODE" # "ACL" # "ANY" node and acl actions action_type: NODE +# observation space +observation_space: + # flatten: true + components: + - name: NODE_LINK_TABLE + # - name: NODE_STATUSES + # - name: LINK_TRAFFIC_LEVELS + + # Number of episodes for training to run per session -num_train_episodes: 2 +num_train_episodes: 10 # Number of time_steps for training per episode -num_train_steps: 15 +num_train_steps: 256 # Number of episodes for evaluation to run per session -num_eval_episodes: 2 +num_eval_episodes: 1 # Number of time_steps for evaluation per episode -num_eval_steps: 15 -# Time delay between steps (for generic agents) -time_delay: 1 +num_eval_steps: 256 -# Type of session to be run (TRAINING or EVALUATION) -session_type: EVAL -# Determine whether to load an agent from file -load_agent: False -# File path and file name of agent if you're loading one in -agent_load_file: C:\[Path]\[agent_saved_filename.zip] +# Sets how often the agent will save a checkpoint (every n time episodes). +# Set to 0 if no checkpoints are required. Default is 10 +checkpoint_every_n_episodes: 10 + +# Time delay (milliseconds) between steps for CUSTOM agents. +time_delay: 5 + +# Type of session to be run. Options are: +# "TRAIN" (Trains an agent) +# "EVAL" (Evaluates an agent) +# "TRAIN_EVAL" (Trains then evaluates an agent) +session_type: TRAIN_EVAL # Environment config values # The high value for the observation space observation_space_high_value: 1000000000 +# The Stable Baselines3 learn/eval output verbosity level: +# Options are: +# "NONE" (No Output) +# "INFO" (Info Messages (such as devices and wrappers used)) +# "DEBUG" (All Messages) +sb3_output_verbose_level: NONE + # Reward values # Generic all_ok: 0 # Node Hardware State -off_should_be_on: -10 -off_should_be_resetting: -5 -on_should_be_off: -2 -on_should_be_resetting: -5 -resetting_should_be_on: -5 -resetting_should_be_off: -2 -resetting: -3 +off_should_be_on: -0.001 +off_should_be_resetting: -0.0005 +on_should_be_off: -0.0002 +on_should_be_resetting: -0.0005 +resetting_should_be_on: -0.0005 +resetting_should_be_off: -0.0002 +resetting: -0.0003 # Node Software or Service State -good_should_be_patching: 2 -good_should_be_compromised: 5 -good_should_be_overwhelmed: 5 -patching_should_be_good: -5 -patching_should_be_compromised: 2 -patching_should_be_overwhelmed: 2 -patching: -3 -compromised_should_be_good: -20 -compromised_should_be_patching: -20 -compromised_should_be_overwhelmed: -20 -compromised: -20 -overwhelmed_should_be_good: -20 -overwhelmed_should_be_patching: -20 -overwhelmed_should_be_compromised: -20 -overwhelmed: -20 +good_should_be_patching: 0.0002 +good_should_be_compromised: 0.0005 +good_should_be_overwhelmed: 0.0005 +patching_should_be_good: -0.0005 +patching_should_be_compromised: 0.0002 +patching_should_be_overwhelmed: 0.0002 +patching: -0.0003 +compromised_should_be_good: -0.002 +compromised_should_be_patching: -0.002 +compromised_should_be_overwhelmed: -0.002 +compromised: -0.002 +overwhelmed_should_be_good: -0.002 +overwhelmed_should_be_patching: -0.002 +overwhelmed_should_be_compromised: -0.002 +overwhelmed: -0.002 # Node File System State -good_should_be_repairing: 2 -good_should_be_restoring: 2 -good_should_be_corrupt: 5 -good_should_be_destroyed: 10 -repairing_should_be_good: -5 -repairing_should_be_restoring: 2 -repairing_should_be_corrupt: 2 -repairing_should_be_destroyed: 0 -repairing: -3 -restoring_should_be_good: -10 -restoring_should_be_repairing: -2 -restoring_should_be_corrupt: 1 -restoring_should_be_destroyed: 2 -restoring: -6 -corrupt_should_be_good: -10 -corrupt_should_be_repairing: -10 -corrupt_should_be_restoring: -10 -corrupt_should_be_destroyed: 2 -corrupt: -10 -destroyed_should_be_good: -20 -destroyed_should_be_repairing: -20 -destroyed_should_be_restoring: -20 -destroyed_should_be_corrupt: -20 -destroyed: -20 -scanning: -2 +good_should_be_repairing: 0.0002 +good_should_be_restoring: 0.0002 +good_should_be_corrupt: 0.0005 +good_should_be_destroyed: 0.001 +repairing_should_be_good: -0.0005 +repairing_should_be_restoring: 0.0002 +repairing_should_be_corrupt: 0.0002 +repairing_should_be_destroyed: 0.0000 +repairing: -0.0003 +restoring_should_be_good: -0.001 +restoring_should_be_repairing: -0.0002 +restoring_should_be_corrupt: 0.0001 +restoring_should_be_destroyed: 0.0002 +restoring: -0.0006 +corrupt_should_be_good: -0.001 +corrupt_should_be_repairing: -0.001 +corrupt_should_be_restoring: -0.001 +corrupt_should_be_destroyed: 0.0002 +corrupt: -0.001 +destroyed_should_be_good: -0.002 +destroyed_should_be_repairing: -0.002 +destroyed_should_be_restoring: -0.002 +destroyed_should_be_corrupt: -0.002 +destroyed: -0.002 +scanning: -0.0002 # IER status -red_ier_running: -5 -green_ier_blocked: -10 +red_ier_running: -0.0005 +green_ier_blocked: -0.001 # Patching / Reset durations os_patching_duration: 5 # The time taken to patch the OS diff --git a/tests/config/training_config_main_rllib.yaml b/tests/config/training_config_main_rllib.yaml new file mode 100644 index 00000000..88f82890 --- /dev/null +++ b/tests/config/training_config_main_rllib.yaml @@ -0,0 +1,163 @@ +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: RLLIB + +# Sets which deep learning framework will be used (by RLlib ONLY). +# Default is TF (Tensorflow). +# Options are: +# "TF" (Tensorflow) +# TF2 (Tensorflow 2.X) +# TORCH (PyTorch) +deep_learning_framework: TF2 + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: PPO + +# Sets whether Red Agent POL and IER is randomised. +# Options are: +# True +# False +random_red_agent: False + +# The (integer) seed to be used in random number generation +# Default is None (null) +seed: null + +# Set whether the agent will be deterministic instead of stochastic +# Options are: +# True +# False +deterministic: False + +# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC. +# Options are: +# "BASIC" (The current observation space only) +# "FULL" (Full environment view with actions taken and reward feedback) +hard_coded_agent_view: FULL + +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: NODE +# observation space +observation_space: + # flatten: true + components: + - name: NODE_LINK_TABLE + # - name: NODE_STATUSES + # - name: LINK_TRAFFIC_LEVELS + + +# Number of episodes for training to run per session +num_train_episodes: 10 + +# Number of time_steps for training per episode +num_train_steps: 256 + +# Number of episodes for evaluation to run per session +num_eval_episodes: 1 + +# Number of time_steps for evaluation per episode +num_eval_steps: 256 + +# Sets how often the agent will save a checkpoint (every n time episodes). +# Set to 0 if no checkpoints are required. Default is 10 +checkpoint_every_n_episodes: 10 + +# Time delay (milliseconds) between steps for CUSTOM agents. +time_delay: 5 + +# Type of session to be run. Options are: +# "TRAIN" (Trains an agent) +# "EVAL" (Evaluates an agent) +# "TRAIN_EVAL" (Trains then evaluates an agent) +session_type: TRAIN_EVAL + +# Environment config values +# The high value for the observation space +observation_space_high_value: 1000000000 + +# The Stable Baselines3 learn/eval output verbosity level: +# Options are: +# "NONE" (No Output) +# "INFO" (Info Messages (such as devices and wrappers used)) +# "DEBUG" (All Messages) +sb3_output_verbose_level: NONE + +# Reward values +# Generic +all_ok: 0 +# Node Hardware State +off_should_be_on: -0.001 +off_should_be_resetting: -0.0005 +on_should_be_off: -0.0002 +on_should_be_resetting: -0.0005 +resetting_should_be_on: -0.0005 +resetting_should_be_off: -0.0002 +resetting: -0.0003 +# Node Software or Service State +good_should_be_patching: 0.0002 +good_should_be_compromised: 0.0005 +good_should_be_overwhelmed: 0.0005 +patching_should_be_good: -0.0005 +patching_should_be_compromised: 0.0002 +patching_should_be_overwhelmed: 0.0002 +patching: -0.0003 +compromised_should_be_good: -0.002 +compromised_should_be_patching: -0.002 +compromised_should_be_overwhelmed: -0.002 +compromised: -0.002 +overwhelmed_should_be_good: -0.002 +overwhelmed_should_be_patching: -0.002 +overwhelmed_should_be_compromised: -0.002 +overwhelmed: -0.002 +# Node File System State +good_should_be_repairing: 0.0002 +good_should_be_restoring: 0.0002 +good_should_be_corrupt: 0.0005 +good_should_be_destroyed: 0.001 +repairing_should_be_good: -0.0005 +repairing_should_be_restoring: 0.0002 +repairing_should_be_corrupt: 0.0002 +repairing_should_be_destroyed: 0.0000 +repairing: -0.0003 +restoring_should_be_good: -0.001 +restoring_should_be_repairing: -0.0002 +restoring_should_be_corrupt: 0.0001 +restoring_should_be_destroyed: 0.0002 +restoring: -0.0006 +corrupt_should_be_good: -0.001 +corrupt_should_be_repairing: -0.001 +corrupt_should_be_restoring: -0.001 +corrupt_should_be_destroyed: 0.0002 +corrupt: -0.001 +destroyed_should_be_good: -0.002 +destroyed_should_be_repairing: -0.002 +destroyed_should_be_restoring: -0.002 +destroyed_should_be_corrupt: -0.002 +destroyed: -0.002 +scanning: -0.0002 +# IER status +red_ier_running: -0.0005 +green_ier_blocked: -0.001 + +# Patching / Reset durations +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/tests/conftest.py b/tests/conftest.py index e089f2d8..3f022b6f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ import shutil import tempfile from datetime import datetime from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Dict, Tuple, Union from unittest.mock import patch import pytest @@ -13,7 +13,7 @@ import pytest from primaite import getLogger from primaite.environment.primaite_env import Primaite from primaite.primaite_session import PrimaiteSession -from primaite.utils.session_output_reader import av_rewards_dict +from primaite.utils.session_output_reader import all_transactions_dict, av_rewards_dict from tests.mock_and_patch.get_session_path_mock import get_temp_session_path ACTION_SPACE_NODE_VALUES = 1 @@ -37,16 +37,26 @@ class TempPrimaiteSession(PrimaiteSession): super().__init__(training_config_path, lay_down_config_path) self.setup() - def learn_av_reward_per_episode(self) -> Dict[int, float]: + def learn_av_reward_per_episode_dict(self) -> Dict[int, float]: """Get the learn av reward per episode from file.""" csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" return av_rewards_dict(self.learning_path / csv_file) - def eval_av_reward_per_episode_csv(self) -> Dict[int, float]: + def eval_av_reward_per_episode_dict(self) -> Dict[int, float]: """Get the eval av reward per episode from file.""" csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" return av_rewards_dict(self.evaluation_path / csv_file) + def learn_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]: + """Get the learn all transactions from file.""" + csv_file = f"all_transactions_{self.timestamp_str}.csv" + return all_transactions_dict(self.learning_path / csv_file) + + def eval_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]: + """Get the eval all transactions from file.""" + csv_file = f"all_transactions_{self.timestamp_str}.csv" + return all_transactions_dict(self.evaluation_path / csv_file) + def metadata_file_as_dict(self) -> Dict[str, Any]: """Read the session_metadata.json file and return as a dict.""" with open(self.session_path / "session_metadata.json", "r") as file: @@ -113,7 +123,7 @@ def temp_primaite_session(request): """ training_config_path = request.param[0] lay_down_config_path = request.param[1] - with patch("primaite.agents.agent.get_session_path", get_temp_session_path) as mck: + with patch("primaite.agents.agent_abc.get_session_path", get_temp_session_path) as mck: mck.session_timestamp = datetime.now() return TempPrimaiteSession(training_config_path, lay_down_config_path) diff --git a/tests/test_reward.py b/tests/test_reward.py index bb6eb1b0..2edfd44a 100644 --- a/tests/test_reward.py +++ b/tests/test_reward.py @@ -48,5 +48,5 @@ def test_rewards_are_being_penalised_at_each_step_function( """ with temp_primaite_session as session: session.evaluate() - ev_rewards = session.eval_av_reward_per_episode_csv() + ev_rewards = session.eval_av_reward_per_episode_dict() assert ev_rewards[1] == -8.0 diff --git a/tests/test_rllib_agent.py b/tests/test_rllib_agent.py new file mode 100644 index 00000000..645214e3 --- /dev/null +++ b/tests/test_rllib_agent.py @@ -0,0 +1,23 @@ +import pytest + +from primaite import getLogger +from primaite.config.lay_down_config import dos_very_basic_config_path +from tests import TEST_CONFIG_ROOT + +_LOGGER = getLogger(__name__) + + +@pytest.mark.parametrize( + "temp_primaite_session", + [[TEST_CONFIG_ROOT / "training_config_main_rllib.yaml", dos_very_basic_config_path()]], + indirect=True, +) +def test_primaite_session(temp_primaite_session): + """Test the training_config_main_rllib.yaml training config file.""" + with temp_primaite_session as session: + session_path = session.session_path + assert session_path.exists() + session.learn() + + assert len(session.learn_av_reward_per_episode_dict().keys()) == 10 + assert len(session.learn_all_transactions_dict().keys()) == 10 * 256 diff --git a/tests/test_seeding_and_deterministic_session.py b/tests/test_seeding_and_deterministic_session.py index 200eea93..637c1693 100644 --- a/tests/test_seeding_and_deterministic_session.py +++ b/tests/test_seeding_and_deterministic_session.py @@ -43,11 +43,9 @@ def test_seeded_learning(temp_primaite_session): session._training_config.seed == 67890 ), "Expected output is based upon a agent that was trained with seed 67890" session.learn() + actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict() - print("\n") - print(session.learn_av_reward_per_episode()) - - assert expected_mean_reward_per_episode == session.learn_av_reward_per_episode() + assert actual_mean_reward_per_episode == expected_mean_reward_per_episode @pytest.mark.skip(reason="Inconsistent results. Needs someone with RL knowledge to investigate further.") @@ -62,5 +60,5 @@ def test_deterministic_evaluation(temp_primaite_session): # do stuff session.learn() session.evaluate() - eval_mean_reward = session.eval_av_reward_per_episode_csv() + eval_mean_reward = session.eval_av_reward_per_episode_dict() assert len(set(eval_mean_reward.values())) == 1 diff --git a/tests/test_session_loading.py b/tests/test_session_loading.py new file mode 100644 index 00000000..54cac351 --- /dev/null +++ b/tests/test_session_loading.py @@ -0,0 +1,188 @@ +import os.path +import shutil +import tempfile +from pathlib import Path +from typing import Union +from uuid import uuid4 + +from primaite import getLogger +from primaite.agents.sb3 import SB3Agent +from primaite.common.enums import AgentFramework, AgentIdentifier +from primaite.main import run +from primaite.primaite_session import PrimaiteSession +from primaite.utils.session_output_reader import av_rewards_dict +from tests import TEST_ASSETS_ROOT + +_LOGGER = getLogger(__name__) + + +def copy_session_asset(asset_path: Union[str, Path]) -> str: + """Copies the asset into a temporary test folder.""" + if asset_path is None: + raise Exception("No path provided") + + if isinstance(asset_path, Path): + asset_path = str(os.path.normpath(asset_path)) + + copy_path = str(Path(tempfile.gettempdir()) / "primaite" / str(uuid4())) + + # copy the asset into a temp path + try: + shutil.copytree(asset_path, copy_path) + except Exception as e: + msg = f"Unable to copy directory: {asset_path}" + _LOGGER.error(msg, e) + print(msg, e) + + _LOGGER.debug(f"Copied test asset to: {copy_path}") + + # return the copied assets path + return copy_path + + +def test_load_sb3_session(): + """Test that loading an SB3 agent works.""" + expected_learn_mean_reward_per_episode = { + 10: 0, + 11: -0.008037109374999995, + 12: -0.007978515624999988, + 13: -0.008191406249999991, + 14: -0.00817578124999999, + 15: -0.008085937499999998, + 16: -0.007837890624999982, + 17: -0.007798828124999992, + 18: -0.007777343749999998, + 19: -0.007958984374999988, + 20: -0.0077499999999999835, + } + + test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session") + + loaded_agent = SB3Agent(session_path=test_path) + + # loaded agent should have the same UUID as the previous agent + assert loaded_agent.uuid == "301874d3-2e14-43c2-ba7f-e2b03ad05dde" + assert loaded_agent._training_config.agent_framework == AgentFramework.SB3.name + assert loaded_agent._training_config.agent_identifier == AgentIdentifier.PPO.name + assert loaded_agent._training_config.deterministic + assert loaded_agent._training_config.seed == 12345 + assert str(loaded_agent.session_path) == str(test_path) + + # run another learn session + loaded_agent.learn() + + learn_mean_rewards = av_rewards_dict( + loaded_agent.learning_path / f"average_reward_per_episode_{loaded_agent.timestamp_str}.csv" + ) + + # run is seeded so should have the expected learn value + assert learn_mean_rewards == expected_learn_mean_reward_per_episode + + # run an evaluation + loaded_agent.evaluate() + + # load the evaluation average reward csv file + eval_mean_reward = av_rewards_dict( + loaded_agent.evaluation_path / f"average_reward_per_episode_{loaded_agent.timestamp_str}.csv" + ) + + # the agent config ran the evaluation in deterministic mode, so should have the same reward value + assert len(set(eval_mean_reward.values())) == 1 + + # the evaluation should be the same as a previous run + assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988 + + # delete the test directory + shutil.rmtree(test_path) + + +def test_load_primaite_session(): + """Test that loading a Primaite session works.""" + expected_learn_mean_reward_per_episode = { + 10: 0, + 11: -0.008037109374999995, + 12: -0.007978515624999988, + 13: -0.008191406249999991, + 14: -0.00817578124999999, + 15: -0.008085937499999998, + 16: -0.007837890624999982, + 17: -0.007798828124999992, + 18: -0.007777343749999998, + 19: -0.007958984374999988, + 20: -0.0077499999999999835, + } + + test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session") + + # create loaded session + session = PrimaiteSession(session_path=test_path) + + # run setup on session + session.setup() + + # make sure that the session was loaded correctly + assert session._agent_session.uuid == "301874d3-2e14-43c2-ba7f-e2b03ad05dde" + assert session._agent_session._training_config.agent_framework == AgentFramework.SB3.name + assert session._agent_session._training_config.agent_identifier == AgentIdentifier.PPO.name + assert session._agent_session._training_config.deterministic + assert session._agent_session._training_config.seed == 12345 + assert str(session._agent_session.session_path) == str(test_path) + + # run another learn session + session.learn() + + learn_mean_rewards = av_rewards_dict( + session.learning_path / f"average_reward_per_episode_{session.timestamp_str}.csv" + ) + + # run is seeded so should have the expected learn value + assert learn_mean_rewards == expected_learn_mean_reward_per_episode + + # run an evaluation + session.evaluate() + + # load the evaluation average reward csv file + eval_mean_reward = av_rewards_dict( + session.evaluation_path / f"average_reward_per_episode_{session.timestamp_str}.csv" + ) + + # the agent config ran the evaluation in deterministic mode, so should have the same reward value + assert len(set(eval_mean_reward.values())) == 1 + + # the evaluation should be the same as a previous run + assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988 + + # delete the test directory + shutil.rmtree(test_path) + + +def test_run_loading(): + """Test loading session via main.run.""" + expected_learn_mean_reward_per_episode = { + 10: 0, + 11: -0.008037109374999995, + 12: -0.007978515624999988, + 13: -0.008191406249999991, + 14: -0.00817578124999999, + 15: -0.008085937499999998, + 16: -0.007837890624999982, + 17: -0.007798828124999992, + 18: -0.007777343749999998, + 19: -0.007958984374999988, + 20: -0.0077499999999999835, + } + + test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session") + + # create loaded session + run(session_path=test_path) + + learn_mean_rewards = av_rewards_dict( + next(Path(test_path).rglob("**/learning/average_reward_per_episode_*.csv"), None) + ) + + # run is seeded so should have the expected learn value + assert learn_mean_rewards == expected_learn_mean_reward_per_episode + + # delete the test directory + shutil.rmtree(test_path)