diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index 066c66b2..0bb03594 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -86,5 +86,5 @@ stages: displayName: 'Perform PrimAITE Setup' - script: | - pytest tests/ + pytest -n 4 displayName: 'Run tests' diff --git a/.gitignore b/.gitignore index 5d6434f1..ef1050e6 100644 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,9 @@ coverage.xml .hypothesis/ .pytest_cache/ cover/ +tests/assets/**/*.png +tests/assets/**/tensorboard_logs/ +tests/assets/**/checkpoints/ # Translations *.mo diff --git a/docs/_templates/custom-class-template.rst b/docs/_templates/custom-class-template.rst index 8a539bc9..acffdc4c 100644 --- a/docs/_templates/custom-class-template.rst +++ b/docs/_templates/custom-class-template.rst @@ -1,3 +1,7 @@ +.. only:: comment + + Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + .. Credit to https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion for the custom templates. .. diff --git a/docs/_templates/custom-module-template.rst b/docs/_templates/custom-module-template.rst index e6ecabd1..8eebad3e 100644 --- a/docs/_templates/custom-module-template.rst +++ b/docs/_templates/custom-module-template.rst @@ -1,3 +1,7 @@ +.. only:: comment + + Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + .. Credit to https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion for the custom templates. .. diff --git a/docs/api.rst b/docs/api.rst index df2bc193..b24dafc3 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1,3 +1,7 @@ +.. only:: comment + + Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + .. DO NOT DELETE THIS FILE! It contains the all-important `.. autosummary::` directive with `:recursive:` option, without which API documentation wouldn't get extracted from docstrings by the `sphinx.ext.autosummary` engine. It is hidden diff --git a/docs/conf.py b/docs/conf.py index 51b745cf..8afc1246 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. # Configuration file for the Sphinx documentation builder. # # For the full list of built-in configuration values, see the documentation: diff --git a/docs/index.rst b/docs/index.rst index fed65919..de5bed46 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,3 +1,7 @@ +.. only:: comment + + Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + Welcome to PrimAITE's documentation ==================================== @@ -20,17 +24,24 @@ What is PrimAITE built with * `OpenAI's Gym `_ is used as the basis for AI blue agent interaction with the PrimAITE environment * `Networkx `_ is used as the underlying data structure used for the PrimAITE environment * `Stable Baselines 3 `_ is used as a default source of RL algorithms (although PrimAITE is not limited to SB3 agents) +* `Ray RLlib `_ is used as an additional source of RL algorithms +* `Typer `_ is used for building CLIs (Command Line Interface applications) +* `Jupyterlab `_ is used as an extensible environment for interactive and reproducible computing, based on the Jupyter Notebook Architecture +* `Platformdirs `_ is used for finding the right location to store user data and configuration but varies per platform +* `Plotly `_ is used for building high level charts + Where next? ------------ -The best place to start is :ref:`about` +Head over to the :ref:`getting-started` page to install and setup PrimAITE! .. toctree:: :maxdepth: 8 :caption: Contents: :hidden: + source/getting_started source/about source/config source/primaite_session @@ -41,12 +52,14 @@ The best place to start is :ref:`about` source/glossary source/migration_1.2_-_2.0 + +.. TODO: Add project links once public repo has been created + .. toctree:: :caption: Project Links: :hidden: -.. - #Code <> - #Issues <> - #Pull Requests <> - #Discussions <> + Code + Issues + Pull Requests + Discussions diff --git a/docs/source/about.rst b/docs/source/about.rst index a7135fc0..2068472c 100644 --- a/docs/source/about.rst +++ b/docs/source/about.rst @@ -1,4 +1,8 @@ -.. _about: +.. only:: comment + + Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + +.. _about: About PrimAITE ============== diff --git a/docs/source/config.rst b/docs/source/config.rst index af590a24..058565da 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -1,3 +1,7 @@ +.. only:: comment + + Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + .. _config: The Config Files Explained diff --git a/docs/source/custom_agent.rst b/docs/source/custom_agent.rst index b4552d64..ba438305 100644 --- a/docs/source/custom_agent.rst +++ b/docs/source/custom_agent.rst @@ -1,4 +1,8 @@ -Custom Agents +.. only:: comment + + Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + +Custom Agents ============= diff --git a/docs/source/dependencies.rst b/docs/source/dependencies.rst index bbca3fce..0d3f21c3 100644 --- a/docs/source/dependencies.rst +++ b/docs/source/dependencies.rst @@ -1,3 +1,7 @@ +.. only:: comment + + Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + .. role:: raw-html(raw) :format: html diff --git a/docs/source/getting_started.rst b/docs/source/getting_started.rst new file mode 100644 index 00000000..13c9d699 --- /dev/null +++ b/docs/source/getting_started.rst @@ -0,0 +1,155 @@ +.. only:: comment + + Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + +.. _getting-started: + +Getting Started +=============== + +**Getting Started with PrimAITE** + +Pre-Requisites + +In order to get **PrimAITE** installed, you will need to have a python version between 3.8 and 3.10 installed. If you don't already have it, this is how to install it: + + +.. tabs:: lang + + .. code-tab:: bash + :caption: Unix + + sudo add-apt-repository ppa:deadsnakes/ppa + sudo apt install python3.10 + sudo apt-get install python3-pip + sudo apt-get install python3-venv + + .. code-tab:: text + :caption: Windows (Powershell) + + - Manual install from: https://www.python.org/downloads/release/python-31011/ + +**PrimAITE** is designed to be OS-agnostic, and thus should work on most variations/distros of Linux, Windows, and MacOS. + +Install PrimAITE +**************** + +1. Create a primaite directory in your home directory: + +.. tabs:: lang + + .. code-tab:: bash + :caption: Unix + + mkdir ~/primaite + + .. code-tab:: powershell + :caption: Windows (Powershell) + + mkdir ~\primaite + +2. Navigate to the primaite directory and create a new python virtual environment (venv) + +.. tabs:: lang + + .. code-tab:: bash + :caption: Unix + + cd ~/primaite + python3 -m venv .venv + + .. code-tab:: powershell + :caption: Windows (Powershell) + + cd ~\primaite + python3 -m venv .venv + attrib +h .venv /s /d # Hides the .venv directory + +3. Activate the venv + +.. tabs:: lang + + .. code-tab:: bash + :caption: Unix + + source .venv/bin/activate + + .. code-tab:: powershell + :caption: Windows (Powershell) + + .\.venv\Scripts\activate + + +4. Install PrimAITE using pip from PyPi + +.. tabs:: lang + + .. code-tab:: bash + :caption: Unix + + pip install primaite + + .. code-tab:: powershell + :caption: Windows (Powershell) + + pip install primaite + +5. Perform the PrimAITE setup + +.. tabs:: lang + + .. code-tab:: bash + :caption: Unix + + primaite setup + + .. code-tab:: powershell + :caption: Windows (Powershell) + + primaite setup + +Clone & Install PrimAITE for Development +**************************************** + +To be able to extend PrimAITE further, or to build wheels manually before install, clone the repository to a location +of your choice: + +.. TODO:: Add repo path once we know what it is + +.. code-block:: bash + + git clone + cd primaite + +Create and activate your Python virtual environment (venv) + +.. tabs:: lang + + .. code-tab:: bash + :caption: Unix + + python3 -m venv venv + source venv/bin/activate + + .. code-tab:: powershell + :caption: Windows (Powershell) + + python3 -m venv venv + .\venv\Scripts\activate + +Install PrimAITE with the dev extra + +.. tabs:: lang + + .. code-tab:: bash + :caption: Unix + + pip install -e .[dev] + + .. code-tab:: powershell + :caption: Windows (Powershell) + + pip install -e .[dev] + + +To view the complete list of packages installed during PrimAITE installation, go to the dependencies page (:ref:`Dependencies`). diff --git a/docs/source/glossary.rst b/docs/source/glossary.rst index 58b4cd5e..3422d51e 100644 --- a/docs/source/glossary.rst +++ b/docs/source/glossary.rst @@ -1,3 +1,7 @@ +.. only:: comment + + Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + Glossary ============= diff --git a/docs/source/migration_1.2_-_2.0.rst b/docs/source/migration_1.2_-_2.0.rst index 2adf9656..b7c9996d 100644 --- a/docs/source/migration_1.2_-_2.0.rst +++ b/docs/source/migration_1.2_-_2.0.rst @@ -1,3 +1,7 @@ +.. only:: comment + + Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + v1.2 to v2.0 Migration guide ============================ diff --git a/docs/source/primaite_session.rst b/docs/source/primaite_session.rst index a393093c..b8895fc7 100644 --- a/docs/source/primaite_session.rst +++ b/docs/source/primaite_session.rst @@ -1,3 +1,7 @@ +.. only:: comment + + Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + .. _run a primaite session: Run a PrimAITE Session @@ -10,6 +14,7 @@ A PrimAITE session can be ran either with the ``primaite session`` command from (See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` from a Python terminal or Jupyter Notebook. Both the ``primaite session`` and :func:`primaite.main.run` take a training config and a lay down config as parameters. + .. tabs:: .. code-tab:: bash @@ -19,7 +24,7 @@ Both the ``primaite session`` and :func:`primaite.main.run` take a training conf source ./.venv/bin/activate primaite session ./config/my_training_config.yaml ./config/my_lay_down_config.yaml - .. code-tab:: bash + .. code-tab:: powershell :caption: Powershell CLI cd ~\primaite @@ -42,6 +47,37 @@ The sub-directory is formatted as such: ``~/primaite/sessions//) + +When PrimAITE runs a loaded session, PrimAITE will output in the provided session directory Outputs ------- diff --git a/pyproject.toml b/pyproject.toml index 86418eaa..c5c351a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ dev = [ "pip-licenses==4.3.0", "pre-commit==2.20.0", "pytest==7.2.0", + "pytest-xdist==3.3.1", "pytest-cov==4.0.0", "pytest-flake8==1.1.1", "setuptools==66", diff --git a/setup.py b/setup.py index 63e905c0..efaf24bf 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. from setuptools import setup from wheel.bdist_wheel import bdist_wheel as _bdist_wheel # noqa diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index dacd5c12..c348681d 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. import logging import logging.config import sys diff --git a/src/primaite/acl/__init__.py b/src/primaite/acl/__init__.py index 2623efbc..c6fd79f2 100644 --- a/src/primaite/acl/__init__.py +++ b/src/primaite/acl/__init__.py @@ -1,2 +1,2 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Access Control List. Models firewall functionality.""" diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index d4d843e3..007f12a0 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """A class that implements the access control list implementation for the network.""" from typing import Dict diff --git a/src/primaite/acl/acl_rule.py b/src/primaite/acl/acl_rule.py index 69532376..830cfe35 100644 --- a/src/primaite/acl/acl_rule.py +++ b/src/primaite/acl/acl_rule.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """A class that implements an access control list rule.""" diff --git a/src/primaite/agents/__init__.py b/src/primaite/agents/__init__.py index 89580145..d987b43f 100644 --- a/src/primaite/agents/__init__.py +++ b/src/primaite/agents/__init__.py @@ -1 +1,2 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Common interface between RL agents from different libraries and PrimAITE.""" diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent_abc.py similarity index 62% rename from src/primaite/agents/agent.py rename to src/primaite/agents/agent_abc.py index 90860f7d..9b0dd031 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent_abc.py @@ -1,28 +1,24 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. from __future__ import annotations import json -import time from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path -from typing import Any, Dict, Final, TYPE_CHECKING, Union +from typing import Any, Dict, Optional, TYPE_CHECKING, Union from uuid import uuid4 -import yaml - import primaite from primaite import getLogger, SESSIONS_DIR from primaite.config import lay_down_config, training_config from primaite.config.training_config import TrainingConfig from primaite.data_viz.session_plots import plot_av_reward_per_episode from primaite.environment.primaite_env import Primaite +from primaite.utils.session_metadata_parser import parse_session_metadata if TYPE_CHECKING: from logging import Logger - import numpy as np - - _LOGGER: "Logger" = getLogger(__name__) @@ -53,38 +49,63 @@ class AgentSessionABC(ABC): """ @abstractmethod - def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None: + def __init__( + self, + training_config_path: Optional[Union[str, Path]] = None, + lay_down_config_path: Optional[Union[str, Path]] = None, + session_path: Optional[Union[str, Path]] = None, + ) -> None: """ - Initialise an agent session from config files. + Initialise an agent session from config files, or load a previous session. + + If training configuration and laydown configuration are provided with a session path, + the session path will be used. :param training_config_path: YAML file containing configurable items defined in `primaite.config.training_config.TrainingConfig` :type training_config_path: Union[path, str] :param lay_down_config_path: YAML file containing configurable items for generating network laydown. :type lay_down_config_path: Union[path, str] + :param session_path: directory path of the session to load """ - if not isinstance(training_config_path, Path): - training_config_path = Path(training_config_path) - self._training_config_path: Final[Union[Path, str]] = training_config_path - self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path) - - if not isinstance(lay_down_config_path, Path): - lay_down_config_path = Path(lay_down_config_path) - self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path - self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) - self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level - + # initialise variables self._env: Primaite self._agent = None self._can_learn: bool = False self._can_evaluate: bool = False self.is_eval = False - self._uuid = str(uuid4()) self.session_timestamp: datetime = datetime.now() - "The session timestamp" - self.session_path = get_session_path(self.session_timestamp) - "The Session path" + + # convert session to path + if session_path is not None: + if not isinstance(session_path, Path): + session_path = Path(session_path) + + # if a session path is provided, load it + if not session_path.exists(): + raise Exception(f"Session could not be loaded. Path does not exist: {session_path}") + + # load session + self.load(session_path) + else: + # set training config path + if not isinstance(training_config_path, Path): + training_config_path = Path(training_config_path) + self._training_config_path: Union[Path, str] = training_config_path + self._training_config: TrainingConfig = training_config.load(self._training_config_path) + + if not isinstance(lay_down_config_path, Path): + lay_down_config_path = Path(lay_down_config_path) + self._lay_down_config_path: Union[Path, str] = lay_down_config_path + self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) + self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level + + # set random UUID for session + self._uuid = str(uuid4()) + "The session timestamp" + self.session_path = get_session_path(self.session_timestamp) + "The Session path" @property def timestamp_str(self) -> str: @@ -233,51 +254,27 @@ class AgentSessionABC(ABC): def _get_latest_checkpoint(self) -> None: pass - @classmethod - @abstractmethod - def load(cls, path: Union[str, Path]) -> AgentSessionABC: + def load(self, path: Union[str, Path]): """Load an agent from file.""" - if not isinstance(path, Path): - path = Path(path) + md_dict, training_config_path, laydown_config_path = parse_session_metadata(path) - if path.exists(): - # Unpack the session_metadata.json file - md_file = path / "session_metadata.json" - with open(md_file, "r") as file: - md_dict = json.load(file) + # set training config path + self._training_config_path: Union[Path, str] = training_config_path + self._training_config: TrainingConfig = training_config.load(self._training_config_path) + self._lay_down_config_path: Union[Path, str] = laydown_config_path + self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) + self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level - # Create a temp directory and dump the training and lay down - # configs into it - temp_dir = path / ".temp" - temp_dir.mkdir(exist_ok=True) + # set random UUID for session + self._uuid = md_dict["uuid"] - temp_tc = temp_dir / "tc.yaml" - with open(temp_tc, "w") as file: - yaml.dump(md_dict["env"]["training_config"], file) - - temp_ldc = temp_dir / "ldc.yaml" - with open(temp_ldc, "w") as file: - yaml.dump(md_dict["env"]["lay_down_config"], file) - - agent = cls(temp_tc, temp_ldc) - - agent.session_path = path - - return agent - - else: - # Session path does not exist - msg = f"Failed to load PrimAITE Session, path does not exist: {path}" - _LOGGER.error(msg) - raise FileNotFoundError(msg) + # set the session path + self.session_path = path + "The Session path" @property def _saved_agent_path(self) -> Path: - file_name = ( - f"{self._training_config.agent_framework}_" - f"{self._training_config.agent_identifier}_" - f"{self.timestamp_str}.zip" - ) + file_name = f"{self._training_config.agent_framework}_" f"{self._training_config.agent_identifier}" f".zip" return self.learning_path / file_name @abstractmethod @@ -313,104 +310,3 @@ class AgentSessionABC(ABC): fig = plot_av_reward_per_episode(path, title, subtitle) fig.write_image(image_path) _LOGGER.debug(f"Saved average rewards per episode plot to: {path}") - - -class HardCodedAgentSessionABC(AgentSessionABC): - """ - An Agent Session ABC for evaluation deterministic agents. - - This class cannot be directly instantiated and must be inherited from with all implemented abstract methods - implemented. - """ - - def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None: - """ - Initialise a hardcoded agent session. - - :param training_config_path: YAML file containing configurable items defined in - `primaite.config.training_config.TrainingConfig` - :type training_config_path: Union[path, str] - :param lay_down_config_path: YAML file containing configurable items for generating network laydown. - :type lay_down_config_path: Union[path, str] - """ - super().__init__(training_config_path, lay_down_config_path) - self._setup() - - def _setup(self) -> None: - self._env: Primaite = Primaite( - training_config_path=self._training_config_path, - lay_down_config_path=self._lay_down_config_path, - session_path=self.session_path, - timestamp_str=self.timestamp_str, - ) - super()._setup() - self._can_learn = False - self._can_evaluate = True - - def _save_checkpoint(self) -> None: - pass - - def _get_latest_checkpoint(self) -> None: - pass - - def learn( - self, - **kwargs: Any, - ) -> None: - """ - Train the agent. - - :param kwargs: Any agent-specific key-word args to be passed. - """ - _LOGGER.warning("Deterministic agents cannot learn") - - @abstractmethod - def _calculate_action(self, obs: np.ndarray) -> None: - pass - - def evaluate( - self, - **kwargs: Any, - ) -> None: - """ - Evaluate the agent. - - :param kwargs: Any agent-specific key-word args to be passed. - """ - self._env.set_as_eval() # noqa - self.is_eval = True - - time_steps = self._training_config.num_eval_steps - episodes = self._training_config.num_eval_episodes - - obs = self._env.reset() - for episode in range(episodes): - # Reset env and collect initial observation - for step in range(time_steps): - # Calculate action - action = self._calculate_action(obs) - - # Perform the step - obs, reward, done, info = self._env.step(action) - - if done: - break - - # Introduce a delay between steps - time.sleep(self._training_config.time_delay / 1000) - obs = self._env.reset() - self._env.close() - super().evaluate() - - @classmethod - def load(cls) -> None: - """Load an agent from file.""" - _LOGGER.warning("Deterministic agents cannot be loaded") - - def save(self) -> None: - """Save the agent.""" - _LOGGER.warning("Deterministic agents cannot be saved") - - def export(self) -> None: - """Export the agent to transportable file format.""" - _LOGGER.warning("Deterministic agents cannot be exported") diff --git a/src/primaite/agents/hardcoded_abc.py b/src/primaite/agents/hardcoded_abc.py new file mode 100644 index 00000000..ec4b53e7 --- /dev/null +++ b/src/primaite/agents/hardcoded_abc.py @@ -0,0 +1,116 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +import time +from abc import abstractmethod +from pathlib import Path +from typing import Optional, Union + +from primaite import getLogger +from primaite.agents.agent_abc import AgentSessionABC +from primaite.environment.primaite_env import Primaite + +_LOGGER = getLogger(__name__) + + +class HardCodedAgentSessionABC(AgentSessionABC): + """ + An Agent Session ABC for evaluation deterministic agents. + + This class cannot be directly instantiated and must be inherited from with all implemented abstract methods + implemented. + """ + + def __init__( + self, + training_config_path: Optional[Union[str, Path]] = "", + lay_down_config_path: Optional[Union[str, Path]] = "", + session_path: Optional[Union[str, Path]] = None, + ): + """ + Initialise a hardcoded agent session. + + :param training_config_path: YAML file containing configurable items defined in + `primaite.config.training_config.TrainingConfig` + :type training_config_path: Union[path, str] + :param lay_down_config_path: YAML file containing configurable items for generating network laydown. + :type lay_down_config_path: Union[path, str] + """ + super().__init__(training_config_path, lay_down_config_path, session_path) + self._setup() + + def _setup(self): + self._env: Primaite = Primaite( + training_config_path=self._training_config_path, + lay_down_config_path=self._lay_down_config_path, + session_path=self.session_path, + timestamp_str=self.timestamp_str, + ) + super()._setup() + self._can_learn = False + self._can_evaluate = True + + def _save_checkpoint(self): + pass + + def _get_latest_checkpoint(self): + pass + + def learn( + self, + **kwargs, + ): + """ + Train the agent. + + :param kwargs: Any agent-specific key-word args to be passed. + """ + _LOGGER.warning("Deterministic agents cannot learn") + + @abstractmethod + def _calculate_action(self, obs): + pass + + def evaluate( + self, + **kwargs, + ): + """ + Evaluate the agent. + + :param kwargs: Any agent-specific key-word args to be passed. + """ + self._env.set_as_eval() # noqa + self.is_eval = True + + time_steps = self._training_config.num_eval_steps + episodes = self._training_config.num_eval_episodes + + obs = self._env.reset() + for episode in range(episodes): + # Reset env and collect initial observation + for step in range(time_steps): + # Calculate action + action = self._calculate_action(obs) + + # Perform the step + obs, reward, done, info = self._env.step(action) + + if done: + break + + # Introduce a delay between steps + time.sleep(self._training_config.time_delay / 1000) + obs = self._env.reset() + self._env.close() + + @classmethod + def load(cls, path=None): + """Load an agent from file.""" + _LOGGER.warning("Deterministic agents cannot be loaded") + + def save(self): + """Save the agent.""" + _LOGGER.warning("Deterministic agents cannot be saved") + + def export(self): + """Export the agent to transportable file format.""" + _LOGGER.warning("Deterministic agents cannot be exported") diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py index 0ac5022c..b8c49c14 100644 --- a/src/primaite/agents/hardcoded_acl.py +++ b/src/primaite/agents/hardcoded_acl.py @@ -1,10 +1,11 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. from typing import Dict, List, Union import numpy as np from primaite.acl.access_control_list import AccessControlList from primaite.acl.acl_rule import ACLRule -from primaite.agents.agent import HardCodedAgentSessionABC +from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC from primaite.agents.utils import ( get_new_action, get_node_of_ip, diff --git a/src/primaite/agents/hardcoded_node.py b/src/primaite/agents/hardcoded_node.py index b74c3a0b..10cc2b72 100644 --- a/src/primaite/agents/hardcoded_node.py +++ b/src/primaite/agents/hardcoded_node.py @@ -1,6 +1,7 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. import numpy as np -from primaite.agents.agent import HardCodedAgentSessionABC +from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index 0281de7e..8afc98a1 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -1,10 +1,11 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. from __future__ import annotations import json import shutil from datetime import datetime from pathlib import Path -from typing import Any, Callable, Dict, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Union from uuid import uuid4 from ray.rllib.algorithms import Algorithm @@ -14,7 +15,7 @@ from ray.tune.logger import UnifiedLogger from ray.tune.registry import register_env from primaite import getLogger -from primaite.agents.agent import AgentSessionABC +from primaite.agents.agent_abc import AgentSessionABC from primaite.common.enums import AgentFramework, AgentIdentifier from primaite.environment.primaite_env import Primaite @@ -48,7 +49,12 @@ def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]: class RLlibAgent(AgentSessionABC): """An AgentSession class that implements a Ray RLlib agent.""" - def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None: + def __init__( + self, + training_config_path: Optional[Union[str, Path]] = "", + lay_down_config_path: Optional[Union[str, Path]] = "", + session_path: Optional[Union[str, Path]] = None, + ) -> None: """ Initialise the RLLib Agent training session. @@ -61,6 +67,13 @@ class RLlibAgent(AgentSessionABC): :raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO` or `A2C`) """ + # TODO: implement RLlib agent loading + if session_path is not None: + msg = "RLlib agent loading has not been implemented yet" + _LOGGER.error(msg) + print(msg) + raise NotImplementedError + super().__init__(training_config_path, lay_down_config_path) if not self._training_config.agent_framework == AgentFramework.RLLIB: msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}" diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 462360a0..881426ab 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -1,14 +1,16 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. from __future__ import annotations +import json from pathlib import Path -from typing import Any, TYPE_CHECKING, Union +from typing import Any, Optional, TYPE_CHECKING, Union import numpy as np from stable_baselines3 import A2C, PPO from stable_baselines3.ppo import MlpPolicy as PPOMlp from primaite import getLogger -from primaite.agents.agent import AgentSessionABC +from primaite.agents.agent_abc import AgentSessionABC from primaite.common.enums import AgentFramework, AgentIdentifier from primaite.environment.primaite_env import Primaite @@ -21,7 +23,12 @@ _LOGGER: "Logger" = getLogger(__name__) class SB3Agent(AgentSessionABC): """An AgentSession class that implements a Stable Baselines3 agent.""" - def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None: + def __init__( + self, + training_config_path: Optional[Union[str, Path]] = None, + lay_down_config_path: Optional[Union[str, Path]] = None, + session_path: Optional[Union[str, Path]] = None, + ) -> None: """ Initialise the SB3 Agent training session. @@ -34,7 +41,7 @@ class SB3Agent(AgentSessionABC): :raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO` or `A2C`) """ - super().__init__(training_config_path, lay_down_config_path) + super().__init__(training_config_path, lay_down_config_path, session_path) if not self._training_config.agent_framework == AgentFramework.SB3: msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}" _LOGGER.error(msg) @@ -51,7 +58,7 @@ class SB3Agent(AgentSessionABC): self._tensorboard_log_path = self.learning_path / "tensorboard_logs" self._tensorboard_log_path.mkdir(parents=True, exist_ok=True) - self._setup() + _LOGGER.debug( f"Created {self.__class__.__name__} using: " f"agent_framework={self._training_config.agent_framework}, " @@ -61,8 +68,10 @@ class SB3Agent(AgentSessionABC): self.is_eval = False + self._setup() + def _setup(self) -> None: - super()._setup() + """Set up the SB3 Agent.""" self._env = Primaite( training_config_path=self._training_config_path, lay_down_config_path=self._lay_down_config_path, @@ -70,14 +79,43 @@ class SB3Agent(AgentSessionABC): timestamp_str=self.timestamp_str, ) - self._agent = self._agent_class( - PPOMlp, - self._env, - verbose=self.sb3_output_verbose_level, - n_steps=self._training_config.num_train_steps, - tensorboard_log=str(self._tensorboard_log_path), - seed=self._training_config.seed, - ) + # check if there is a zip file that needs to be loaded + load_file = next(self.session_path.rglob("*.zip"), None) + + if not load_file: + # create a new env and agent + + self._agent = self._agent_class( + PPOMlp, + self._env, + verbose=self.sb3_output_verbose_level, + n_steps=self._training_config.num_train_steps, + tensorboard_log=str(self._tensorboard_log_path), + seed=self._training_config.seed, + ) + else: + # set env values from session metadata + with open(self.session_path / "session_metadata.json", "r") as file: + md_dict = json.load(file) + + # load environment values + if self.is_eval: + # evaluation always starts at 0 + self._env.episode_count = 0 + self._env.total_step_count = 0 + else: + # carry on from previous learning sessions + self._env.episode_count = md_dict["learning"]["total_episodes"] + self._env.total_step_count = md_dict["learning"]["total_time_steps"] + + # load the file + self._agent = self._agent_class.load(load_file, env=self._env) + + # set agent values + self._agent.verbose = self.sb3_output_verbose_level + self._agent.tensorboard_log = self.session_path / "learning/tensorboard_logs" + + super()._setup() def _save_checkpoint(self) -> None: checkpoint_n = self._training_config.checkpoint_every_n_episodes @@ -149,11 +187,6 @@ class SB3Agent(AgentSessionABC): self._env.close() super().evaluate() - @classmethod - def load(cls, path: Union[str, Path]) -> SB3Agent: - """Load an agent from file.""" - raise NotImplementedError - def save(self) -> None: """Save the agent.""" self._agent.save(self._saved_agent_path) diff --git a/src/primaite/agents/simple.py b/src/primaite/agents/simple.py index 2c130c0c..bfc7bcf2 100644 --- a/src/primaite/agents/simple.py +++ b/src/primaite/agents/simple.py @@ -1,6 +1,7 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. from typing import TYPE_CHECKING -from primaite.agents.agent import HardCodedAgentSessionABC +from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum if TYPE_CHECKING: diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index 353978f1..ff0ca8d2 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. from typing import Dict, List, Union import numpy as np diff --git a/src/primaite/cli.py b/src/primaite/cli.py index 863cbfd2..14db236c 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Provides a CLI using Typer as an entry point.""" import logging import os @@ -151,7 +151,7 @@ def setup(overwrite_existing: bool = True) -> None: @app.command() -def session(tc: Optional[str] = None, ldc: Optional[str] = None) -> None: +def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[str] = None) -> None: """ Run a PrimAITE session. @@ -162,11 +162,19 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None) -> None: ldc: The lay down config file path. Optional. If no value is passed then example default lay down config is used from: ~/primaite/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml. + + load: The directory of a previous session. Optional. If no value is passed, then the session + will use the default training config and laydown config. Inversely, if a training config and laydown config + is passed while a session directory is passed, PrimAITE will load the session and ignore the training config + and laydown config. """ from primaite.config.lay_down_config import dos_very_basic_config_path from primaite.config.training_config import main_training_config_path from primaite.main import run + if load is not None: + run(session_path=load) + if not tc: tc = main_training_config_path() diff --git a/src/primaite/common/__init__.py b/src/primaite/common/__init__.py index 5f47b0b5..738a30d1 100644 --- a/src/primaite/common/__init__.py +++ b/src/primaite/common/__init__.py @@ -1,2 +1,2 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Objects which are shared between many PrimAITE modules.""" diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index ff090ca9..0209c64d 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Enumerations for APE.""" from enum import Enum, IntEnum diff --git a/src/primaite/common/protocol.py b/src/primaite/common/protocol.py index f7a757e8..048ed0ab 100644 --- a/src/primaite/common/protocol.py +++ b/src/primaite/common/protocol.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """The protocol class.""" diff --git a/src/primaite/common/service.py b/src/primaite/common/service.py index 1351a30d..7ee694db 100644 --- a/src/primaite/common/service.py +++ b/src/primaite/common/service.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """The Service class.""" from primaite.common.enums import SoftwareState diff --git a/src/primaite/config/__init__.py b/src/primaite/config/__init__.py index 03ed4cf1..9bd899f7 100644 --- a/src/primaite/config/__init__.py +++ b/src/primaite/config/__init__.py @@ -1 +1,2 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Configuration parameters for running experiments.""" diff --git a/src/primaite/config/lay_down_config.py b/src/primaite/config/lay_down_config.py index 2cc5f9c2..80b0f619 100644 --- a/src/primaite/config/lay_down_config.py +++ b/src/primaite/config/lay_down_config.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. from pathlib import Path from typing import Any, Dict, Final, TYPE_CHECKING, Union diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 628e2818..f618b37c 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. from __future__ import annotations from dataclasses import dataclass, field diff --git a/src/primaite/data_viz/__init__.py b/src/primaite/data_viz/__init__.py index db6ce6c8..ad43c141 100644 --- a/src/primaite/data_viz/__init__.py +++ b/src/primaite/data_viz/__init__.py @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Utility to generate plots of sessions metrics after PrimAITE.""" from enum import Enum diff --git a/src/primaite/data_viz/session_plots.py b/src/primaite/data_viz/session_plots.py index 245b9774..39c2b4cc 100644 --- a/src/primaite/data_viz/session_plots.py +++ b/src/primaite/data_viz/session_plots.py @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. from pathlib import Path from typing import Dict, Optional, Union diff --git a/src/primaite/environment/__init__.py b/src/primaite/environment/__init__.py index 8b0060c0..e837fe1e 100644 --- a/src/primaite/environment/__init__.py +++ b/src/primaite/environment/__init__.py @@ -1,2 +1,2 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Gym/Gymnasium environment for RL agents consisting of a simulated computer network.""" diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index cb9872d1..ebc47043 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Module for handling configurable observation spaces in PrimAITE.""" import logging from abc import ABC, abstractmethod diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index f78b5f8d..8f34204b 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Main environment module containing the PRIMmary AI Training Evironment (Primaite) class.""" import copy import logging diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index a0efac4d..aad15246 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Implements reward function.""" from typing import Dict, TYPE_CHECKING, Union diff --git a/src/primaite/links/__init__.py b/src/primaite/links/__init__.py index 6257f282..21ce44ba 100644 --- a/src/primaite/links/__init__.py +++ b/src/primaite/links/__init__.py @@ -1,2 +1,2 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Network connections between nodes in the simulation.""" diff --git a/src/primaite/links/link.py b/src/primaite/links/link.py index 145de5f3..aa3fa7fb 100644 --- a/src/primaite/links/link.py +++ b/src/primaite/links/link.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """The link class.""" from typing import List diff --git a/src/primaite/main.py b/src/primaite/main.py index 78420972..aed39d73 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -1,8 +1,8 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """The main PrimAITE session runner module.""" import argparse from pathlib import Path -from typing import Union +from typing import Optional, Union from primaite import getLogger from primaite.primaite_session import PrimaiteSession @@ -11,16 +11,21 @@ _LOGGER = getLogger(__name__) def run( - training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path], + training_config_path: Optional[Union[str, Path]] = "", + lay_down_config_path: Optional[Union[str, Path]] = "", + session_path: Optional[Union[str, Path]] = None, ) -> None: """ Run the PrimAITE Session. - :param training_config_path: The training config filepath. - :param lay_down_config_path: The lay down config filepath. + :param training_config_path: YAML file containing configurable items defined in + `primaite.config.training_config.TrainingConfig` + :type training_config_path: Union[path, str] + :param lay_down_config_path: YAML file containing configurable items for generating network laydown. + :type lay_down_config_path: Union[path, str] + :param session_path: directory path of the session to load """ - session = PrimaiteSession(training_config_path, lay_down_config_path) + session = PrimaiteSession(training_config_path, lay_down_config_path, session_path) session.setup() session.learn() @@ -31,9 +36,14 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--tc") parser.add_argument("--ldc") + parser.add_argument("--load") + args = parser.parse_args() - if not args.tc: - _LOGGER.error("Please provide a training config file using the --tc " "argument") - if not args.ldc: - _LOGGER.error("Please provide a lay down config file using the --ldc " "argument") - run(training_config_path=args.tc, lay_down_config_path=args.ldc) + if args.load: + run(session_path=args.load) + else: + if not args.tc: + _LOGGER.error("Please provide a training config file using the --tc " "argument") + if not args.ldc: + _LOGGER.error("Please provide a lay down config file using the --ldc " "argument") + run(training_config_path=args.tc, lay_down_config_path=args.ldc) diff --git a/src/primaite/nodes/__init__.py b/src/primaite/nodes/__init__.py index 19347372..43b213d6 100644 --- a/src/primaite/nodes/__init__.py +++ b/src/primaite/nodes/__init__.py @@ -1,2 +1,2 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Nodes represent network hosts in the simulation.""" diff --git a/src/primaite/nodes/active_node.py b/src/primaite/nodes/active_node.py index b73f80f0..b5df70b5 100644 --- a/src/primaite/nodes/active_node.py +++ b/src/primaite/nodes/active_node.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """An Active Node (i.e. not an actuator).""" import logging from typing import Final diff --git a/src/primaite/nodes/node.py b/src/primaite/nodes/node.py index 7dd7d962..9118fa46 100644 --- a/src/primaite/nodes/node.py +++ b/src/primaite/nodes/node.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """The base Node class.""" from typing import Final diff --git a/src/primaite/nodes/node_state_instruction_green.py b/src/primaite/nodes/node_state_instruction_green.py index 0826efe6..8e03b40f 100644 --- a/src/primaite/nodes/node_state_instruction_green.py +++ b/src/primaite/nodes/node_state_instruction_green.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Defines node behaviour for Green PoL.""" from typing import TYPE_CHECKING, Union diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py index abbe07ad..786e93ac 100644 --- a/src/primaite/nodes/node_state_instruction_red.py +++ b/src/primaite/nodes/node_state_instruction_red.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Defines node behaviour for Green PoL.""" from dataclasses import dataclass from typing import TYPE_CHECKING, Union diff --git a/src/primaite/nodes/passive_node.py b/src/primaite/nodes/passive_node.py index c79636e3..88c8cc85 100644 --- a/src/primaite/nodes/passive_node.py +++ b/src/primaite/nodes/passive_node.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """The Passive Node class (i.e. an actuator).""" from primaite.common.enums import HardwareState, NodeType, Priority from primaite.config.training_config import TrainingConfig diff --git a/src/primaite/nodes/service_node.py b/src/primaite/nodes/service_node.py index ef0cd92e..ce1ffe92 100644 --- a/src/primaite/nodes/service_node.py +++ b/src/primaite/nodes/service_node.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """A Service Node (i.e. not an actuator).""" import logging from typing import Dict, Final diff --git a/src/primaite/notebooks/__init__.py b/src/primaite/notebooks/__init__.py index 6bb5abf4..eaf10005 100644 --- a/src/primaite/notebooks/__init__.py +++ b/src/primaite/notebooks/__init__.py @@ -1,5 +1,6 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Contains default jupyter notebooks which demonstrate PrimAITE functionality.""" -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. + import importlib.util import os import subprocess diff --git a/src/primaite/pol/__init__.py b/src/primaite/pol/__init__.py index cba4b28b..1adb1491 100644 --- a/src/primaite/pol/__init__.py +++ b/src/primaite/pol/__init__.py @@ -1,2 +1,2 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Pattern of Life- Represents the actions of users on the network.""" -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. diff --git a/src/primaite/pol/green_pol.py b/src/primaite/pol/green_pol.py index 7df87590..0425a831 100644 --- a/src/primaite/pol/green_pol.py +++ b/src/primaite/pol/green_pol.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Implements Pattern of Life on the network (nodes and links).""" from typing import Dict diff --git a/src/primaite/pol/ier.py b/src/primaite/pol/ier.py index b46dbf22..7fab340d 100644 --- a/src/primaite/pol/ier.py +++ b/src/primaite/pol/ier.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """ Information Exchange Requirements for APE. diff --git a/src/primaite/pol/red_agent_pol.py b/src/primaite/pol/red_agent_pol.py index 2801e8b0..ad55fa24 100644 --- a/src/primaite/pol/red_agent_pol.py +++ b/src/primaite/pol/red_agent_pol.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Implements POL on the network (nodes and links) resulting from the red agent attack.""" from typing import Dict diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index 5ef856d7..ab3c2150 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -1,11 +1,12 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Main entry point to PrimAITE. Configure training/evaluation experiments and input/output.""" from __future__ import annotations from pathlib import Path -from typing import Any, Dict, Final, Union +from typing import Any, Dict, Final, Optional, Union from primaite import getLogger -from primaite.agents.agent import AgentSessionABC +from primaite.agents.agent_abc import AgentSessionABC from primaite.agents.hardcoded_acl import HardCodedACLAgent from primaite.agents.hardcoded_node import HardCodedNodeAgent from primaite.agents.rllib import RLlibAgent @@ -14,6 +15,7 @@ from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyA from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType from primaite.config import lay_down_config, training_config from primaite.config.training_config import TrainingConfig +from primaite.utils.session_metadata_parser import parse_session_metadata _LOGGER = getLogger(__name__) @@ -27,15 +29,39 @@ class PrimaiteSession: def __init__( self, - training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path], + training_config_path: Optional[Union[str, Path]] = "", + lay_down_config_path: Optional[Union[str, Path]] = "", + session_path: Optional[Union[str, Path]] = None, ) -> None: """ The PrimaiteSession constructor. - :param training_config_path: The training config path. - :param lay_down_config_path: The lay down config path. + :param training_config_path: YAML file containing configurable items defined in + `primaite.config.training_config.TrainingConfig` + :type training_config_path: Union[path, str] + :param lay_down_config_path: YAML file containing configurable items for generating network laydown. + :type lay_down_config_path: Union[path, str] + :param session_path: directory path of the session to load """ + self._agent_session: AgentSessionABC = None # noqa + self.session_path: Path = session_path # noqa + self.timestamp_str: str = None # noqa + self.learning_path: Path = None # noqa + self.evaluation_path: Path = None # noqa + + # check if session path is provided + if session_path is not None: + # set load_session to true + self.is_load_session = True + if not isinstance(session_path, Path): + session_path = Path(session_path) + + # if a session path is provided, load it + if not session_path.exists(): + raise Exception(f"Session could not be loaded. Path does not exist: {session_path}") + + md_dict, training_config_path, lay_down_config_path = parse_session_metadata(session_path) + if not isinstance(training_config_path, Path): training_config_path = Path(training_config_path) self._training_config_path: Final[Union[Path, str]] = training_config_path @@ -60,11 +86,15 @@ class PrimaiteSession: _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}") if self._training_config.action_type == ActionType.NODE: # Deterministic Hardcoded Agent with Node Action Space - self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path) + self._agent_session = HardCodedNodeAgent( + self._training_config_path, self._lay_down_config_path, self.session_path + ) elif self._training_config.action_type == ActionType.ACL: # Deterministic Hardcoded Agent with ACL Action Space - self._agent_session = HardCodedACLAgent(self._training_config_path, self._lay_down_config_path) + self._agent_session = HardCodedACLAgent( + self._training_config_path, self._lay_down_config_path, self.session_path + ) elif self._training_config.action_type == ActionType.ANY: # Deterministic Hardcoded Agent with ANY Action Space @@ -77,11 +107,15 @@ class PrimaiteSession: elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING: _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHING}") if self._training_config.action_type == ActionType.NODE: - self._agent_session = DoNothingNodeAgent(self._training_config_path, self._lay_down_config_path) + self._agent_session = DoNothingNodeAgent( + self._training_config_path, self._lay_down_config_path, self.session_path + ) elif self._training_config.action_type == ActionType.ACL: # Deterministic Hardcoded Agent with ACL Action Space - self._agent_session = DoNothingACLAgent(self._training_config_path, self._lay_down_config_path) + self._agent_session = DoNothingACLAgent( + self._training_config_path, self._lay_down_config_path, self.session_path + ) elif self._training_config.action_type == ActionType.ANY: # Deterministic Hardcoded Agent with ANY Action Space @@ -93,10 +127,14 @@ class PrimaiteSession: elif self._training_config.agent_identifier == AgentIdentifier.RANDOM: _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}") - self._agent_session = RandomAgent(self._training_config_path, self._lay_down_config_path) + self._agent_session = RandomAgent( + self._training_config_path, self._lay_down_config_path, self.session_path + ) elif self._training_config.agent_identifier == AgentIdentifier.DUMMY: _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}") - self._agent_session = DummyAgent(self._training_config_path, self._lay_down_config_path) + self._agent_session = DummyAgent( + self._training_config_path, self._lay_down_config_path, self.session_path + ) else: # Invalid AgentFramework AgentIdentifier combo @@ -105,12 +143,12 @@ class PrimaiteSession: elif self._training_config.agent_framework == AgentFramework.SB3: _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}") # Stable Baselines3 Agent - self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path) + self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path, self.session_path) elif self._training_config.agent_framework == AgentFramework.RLLIB: _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}") # Ray RLlib Agent - self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path) + self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path, self.session_path) else: # Invalid AgentFramework diff --git a/src/primaite/setup/__init__.py b/src/primaite/setup/__init__.py index 3c0bfe14..acfa48c4 100644 --- a/src/primaite/setup/__init__.py +++ b/src/primaite/setup/__init__.py @@ -1,2 +1,2 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Utilities to prepare the user's data folders.""" -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. diff --git a/src/primaite/setup/old_installation_clean_up.py b/src/primaite/setup/old_installation_clean_up.py index 1603f06e..43950e4f 100644 --- a/src/primaite/setup/old_installation_clean_up.py +++ b/src/primaite/setup/old_installation_clean_up.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. from typing import TYPE_CHECKING from primaite import getLogger diff --git a/src/primaite/setup/reset_demo_notebooks.py b/src/primaite/setup/reset_demo_notebooks.py index 530a2c30..775f43b5 100644 --- a/src/primaite/setup/reset_demo_notebooks.py +++ b/src/primaite/setup/reset_demo_notebooks.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. import filecmp import os import shutil diff --git a/src/primaite/setup/reset_example_configs.py b/src/primaite/setup/reset_example_configs.py index 99d04149..df3b36a1 100644 --- a/src/primaite/setup/reset_example_configs.py +++ b/src/primaite/setup/reset_example_configs.py @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. import filecmp import os import shutil diff --git a/src/primaite/setup/setup_app_dirs.py b/src/primaite/setup/setup_app_dirs.py index 1288e63c..56f16a08 100644 --- a/src/primaite/setup/setup_app_dirs.py +++ b/src/primaite/setup/setup_app_dirs.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. from typing import TYPE_CHECKING from primaite import _USER_DIRS, getLogger, LOG_DIR, NOTEBOOKS_DIR diff --git a/src/primaite/transactions/__init__.py b/src/primaite/transactions/__init__.py index 45315b22..9a881fd5 100644 --- a/src/primaite/transactions/__init__.py +++ b/src/primaite/transactions/__init__.py @@ -1,2 +1,2 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Record data of the system's state and agent's observations and actions.""" -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index 09ec2cec..1a702748 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """The Transaction class.""" from datetime import datetime from typing import List, Optional, Tuple, TYPE_CHECKING, Union diff --git a/src/primaite/utils/__init__.py b/src/primaite/utils/__init__.py index 55e8a6ba..5dbd1e5f 100644 --- a/src/primaite/utils/__init__.py +++ b/src/primaite/utils/__init__.py @@ -1 +1,2 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Utilities for PrimAITE.""" diff --git a/src/primaite/utils/package_data.py b/src/primaite/utils/package_data.py index a994f880..b9abca8f 100644 --- a/src/primaite/utils/package_data.py +++ b/src/primaite/utils/package_data.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. import os from pathlib import Path from typing import TYPE_CHECKING diff --git a/src/primaite/utils/session_metadata_parser.py b/src/primaite/utils/session_metadata_parser.py new file mode 100644 index 00000000..eb3c3339 --- /dev/null +++ b/src/primaite/utils/session_metadata_parser.py @@ -0,0 +1,59 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +import json +from pathlib import Path +from typing import Union + +import yaml + +from primaite import getLogger + +_LOGGER = getLogger(__name__) + + +def parse_session_metadata(session_path: Union[Path, str], dict_only=False): + """ + Loads a session metadata from the given directory path. + + :param session_path: Directory where the session metadata file is in + :param dict_only: If dict_only is true, the function will only return the dict contents of session metadata + + :return: Dictionary which has all the session metadata contents + :rtype: Dict + + :return: Path where the YAML copy of the training config is dumped into + :rtype: str + :return: Path where the YAML copy of the laydown config is dumped into + :rtype: str + """ + if not isinstance(session_path, Path): + session_path = Path(session_path) + + if not session_path.exists(): + # Session path does not exist + msg = f"Failed to load PrimAITE Session, path does not exist: {session_path}" + _LOGGER.error(msg) + raise FileNotFoundError(msg) + + # Unpack the session_metadata.json file + md_file = session_path / "session_metadata.json" + with open(md_file, "r") as file: + md_dict = json.load(file) + + # if dict only, return dict without doing anything else + if dict_only: + return md_dict + + # Create a temp directory and dump the training and lay down + # configs into it + temp_dir = session_path / ".temp" + temp_dir.mkdir(exist_ok=True) + + temp_tc = temp_dir / "tc.yaml" + with open(temp_tc, "w") as file: + yaml.dump(md_dict["env"]["training_config"], file) + + temp_ldc = temp_dir / "ldc.yaml" + with open(temp_ldc, "w") as file: + yaml.dump(md_dict["env"]["lay_down_config"], file) + + return [md_dict, temp_tc, temp_ldc] diff --git a/src/primaite/utils/session_output_reader.py b/src/primaite/utils/session_output_reader.py index ad3dd4f4..7089c69a 100644 --- a/src/primaite/utils/session_output_reader.py +++ b/src/primaite/utils/session_output_reader.py @@ -1,5 +1,6 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. from pathlib import Path -from typing import Dict, Union +from typing import Any, Dict, Tuple, Union # Using polars as it's faster than Pandas; it will speed things up when # files get big! @@ -13,8 +14,33 @@ def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]: The dictionary keys are the episode number, and the values are the mean reward that episode. :param av_rewards_csv_file: The average rewards per episode csv file path. - :return: The average rewards per episode cdv as a dict. + :return: The average rewards per episode csv as a dict. """ - df = pl.read_csv(av_rewards_csv_file).to_dict() + df_dict = pl.read_csv(av_rewards_csv_file).to_dict() - return {v: df["Average Reward"][i] for i, v in enumerate(df["Episode"])} + return {v: df_dict["Average Reward"][i] for i, v in enumerate(df_dict["Episode"])} + + +def all_transactions_dict(all_transactions_csv_file: Union[str, Path]) -> Dict[Tuple[int, int], Dict[str, Any]]: + """ + Read an all transactions csv file and return as a dict. + + The dict keys are a tuple with the structure (episode, step). The dict + values are the remaining columns as a dict. + + :param all_transactions_csv_file: The all transactions csv file path. + :return: The all transactions csv file as a dict. + """ + df_dict = pl.read_csv(all_transactions_csv_file).to_dict() + new_dict = {} + + episodes = df_dict["Episode"] + steps = df_dict["Step"] + keys = list(df_dict.keys()) + + for i in range(len(episodes)): + key = (episodes[i], steps[i]) + value_dict = {key: df_dict[key][i] for key in keys if key not in ["Episode", "Step"]} + new_dict[key] = value_dict + + return new_dict diff --git a/src/primaite/utils/session_output_writer.py b/src/primaite/utils/session_output_writer.py index d05f69b1..e7f1b248 100644 --- a/src/primaite/utils/session_output_writer.py +++ b/src/primaite/utils/session_output_writer.py @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. import csv from logging import Logger from typing import Final, List, Tuple, TYPE_CHECKING, Union diff --git a/tests/__init__.py b/tests/__init__.py index 4a0bdce1..f8e6fc55 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,6 +1,9 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. from pathlib import Path from typing import Final TEST_CONFIG_ROOT: Final[Path] = Path(__file__).parent / "config" "The tests config root directory." + +TEST_ASSETS_ROOT: Final[Path] = Path(__file__).parent / "assets" +"The tests assets root directory." diff --git a/tests/assets/example_sb3_agent_session/session_metadata.json b/tests/assets/example_sb3_agent_session/session_metadata.json new file mode 100644 index 00000000..20f6a77c --- /dev/null +++ b/tests/assets/example_sb3_agent_session/session_metadata.json @@ -0,0 +1 @@ +{"uuid": "301874d3-2e14-43c2-ba7f-e2b03ad05dde", "start_datetime": "2023-07-14T09:48:22.973005", "end_datetime": "2023-07-14T09:48:34.182715", "learning": {"total_episodes": 10, "total_time_steps": 2560}, "evaluation": {"total_episodes": 5, "total_time_steps": 1280}, "env": {"training_config": {"agent_framework": "SB3", "deep_learning_framework": "TF2", "agent_identifier": "PPO", "hard_coded_agent_view": "FULL", "random_red_agent": false, "action_type": "NODE", "num_train_episodes": 10, "num_train_steps": 256, "num_eval_episodes": 5, "num_eval_steps": 256, "checkpoint_every_n_episodes": 10, "observation_space": {"components": [{"name": "NODE_LINK_TABLE"}]}, "time_delay": 5, "session_type": "TRAIN_EVAL", "load_agent": false, "agent_load_file": null, "observation_space_high_value": 1000000000, "sb3_output_verbose_level": "NONE", "all_ok": 0, "off_should_be_on": -0.001, "off_should_be_resetting": -0.0005, "on_should_be_off": -0.0002, "on_should_be_resetting": -0.0005, "resetting_should_be_on": -0.0005, "resetting_should_be_off": -0.0002, "resetting": -0.0003, "good_should_be_patching": 0.0002, "good_should_be_compromised": 0.0005, "good_should_be_overwhelmed": 0.0005, "patching_should_be_good": -0.0005, "patching_should_be_compromised": 0.0002, "patching_should_be_overwhelmed": 0.0002, "patching": -0.0003, "compromised_should_be_good": -0.002, "compromised_should_be_patching": -0.002, "compromised_should_be_overwhelmed": -0.002, "compromised": -0.002, "overwhelmed_should_be_good": -0.002, "overwhelmed_should_be_patching": -0.002, "overwhelmed_should_be_compromised": -0.002, "overwhelmed": -0.002, "good_should_be_repairing": 0.0002, "good_should_be_restoring": 0.0002, "good_should_be_corrupt": 0.0005, "good_should_be_destroyed": 0.001, "repairing_should_be_good": -0.0005, "repairing_should_be_restoring": 0.0002, "repairing_should_be_corrupt": 0.0002, "repairing_should_be_destroyed": 0.0, "repairing": -0.0003, "restoring_should_be_good": -0.001, "restoring_should_be_repairing": -0.0002, "restoring_should_be_corrupt": 0.0001, "restoring_should_be_destroyed": 0.0002, "restoring": -0.0006, "corrupt_should_be_good": -0.001, "corrupt_should_be_repairing": -0.001, "corrupt_should_be_restoring": -0.001, "corrupt_should_be_destroyed": 0.0002, "corrupt": -0.001, "destroyed_should_be_good": -0.002, "destroyed_should_be_repairing": -0.002, "destroyed_should_be_restoring": -0.002, "destroyed_should_be_corrupt": -0.002, "destroyed": -0.002, "scanning": -0.0002, "red_ier_running": -0.0005, "green_ier_blocked": -0.001, "os_patching_duration": 5, "node_reset_duration": 5, "node_booting_duration": 3, "node_shutdown_duration": 2, "service_patching_duration": 5, "file_system_repairing_limit": 5, "file_system_restoring_limit": 5, "file_system_scanning_limit": 5, "deterministic": true, "seed": 12345}, "lay_down_config": [{"item_type": "PORTS", "ports_list": [{"port": "80"}]}, {"item_type": "SERVICES", "service_list": [{"name": "TCP"}]}, {"item_type": "NODE", "node_id": "1", "name": "PC1", "node_class": "SERVICE", "node_type": "COMPUTER", "priority": "P5", "hardware_state": "ON", "ip_address": "192.168.1.2", "software_state": "GOOD", "file_system_state": "GOOD", "services": [{"name": "TCP", "port": "80", "state": "GOOD"}]}, {"item_type": "NODE", "node_id": "2", "name": "PC2", "node_class": "SERVICE", "node_type": "COMPUTER", "priority": "P5", "hardware_state": "ON", "ip_address": "192.168.1.3", "software_state": "GOOD", "file_system_state": "GOOD", "services": [{"name": "TCP", "port": "80", "state": "GOOD"}]}, {"item_type": "NODE", "node_id": "3", "name": "SWITCH1", "node_class": "ACTIVE", "node_type": "SWITCH", "priority": "P2", "hardware_state": "ON", "ip_address": "192.168.1.1", "software_state": "GOOD", "file_system_state": "GOOD"}, {"item_type": "NODE", "node_id": "4", "name": "SERVER1", "node_class": "SERVICE", "node_type": "SERVER", "priority": "P5", "hardware_state": "ON", "ip_address": "192.168.1.4", "software_state": "GOOD", "file_system_state": "GOOD", "services": [{"name": "TCP", "port": "80", "state": "GOOD"}]}, {"item_type": "LINK", "id": "5", "name": "link1", "bandwidth": 1000000000, "source": "1", "destination": "3"}, {"item_type": "LINK", "id": "6", "name": "link2", "bandwidth": 1000000000, "source": "2", "destination": "3"}, {"item_type": "LINK", "id": "7", "name": "link3", "bandwidth": 1000000000, "source": "3", "destination": "4"}, {"item_type": "GREEN_IER", "id": "8", "start_step": 1, "end_step": 256, "load": 10000, "protocol": "TCP", "port": "80", "source": "1", "destination": "4", "mission_criticality": 1}, {"item_type": "GREEN_IER", "id": "9", "start_step": 1, "end_step": 256, "load": 10000, "protocol": "TCP", "port": "80", "source": "2", "destination": "4", "mission_criticality": 1}, {"item_type": "GREEN_IER", "id": "10", "start_step": 1, "end_step": 256, "load": 10000, "protocol": "TCP", "port": "80", "source": "4", "destination": "2", "mission_criticality": 5}, {"item_type": "ACL_RULE", "id": "11", "permission": "ALLOW", "source": "192.168.1.2", "destination": "192.168.1.4", "protocol": "TCP", "port": 80}, {"item_type": "ACL_RULE", "id": "12", "permission": "ALLOW", "source": "192.168.1.3", "destination": "192.168.1.4", "protocol": "TCP", "port": 80}, {"item_type": "ACL_RULE", "id": "13", "permission": "ALLOW", "source": "192.168.1.4", "destination": "192.168.1.3", "protocol": "TCP", "port": 80}, {"item_type": "RED_POL", "id": "14", "start_step": 20, "end_step": 20, "targetNodeId": "1", "initiator": "DIRECT", "type": "SERVICE", "protocol": "TCP", "state": "COMPROMISED", "sourceNodeId": "NA", "sourceNodeService": "NA", "sourceNodeServiceState": "NA"}, {"item_type": "RED_IER", "id": "15", "start_step": 30, "end_step": 256, "load": 10000000, "protocol": "TCP", "port": "80", "source": "1", "destination": "4", "mission_criticality": 0}, {"item_type": "RED_POL", "id": "16", "start_step": 40, "end_step": 40, "targetNodeId": "4", "initiator": "IER", "type": "SERVICE", "protocol": "TCP", "state": "OVERWHELMED", "sourceNodeId": "NA", "sourceNodeService": "NA", "sourceNodeServiceState": "NA"}]}} diff --git a/tests/config/legacy_conversion/legacy_training_config.yaml b/tests/config/legacy_conversion/legacy_training_config.yaml index 5c2025a2..fb24e3d7 100644 --- a/tests/config/legacy_conversion/legacy_training_config.yaml +++ b/tests/config/legacy_conversion/legacy_training_config.yaml @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. # Main Config File # Generic config values diff --git a/tests/config/legacy_conversion/new_training_config.yaml b/tests/config/legacy_conversion/new_training_config.yaml index c57741f7..3df29d04 100644 --- a/tests/config/legacy_conversion/new_training_config.yaml +++ b/tests/config/legacy_conversion/new_training_config.yaml @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. # Main Config File # Generic config values diff --git a/tests/config/obs_tests/laydown.yaml b/tests/config/obs_tests/laydown.yaml index ef77ce83..3590492b 100644 --- a/tests/config/obs_tests/laydown.yaml +++ b/tests/config/obs_tests/laydown.yaml @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. - item_type: PORTS ports_list: - port: '80' diff --git a/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml b/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml index 2ac8f59a..8374115d 100644 --- a/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml +++ b/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. # Training Config File # Sets which agent algorithm framework will be used. diff --git a/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml b/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml index a9986d5b..c68199a0 100644 --- a/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml +++ b/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. # Training Config File # Sets which agent algorithm framework will be used. diff --git a/tests/config/obs_tests/main_config_NODE_STATUSES.yaml b/tests/config/obs_tests/main_config_NODE_STATUSES.yaml index a129712c..c662e715 100644 --- a/tests/config/obs_tests/main_config_NODE_STATUSES.yaml +++ b/tests/config/obs_tests/main_config_NODE_STATUSES.yaml @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. # Training Config File # Sets which agent algorithm framework will be used. diff --git a/tests/config/obs_tests/main_config_without_obs.yaml b/tests/config/obs_tests/main_config_without_obs.yaml index 03d11b82..bd23bded 100644 --- a/tests/config/obs_tests/main_config_without_obs.yaml +++ b/tests/config/obs_tests/main_config_without_obs.yaml @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. # Training Config File # Sets which agent algorithm framework will be used. diff --git a/tests/config/one_node_states_on_off_lay_down_config.yaml b/tests/config/one_node_states_on_off_lay_down_config.yaml index aadbd449..65257d62 100644 --- a/tests/config/one_node_states_on_off_lay_down_config.yaml +++ b/tests/config/one_node_states_on_off_lay_down_config.yaml @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. - item_type: PORTS ports_list: - port: '21' diff --git a/tests/config/one_node_states_on_off_main_config.yaml b/tests/config/one_node_states_on_off_main_config.yaml index dd425a8c..133b2af8 100644 --- a/tests/config/one_node_states_on_off_main_config.yaml +++ b/tests/config/one_node_states_on_off_main_config.yaml @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. # Training Config File # Sets which agent algorithm framework will be used. @@ -7,6 +8,14 @@ # "CUSTOM" (Custom Agent) agent_framework: CUSTOM +# Sets which deep learning framework will be used (by RLlib ONLY). +# Default is TF (Tensorflow). +# Options are: +# "TF" (Tensorflow) +# TF2 (Tensorflow 2.X) +# TORCH (PyTorch) +deep_learning_framework: TF2 + # Sets which Agent class will be used. # Options are: # "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) @@ -17,32 +26,78 @@ agent_framework: CUSTOM # "DUMMY" (primaite.agents.simple.DummyAgent) agent_identifier: DUMMY +# Sets whether Red Agent POL and IER is randomised. +# Options are: +# True +# False +random_red_agent: False + +# The (integer) seed to be used in random number generation +# Default is None (null) +seed: null + +# Set whether the agent will be deterministic instead of stochastic +# Options are: +# True +# False +deterministic: False + +# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC. +# Options are: +# "BASIC" (The current observation space only) +# "FULL" (Full environment view with actions taken and reward feedback) +hard_coded_agent_view: FULL + # Sets How the Action Space is defined: # "NODE" # "ACL" # "ANY" node and acl actions action_type: NODE +# observation space +observation_space: + # flatten: true + components: + - name: NODE_LINK_TABLE + # - name: NODE_STATUSES + # - name: LINK_TRAFFIC_LEVELS +# Number of episodes for training to run per session +num_train_episodes: 10 + +# Number of time_steps for training per episode +num_train_steps: 256 + # Number of episodes for evaluation to run per session num_eval_episodes: 1 # Number of time_steps for evaluation per episode num_eval_steps: 15 -# Time delay between steps (for generic agents) -time_delay: 1 -# Type of session to be run (TRAINING or EVALUATION) +# Sets how often the agent will save a checkpoint (every n time episodes). +# Set to 0 if no checkpoints are required. Default is 10 +checkpoint_every_n_episodes: 10 + +# Time delay (milliseconds) between steps for CUSTOM agents. +time_delay: 5 + +# Type of session to be run. Options are: +# "TRAIN" (Trains an agent) +# "EVAL" (Evaluates an agent) +# "TRAIN_EVAL" (Trains then evaluates an agent) session_type: EVAL -# Determine whether to load an agent from file -load_agent: False -# File path and file name of agent if you're loading one in -agent_load_file: C:\[Path]\[agent_saved_filename.zip] # Environment config values # The high value for the observation space observation_space_high_value: 1000000000 +# The Stable Baselines3 learn/eval output verbosity level: +# Options are: +# "NONE" (No Output) +# "INFO" (Info Messages (such as devices and wrappers used)) +# "DEBUG" (All Messages) +sb3_output_verbose_level: NONE + # Reward values # Generic all_ok: 0 diff --git a/tests/config/ppo_not_seeded_training_config.yaml b/tests/config/ppo_not_seeded_training_config.yaml index 14b3f087..1b1d5deb 100644 --- a/tests/config/ppo_not_seeded_training_config.yaml +++ b/tests/config/ppo_not_seeded_training_config.yaml @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. # Training Config File # Sets which agent algorithm framework will be used. diff --git a/tests/config/ppo_seeded_training_config.yaml b/tests/config/ppo_seeded_training_config.yaml index a176c793..14a4face 100644 --- a/tests/config/ppo_seeded_training_config.yaml +++ b/tests/config/ppo_seeded_training_config.yaml @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. # Training Config File # Sets which agent algorithm framework will be used. diff --git a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml index 0f378634..2fcca1f2 100644 --- a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml +++ b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. # Training Config File # Sets which agent algorithm framework will be used. diff --git a/tests/config/single_action_space_lay_down_config.yaml b/tests/config/single_action_space_lay_down_config.yaml index 9d05b84a..9fb82ac2 100644 --- a/tests/config/single_action_space_lay_down_config.yaml +++ b/tests/config/single_action_space_lay_down_config.yaml @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. - item_type: PORTS ports_list: - port: '21' diff --git a/tests/config/single_action_space_main_config.yaml b/tests/config/single_action_space_main_config.yaml index c875757f..625491fe 100644 --- a/tests/config/single_action_space_main_config.yaml +++ b/tests/config/single_action_space_main_config.yaml @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. # Training Config File # Sets which agent algorithm framework will be used. diff --git a/tests/config/test_random_red_main_config.yaml b/tests/config/test_random_red_main_config.yaml index e2b24b41..3416029c 100644 --- a/tests/config/test_random_red_main_config.yaml +++ b/tests/config/test_random_red_main_config.yaml @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. # Training Config File # Sets which agent algorithm framework will be used. @@ -5,7 +6,15 @@ # "SB3" (Stable Baselines3) # "RLLIB" (Ray RLlib) # "CUSTOM" (Custom Agent) -agent_framework: CUSTOM +agent_framework: SB3 + +# Sets which deep learning framework will be used (by RLlib ONLY). +# Default is TF (Tensorflow). +# Options are: +# "TF" (Tensorflow) +# TF2 (Tensorflow 2.X) +# TORCH (PyTorch) +deep_learning_framework: TF2 # Sets which Agent class will be used. # Options are: @@ -15,7 +24,7 @@ agent_framework: CUSTOM # "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) # "RANDOM" (primaite.agents.simple.RandomAgent) # "DUMMY" (primaite.agents.simple.DummyAgent) -agent_identifier: DUMMY +agent_identifier: PPO # Sets whether Red Agent POL and IER is randomised. # Options are: @@ -23,92 +32,128 @@ agent_identifier: DUMMY # False random_red_agent: True +# The (integer) seed to be used in random number generation +# Default is None (null) +seed: null + +# Set whether the agent will be deterministic instead of stochastic +# Options are: +# True +# False +deterministic: False + +# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC. +# Options are: +# "BASIC" (The current observation space only) +# "FULL" (Full environment view with actions taken and reward feedback) +hard_coded_agent_view: FULL + # Sets How the Action Space is defined: # "NODE" # "ACL" # "ANY" node and acl actions action_type: NODE +# observation space +observation_space: + # flatten: true + components: + - name: NODE_LINK_TABLE + # - name: NODE_STATUSES + # - name: LINK_TRAFFIC_LEVELS + + # Number of episodes for training to run per session -num_train_episodes: 2 +num_train_episodes: 10 # Number of time_steps for training per episode -num_train_steps: 15 +num_train_steps: 256 # Number of episodes for evaluation to run per session -num_eval_episodes: 2 +num_eval_episodes: 1 # Number of time_steps for evaluation per episode -num_eval_steps: 15 -# Time delay between steps (for generic agents) -time_delay: 1 +num_eval_steps: 256 -# Type of session to be run (TRAINING or EVALUATION) -session_type: EVAL -# Determine whether to load an agent from file -load_agent: False -# File path and file name of agent if you're loading one in -agent_load_file: C:\[Path]\[agent_saved_filename.zip] +# Sets how often the agent will save a checkpoint (every n time episodes). +# Set to 0 if no checkpoints are required. Default is 10 +checkpoint_every_n_episodes: 10 + +# Time delay (milliseconds) between steps for CUSTOM agents. +time_delay: 5 + +# Type of session to be run. Options are: +# "TRAIN" (Trains an agent) +# "EVAL" (Evaluates an agent) +# "TRAIN_EVAL" (Trains then evaluates an agent) +session_type: TRAIN_EVAL # Environment config values # The high value for the observation space observation_space_high_value: 1000000000 +# The Stable Baselines3 learn/eval output verbosity level: +# Options are: +# "NONE" (No Output) +# "INFO" (Info Messages (such as devices and wrappers used)) +# "DEBUG" (All Messages) +sb3_output_verbose_level: NONE + # Reward values # Generic all_ok: 0 # Node Hardware State -off_should_be_on: -10 -off_should_be_resetting: -5 -on_should_be_off: -2 -on_should_be_resetting: -5 -resetting_should_be_on: -5 -resetting_should_be_off: -2 -resetting: -3 +off_should_be_on: -0.001 +off_should_be_resetting: -0.0005 +on_should_be_off: -0.0002 +on_should_be_resetting: -0.0005 +resetting_should_be_on: -0.0005 +resetting_should_be_off: -0.0002 +resetting: -0.0003 # Node Software or Service State -good_should_be_patching: 2 -good_should_be_compromised: 5 -good_should_be_overwhelmed: 5 -patching_should_be_good: -5 -patching_should_be_compromised: 2 -patching_should_be_overwhelmed: 2 -patching: -3 -compromised_should_be_good: -20 -compromised_should_be_patching: -20 -compromised_should_be_overwhelmed: -20 -compromised: -20 -overwhelmed_should_be_good: -20 -overwhelmed_should_be_patching: -20 -overwhelmed_should_be_compromised: -20 -overwhelmed: -20 +good_should_be_patching: 0.0002 +good_should_be_compromised: 0.0005 +good_should_be_overwhelmed: 0.0005 +patching_should_be_good: -0.0005 +patching_should_be_compromised: 0.0002 +patching_should_be_overwhelmed: 0.0002 +patching: -0.0003 +compromised_should_be_good: -0.002 +compromised_should_be_patching: -0.002 +compromised_should_be_overwhelmed: -0.002 +compromised: -0.002 +overwhelmed_should_be_good: -0.002 +overwhelmed_should_be_patching: -0.002 +overwhelmed_should_be_compromised: -0.002 +overwhelmed: -0.002 # Node File System State -good_should_be_repairing: 2 -good_should_be_restoring: 2 -good_should_be_corrupt: 5 -good_should_be_destroyed: 10 -repairing_should_be_good: -5 -repairing_should_be_restoring: 2 -repairing_should_be_corrupt: 2 -repairing_should_be_destroyed: 0 -repairing: -3 -restoring_should_be_good: -10 -restoring_should_be_repairing: -2 -restoring_should_be_corrupt: 1 -restoring_should_be_destroyed: 2 -restoring: -6 -corrupt_should_be_good: -10 -corrupt_should_be_repairing: -10 -corrupt_should_be_restoring: -10 -corrupt_should_be_destroyed: 2 -corrupt: -10 -destroyed_should_be_good: -20 -destroyed_should_be_repairing: -20 -destroyed_should_be_restoring: -20 -destroyed_should_be_corrupt: -20 -destroyed: -20 -scanning: -2 +good_should_be_repairing: 0.0002 +good_should_be_restoring: 0.0002 +good_should_be_corrupt: 0.0005 +good_should_be_destroyed: 0.001 +repairing_should_be_good: -0.0005 +repairing_should_be_restoring: 0.0002 +repairing_should_be_corrupt: 0.0002 +repairing_should_be_destroyed: 0.0000 +repairing: -0.0003 +restoring_should_be_good: -0.001 +restoring_should_be_repairing: -0.0002 +restoring_should_be_corrupt: 0.0001 +restoring_should_be_destroyed: 0.0002 +restoring: -0.0006 +corrupt_should_be_good: -0.001 +corrupt_should_be_repairing: -0.001 +corrupt_should_be_restoring: -0.001 +corrupt_should_be_destroyed: 0.0002 +corrupt: -0.001 +destroyed_should_be_good: -0.002 +destroyed_should_be_repairing: -0.002 +destroyed_should_be_restoring: -0.002 +destroyed_should_be_corrupt: -0.002 +destroyed: -0.002 +scanning: -0.0002 # IER status -red_ier_running: -5 -green_ier_blocked: -10 +red_ier_running: -0.0005 +green_ier_blocked: -0.001 # Patching / Reset durations os_patching_duration: 5 # The time taken to patch the OS diff --git a/tests/config/train_episode_step.yaml b/tests/config/train_episode_step.yaml index f112b741..31337b0c 100644 --- a/tests/config/train_episode_step.yaml +++ b/tests/config/train_episode_step.yaml @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. # Training Config File # Sets which agent algorithm framework will be used. diff --git a/tests/config/training_config_main_rllib.yaml b/tests/config/training_config_main_rllib.yaml new file mode 100644 index 00000000..40cbc0fc --- /dev/null +++ b/tests/config/training_config_main_rllib.yaml @@ -0,0 +1,164 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: RLLIB + +# Sets which deep learning framework will be used (by RLlib ONLY). +# Default is TF (Tensorflow). +# Options are: +# "TF" (Tensorflow) +# TF2 (Tensorflow 2.X) +# TORCH (PyTorch) +deep_learning_framework: TF2 + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: PPO + +# Sets whether Red Agent POL and IER is randomised. +# Options are: +# True +# False +random_red_agent: False + +# The (integer) seed to be used in random number generation +# Default is None (null) +seed: null + +# Set whether the agent will be deterministic instead of stochastic +# Options are: +# True +# False +deterministic: False + +# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC. +# Options are: +# "BASIC" (The current observation space only) +# "FULL" (Full environment view with actions taken and reward feedback) +hard_coded_agent_view: FULL + +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: NODE +# observation space +observation_space: + # flatten: true + components: + - name: NODE_LINK_TABLE + # - name: NODE_STATUSES + # - name: LINK_TRAFFIC_LEVELS + + +# Number of episodes for training to run per session +num_train_episodes: 10 + +# Number of time_steps for training per episode +num_train_steps: 256 + +# Number of episodes for evaluation to run per session +num_eval_episodes: 1 + +# Number of time_steps for evaluation per episode +num_eval_steps: 256 + +# Sets how often the agent will save a checkpoint (every n time episodes). +# Set to 0 if no checkpoints are required. Default is 10 +checkpoint_every_n_episodes: 10 + +# Time delay (milliseconds) between steps for CUSTOM agents. +time_delay: 5 + +# Type of session to be run. Options are: +# "TRAIN" (Trains an agent) +# "EVAL" (Evaluates an agent) +# "TRAIN_EVAL" (Trains then evaluates an agent) +session_type: TRAIN_EVAL + +# Environment config values +# The high value for the observation space +observation_space_high_value: 1000000000 + +# The Stable Baselines3 learn/eval output verbosity level: +# Options are: +# "NONE" (No Output) +# "INFO" (Info Messages (such as devices and wrappers used)) +# "DEBUG" (All Messages) +sb3_output_verbose_level: NONE + +# Reward values +# Generic +all_ok: 0 +# Node Hardware State +off_should_be_on: -0.001 +off_should_be_resetting: -0.0005 +on_should_be_off: -0.0002 +on_should_be_resetting: -0.0005 +resetting_should_be_on: -0.0005 +resetting_should_be_off: -0.0002 +resetting: -0.0003 +# Node Software or Service State +good_should_be_patching: 0.0002 +good_should_be_compromised: 0.0005 +good_should_be_overwhelmed: 0.0005 +patching_should_be_good: -0.0005 +patching_should_be_compromised: 0.0002 +patching_should_be_overwhelmed: 0.0002 +patching: -0.0003 +compromised_should_be_good: -0.002 +compromised_should_be_patching: -0.002 +compromised_should_be_overwhelmed: -0.002 +compromised: -0.002 +overwhelmed_should_be_good: -0.002 +overwhelmed_should_be_patching: -0.002 +overwhelmed_should_be_compromised: -0.002 +overwhelmed: -0.002 +# Node File System State +good_should_be_repairing: 0.0002 +good_should_be_restoring: 0.0002 +good_should_be_corrupt: 0.0005 +good_should_be_destroyed: 0.001 +repairing_should_be_good: -0.0005 +repairing_should_be_restoring: 0.0002 +repairing_should_be_corrupt: 0.0002 +repairing_should_be_destroyed: 0.0000 +repairing: -0.0003 +restoring_should_be_good: -0.001 +restoring_should_be_repairing: -0.0002 +restoring_should_be_corrupt: 0.0001 +restoring_should_be_destroyed: 0.0002 +restoring: -0.0006 +corrupt_should_be_good: -0.001 +corrupt_should_be_repairing: -0.001 +corrupt_should_be_restoring: -0.001 +corrupt_should_be_destroyed: 0.0002 +corrupt: -0.001 +destroyed_should_be_good: -0.002 +destroyed_should_be_repairing: -0.002 +destroyed_should_be_restoring: -0.002 +destroyed_should_be_corrupt: -0.002 +destroyed: -0.002 +scanning: -0.0002 +# IER status +red_ier_running: -0.0005 +green_ier_blocked: -0.001 + +# Patching / Reset durations +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/tests/conftest.py b/tests/conftest.py index aaf4dbce..9b0db139 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,11 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. import datetime import json import shutil import tempfile from datetime import datetime from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Dict, Tuple, Union from unittest.mock import patch import pytest @@ -13,7 +13,7 @@ import pytest from primaite import getLogger from primaite.environment.primaite_env import Primaite from primaite.primaite_session import PrimaiteSession -from primaite.utils.session_output_reader import av_rewards_dict +from primaite.utils.session_output_reader import all_transactions_dict, av_rewards_dict from tests.mock_and_patch.get_session_path_mock import get_temp_session_path ACTION_SPACE_NODE_VALUES = 1 @@ -37,16 +37,26 @@ class TempPrimaiteSession(PrimaiteSession): super().__init__(training_config_path, lay_down_config_path) self.setup() - def learn_av_reward_per_episode(self) -> Dict[int, float]: + def learn_av_reward_per_episode_dict(self) -> Dict[int, float]: """Get the learn av reward per episode from file.""" csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" return av_rewards_dict(self.learning_path / csv_file) - def eval_av_reward_per_episode_csv(self) -> Dict[int, float]: + def eval_av_reward_per_episode_dict(self) -> Dict[int, float]: """Get the eval av reward per episode from file.""" csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" return av_rewards_dict(self.evaluation_path / csv_file) + def learn_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]: + """Get the learn all transactions from file.""" + csv_file = f"all_transactions_{self.timestamp_str}.csv" + return all_transactions_dict(self.learning_path / csv_file) + + def eval_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]: + """Get the eval all transactions from file.""" + csv_file = f"all_transactions_{self.timestamp_str}.csv" + return all_transactions_dict(self.evaluation_path / csv_file) + def metadata_file_as_dict(self) -> Dict[str, Any]: """Read the session_metadata.json file and return as a dict.""" with open(self.session_path / "session_metadata.json", "r") as file: @@ -62,7 +72,6 @@ class TempPrimaiteSession(PrimaiteSession): def __exit__(self, type, value, tb): shutil.rmtree(self.session_path) - shutil.rmtree(self.session_path.parent) _LOGGER.debug(f"Deleted temp session directory: {self.session_path}") @@ -114,7 +123,7 @@ def temp_primaite_session(request): """ training_config_path = request.param[0] lay_down_config_path = request.param[1] - with patch("primaite.agents.agent.get_session_path", get_temp_session_path) as mck: + with patch("primaite.agents.agent_abc.get_session_path", get_temp_session_path) as mck: mck.session_timestamp = datetime.now() return TempPrimaiteSession(training_config_path, lay_down_config_path) diff --git a/tests/mock_and_patch/__init__.py b/tests/mock_and_patch/__init__.py index e69de29b..778748f7 100644 --- a/tests/mock_and_patch/__init__.py +++ b/tests/mock_and_patch/__init__.py @@ -0,0 +1 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. diff --git a/tests/mock_and_patch/get_session_path_mock.py b/tests/mock_and_patch/get_session_path_mock.py index 90c0cb5d..190e1dba 100644 --- a/tests/mock_and_patch/get_session_path_mock.py +++ b/tests/mock_and_patch/get_session_path_mock.py @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. import tempfile from datetime import datetime from pathlib import Path diff --git a/tests/test_acl.py b/tests/test_acl.py index 30f12697..4ef9d78c 100644 --- a/tests/test_acl.py +++ b/tests/test_acl.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Used to tes the ACL functions.""" from primaite.acl.access_control_list import AccessControlList diff --git a/tests/test_active_node.py b/tests/test_active_node.py index addc595c..880c0f02 100644 --- a/tests/test_active_node.py +++ b/tests/test_active_node.py @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Used to test Active Node functions.""" import pytest diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index d5844fd9..3bcdb66d 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Test env creation and behaviour with different observation spaces.""" import numpy as np import pytest diff --git a/tests/test_primaite_session.py b/tests/test_primaite_session.py index 75ea5882..210d931e 100644 --- a/tests/test_primaite_session.py +++ b/tests/test_primaite_session.py @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. import os import pytest diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py index f8885f3e..3496ed9d 100644 --- a/tests/test_red_random_agent_behaviour.py +++ b/tests/test_red_random_agent_behaviour.py @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. import pytest from primaite.config.lay_down_config import data_manipulation_config_path diff --git a/tests/test_resetting_node.py b/tests/test_resetting_node.py index fb7dc83d..80e13c5b 100644 --- a/tests/test_resetting_node.py +++ b/tests/test_resetting_node.py @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Used to test Active Node functions.""" import pytest diff --git a/tests/test_reward.py b/tests/test_reward.py index bb6eb1b0..741c6f13 100644 --- a/tests/test_reward.py +++ b/tests/test_reward.py @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. import pytest from primaite import getLogger @@ -48,5 +49,5 @@ def test_rewards_are_being_penalised_at_each_step_function( """ with temp_primaite_session as session: session.evaluate() - ev_rewards = session.eval_av_reward_per_episode_csv() + ev_rewards = session.eval_av_reward_per_episode_dict() assert ev_rewards[1] == -8.0 diff --git a/tests/test_rllib_agent.py b/tests/test_rllib_agent.py new file mode 100644 index 00000000..f494ea81 --- /dev/null +++ b/tests/test_rllib_agent.py @@ -0,0 +1,24 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +import pytest + +from primaite import getLogger +from primaite.config.lay_down_config import dos_very_basic_config_path +from tests import TEST_CONFIG_ROOT + +_LOGGER = getLogger(__name__) + + +@pytest.mark.parametrize( + "temp_primaite_session", + [[TEST_CONFIG_ROOT / "training_config_main_rllib.yaml", dos_very_basic_config_path()]], + indirect=True, +) +def test_primaite_session(temp_primaite_session): + """Test the training_config_main_rllib.yaml training config file.""" + with temp_primaite_session as session: + session_path = session.session_path + assert session_path.exists() + session.learn() + + assert len(session.learn_av_reward_per_episode_dict().keys()) == 10 + assert len(session.learn_all_transactions_dict().keys()) == 10 * 256 diff --git a/tests/test_seeding_and_deterministic_session.py b/tests/test_seeding_and_deterministic_session.py index 34cb43fb..5220fb1c 100644 --- a/tests/test_seeding_and_deterministic_session.py +++ b/tests/test_seeding_and_deterministic_session.py @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. import pytest as pytest from primaite.config.lay_down_config import dos_very_basic_config_path @@ -28,7 +29,7 @@ def test_seeded_learning(temp_primaite_session): "Expected output is based upon a agent that was trained with " "seed 67890" ) session.learn() - actual_mean_reward_per_episode = session.learn_av_reward_per_episode() + actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict() assert actual_mean_reward_per_episode == expected_mean_reward_per_episode @@ -45,5 +46,5 @@ def test_deterministic_evaluation(temp_primaite_session): # do stuff session.learn() session.evaluate() - eval_mean_reward = session.eval_av_reward_per_episode_csv() + eval_mean_reward = session.eval_av_reward_per_episode_dict() assert len(set(eval_mean_reward.values())) == 1 diff --git a/tests/test_service_node.py b/tests/test_service_node.py index 4383fc1b..2f504cd6 100644 --- a/tests/test_service_node.py +++ b/tests/test_service_node.py @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Used to test Service Node functions.""" import pytest diff --git a/tests/test_session_loading.py b/tests/test_session_loading.py new file mode 100644 index 00000000..bcd28d96 --- /dev/null +++ b/tests/test_session_loading.py @@ -0,0 +1,189 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +import os.path +import shutil +import tempfile +from pathlib import Path +from typing import Union +from uuid import uuid4 + +from primaite import getLogger +from primaite.agents.sb3 import SB3Agent +from primaite.common.enums import AgentFramework, AgentIdentifier +from primaite.main import run +from primaite.primaite_session import PrimaiteSession +from primaite.utils.session_output_reader import av_rewards_dict +from tests import TEST_ASSETS_ROOT + +_LOGGER = getLogger(__name__) + + +def copy_session_asset(asset_path: Union[str, Path]) -> str: + """Copies the asset into a temporary test folder.""" + if asset_path is None: + raise Exception("No path provided") + + if isinstance(asset_path, Path): + asset_path = str(os.path.normpath(asset_path)) + + copy_path = str(Path(tempfile.gettempdir()) / "primaite" / str(uuid4())) + + # copy the asset into a temp path + try: + shutil.copytree(asset_path, copy_path) + except Exception as e: + msg = f"Unable to copy directory: {asset_path}" + _LOGGER.error(msg, e) + print(msg, e) + + _LOGGER.debug(f"Copied test asset to: {copy_path}") + + # return the copied assets path + return copy_path + + +def test_load_sb3_session(): + """Test that loading an SB3 agent works.""" + expected_learn_mean_reward_per_episode = { + 10: 0, + 11: -0.008037109374999995, + 12: -0.007978515624999988, + 13: -0.008191406249999991, + 14: -0.00817578124999999, + 15: -0.008085937499999998, + 16: -0.007837890624999982, + 17: -0.007798828124999992, + 18: -0.007777343749999998, + 19: -0.007958984374999988, + 20: -0.0077499999999999835, + } + + test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session") + + loaded_agent = SB3Agent(session_path=test_path) + + # loaded agent should have the same UUID as the previous agent + assert loaded_agent.uuid == "301874d3-2e14-43c2-ba7f-e2b03ad05dde" + assert loaded_agent._training_config.agent_framework == AgentFramework.SB3.name + assert loaded_agent._training_config.agent_identifier == AgentIdentifier.PPO.name + assert loaded_agent._training_config.deterministic + assert loaded_agent._training_config.seed == 12345 + assert str(loaded_agent.session_path) == str(test_path) + + # run another learn session + loaded_agent.learn() + + learn_mean_rewards = av_rewards_dict( + loaded_agent.learning_path / f"average_reward_per_episode_{loaded_agent.timestamp_str}.csv" + ) + + # run is seeded so should have the expected learn value + assert learn_mean_rewards == expected_learn_mean_reward_per_episode + + # run an evaluation + loaded_agent.evaluate() + + # load the evaluation average reward csv file + eval_mean_reward = av_rewards_dict( + loaded_agent.evaluation_path / f"average_reward_per_episode_{loaded_agent.timestamp_str}.csv" + ) + + # the agent config ran the evaluation in deterministic mode, so should have the same reward value + assert len(set(eval_mean_reward.values())) == 1 + + # the evaluation should be the same as a previous run + assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988 + + # delete the test directory + shutil.rmtree(test_path) + + +def test_load_primaite_session(): + """Test that loading a Primaite session works.""" + expected_learn_mean_reward_per_episode = { + 10: 0, + 11: -0.008037109374999995, + 12: -0.007978515624999988, + 13: -0.008191406249999991, + 14: -0.00817578124999999, + 15: -0.008085937499999998, + 16: -0.007837890624999982, + 17: -0.007798828124999992, + 18: -0.007777343749999998, + 19: -0.007958984374999988, + 20: -0.0077499999999999835, + } + + test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session") + + # create loaded session + session = PrimaiteSession(session_path=test_path) + + # run setup on session + session.setup() + + # make sure that the session was loaded correctly + assert session._agent_session.uuid == "301874d3-2e14-43c2-ba7f-e2b03ad05dde" + assert session._agent_session._training_config.agent_framework == AgentFramework.SB3.name + assert session._agent_session._training_config.agent_identifier == AgentIdentifier.PPO.name + assert session._agent_session._training_config.deterministic + assert session._agent_session._training_config.seed == 12345 + assert str(session._agent_session.session_path) == str(test_path) + + # run another learn session + session.learn() + + learn_mean_rewards = av_rewards_dict( + session.learning_path / f"average_reward_per_episode_{session.timestamp_str}.csv" + ) + + # run is seeded so should have the expected learn value + assert learn_mean_rewards == expected_learn_mean_reward_per_episode + + # run an evaluation + session.evaluate() + + # load the evaluation average reward csv file + eval_mean_reward = av_rewards_dict( + session.evaluation_path / f"average_reward_per_episode_{session.timestamp_str}.csv" + ) + + # the agent config ran the evaluation in deterministic mode, so should have the same reward value + assert len(set(eval_mean_reward.values())) == 1 + + # the evaluation should be the same as a previous run + assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988 + + # delete the test directory + shutil.rmtree(test_path) + + +def test_run_loading(): + """Test loading session via main.run.""" + expected_learn_mean_reward_per_episode = { + 10: 0, + 11: -0.008037109374999995, + 12: -0.007978515624999988, + 13: -0.008191406249999991, + 14: -0.00817578124999999, + 15: -0.008085937499999998, + 16: -0.007837890624999982, + 17: -0.007798828124999992, + 18: -0.007777343749999998, + 19: -0.007958984374999988, + 20: -0.0077499999999999835, + } + + test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session") + + # create loaded session + run(session_path=test_path) + + learn_mean_rewards = av_rewards_dict( + next(Path(test_path).rglob("**/learning/average_reward_per_episode_*.csv"), None) + ) + + # run is seeded so should have the expected learn value + assert learn_mean_rewards == expected_learn_mean_reward_per_episode + + # delete the test directory + shutil.rmtree(test_path) diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index bfcffd42..4f7af9a6 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. import time import pytest diff --git a/tests/test_train_eval_episode_steps.py b/tests/test_train_eval_episode_steps.py index b839e630..4f7bed16 100644 --- a/tests/test_train_eval_episode_steps.py +++ b/tests/test_train_eval_episode_steps.py @@ -1,3 +1,4 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. import pytest from primaite import getLogger diff --git a/tests/test_training_config.py b/tests/test_training_config.py index d7fe4e50..4123ee39 100644 --- a/tests/test_training_config.py +++ b/tests/test_training_config.py @@ -1,4 +1,4 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. import yaml from primaite.config import training_config