diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml
index 066c66b2..0bb03594 100644
--- a/.azure/azure-ci-build-pipeline.yaml
+++ b/.azure/azure-ci-build-pipeline.yaml
@@ -86,5 +86,5 @@ stages:
displayName: 'Perform PrimAITE Setup'
- script: |
- pytest tests/
+ pytest -n 4
displayName: 'Run tests'
diff --git a/.gitignore b/.gitignore
index 5d6434f1..ef1050e6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -50,6 +50,9 @@ coverage.xml
.hypothesis/
.pytest_cache/
cover/
+tests/assets/**/*.png
+tests/assets/**/tensorboard_logs/
+tests/assets/**/checkpoints/
# Translations
*.mo
diff --git a/docs/_templates/custom-class-template.rst b/docs/_templates/custom-class-template.rst
index 8a539bc9..acffdc4c 100644
--- a/docs/_templates/custom-class-template.rst
+++ b/docs/_templates/custom-class-template.rst
@@ -1,3 +1,7 @@
+.. only:: comment
+
+ Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
+
..
Credit to https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion for the custom templates.
..
diff --git a/docs/_templates/custom-module-template.rst b/docs/_templates/custom-module-template.rst
index e6ecabd1..8eebad3e 100644
--- a/docs/_templates/custom-module-template.rst
+++ b/docs/_templates/custom-module-template.rst
@@ -1,3 +1,7 @@
+.. only:: comment
+
+ Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
+
..
Credit to https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion for the custom templates.
..
diff --git a/docs/api.rst b/docs/api.rst
index df2bc193..b24dafc3 100644
--- a/docs/api.rst
+++ b/docs/api.rst
@@ -1,3 +1,7 @@
+.. only:: comment
+
+ Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
+
..
DO NOT DELETE THIS FILE! It contains the all-important `.. autosummary::` directive with `:recursive:` option, without
which API documentation wouldn't get extracted from docstrings by the `sphinx.ext.autosummary` engine. It is hidden
diff --git a/docs/conf.py b/docs/conf.py
index 51b745cf..8afc1246 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Configuration file for the Sphinx documentation builder.
#
# For the full list of built-in configuration values, see the documentation:
diff --git a/docs/index.rst b/docs/index.rst
index fed65919..de5bed46 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -1,3 +1,7 @@
+.. only:: comment
+
+ Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
+
Welcome to PrimAITE's documentation
====================================
@@ -20,17 +24,24 @@ What is PrimAITE built with
* `OpenAI's Gym `_ is used as the basis for AI blue agent interaction with the PrimAITE environment
* `Networkx `_ is used as the underlying data structure used for the PrimAITE environment
* `Stable Baselines 3 `_ is used as a default source of RL algorithms (although PrimAITE is not limited to SB3 agents)
+* `Ray RLlib `_ is used as an additional source of RL algorithms
+* `Typer `_ is used for building CLIs (Command Line Interface applications)
+* `Jupyterlab `_ is used as an extensible environment for interactive and reproducible computing, based on the Jupyter Notebook Architecture
+* `Platformdirs `_ is used for finding the right location to store user data and configuration but varies per platform
+* `Plotly `_ is used for building high level charts
+
Where next?
------------
-The best place to start is :ref:`about`
+Head over to the :ref:`getting-started` page to install and setup PrimAITE!
.. toctree::
:maxdepth: 8
:caption: Contents:
:hidden:
+ source/getting_started
source/about
source/config
source/primaite_session
@@ -41,12 +52,14 @@ The best place to start is :ref:`about`
source/glossary
source/migration_1.2_-_2.0
+
+.. TODO: Add project links once public repo has been created
+
.. toctree::
:caption: Project Links:
:hidden:
-..
- #Code <>
- #Issues <>
- #Pull Requests <>
- #Discussions <>
+ Code
+ Issues
+ Pull Requests
+ Discussions
diff --git a/docs/source/about.rst b/docs/source/about.rst
index a7135fc0..2068472c 100644
--- a/docs/source/about.rst
+++ b/docs/source/about.rst
@@ -1,4 +1,8 @@
-.. _about:
+.. only:: comment
+
+ Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
+
+.. _about:
About PrimAITE
==============
diff --git a/docs/source/config.rst b/docs/source/config.rst
index af590a24..058565da 100644
--- a/docs/source/config.rst
+++ b/docs/source/config.rst
@@ -1,3 +1,7 @@
+.. only:: comment
+
+ Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
+
.. _config:
The Config Files Explained
diff --git a/docs/source/custom_agent.rst b/docs/source/custom_agent.rst
index b4552d64..ba438305 100644
--- a/docs/source/custom_agent.rst
+++ b/docs/source/custom_agent.rst
@@ -1,4 +1,8 @@
-Custom Agents
+.. only:: comment
+
+ Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
+
+Custom Agents
=============
diff --git a/docs/source/dependencies.rst b/docs/source/dependencies.rst
index bbca3fce..0d3f21c3 100644
--- a/docs/source/dependencies.rst
+++ b/docs/source/dependencies.rst
@@ -1,3 +1,7 @@
+.. only:: comment
+
+ Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
+
.. role:: raw-html(raw)
:format: html
diff --git a/docs/source/getting_started.rst b/docs/source/getting_started.rst
new file mode 100644
index 00000000..13c9d699
--- /dev/null
+++ b/docs/source/getting_started.rst
@@ -0,0 +1,155 @@
+.. only:: comment
+
+ Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
+
+.. _getting-started:
+
+Getting Started
+===============
+
+**Getting Started with PrimAITE**
+
+Pre-Requisites
+
+In order to get **PrimAITE** installed, you will need to have a python version between 3.8 and 3.10 installed. If you don't already have it, this is how to install it:
+
+
+.. tabs:: lang
+
+ .. code-tab:: bash
+ :caption: Unix
+
+ sudo add-apt-repository ppa:deadsnakes/ppa
+ sudo apt install python3.10
+ sudo apt-get install python3-pip
+ sudo apt-get install python3-venv
+
+ .. code-tab:: text
+ :caption: Windows (Powershell)
+
+ - Manual install from: https://www.python.org/downloads/release/python-31011/
+
+**PrimAITE** is designed to be OS-agnostic, and thus should work on most variations/distros of Linux, Windows, and MacOS.
+
+Install PrimAITE
+****************
+
+1. Create a primaite directory in your home directory:
+
+.. tabs:: lang
+
+ .. code-tab:: bash
+ :caption: Unix
+
+ mkdir ~/primaite
+
+ .. code-tab:: powershell
+ :caption: Windows (Powershell)
+
+ mkdir ~\primaite
+
+2. Navigate to the primaite directory and create a new python virtual environment (venv)
+
+.. tabs:: lang
+
+ .. code-tab:: bash
+ :caption: Unix
+
+ cd ~/primaite
+ python3 -m venv .venv
+
+ .. code-tab:: powershell
+ :caption: Windows (Powershell)
+
+ cd ~\primaite
+ python3 -m venv .venv
+ attrib +h .venv /s /d # Hides the .venv directory
+
+3. Activate the venv
+
+.. tabs:: lang
+
+ .. code-tab:: bash
+ :caption: Unix
+
+ source .venv/bin/activate
+
+ .. code-tab:: powershell
+ :caption: Windows (Powershell)
+
+ .\.venv\Scripts\activate
+
+
+4. Install PrimAITE using pip from PyPi
+
+.. tabs:: lang
+
+ .. code-tab:: bash
+ :caption: Unix
+
+ pip install primaite
+
+ .. code-tab:: powershell
+ :caption: Windows (Powershell)
+
+ pip install primaite
+
+5. Perform the PrimAITE setup
+
+.. tabs:: lang
+
+ .. code-tab:: bash
+ :caption: Unix
+
+ primaite setup
+
+ .. code-tab:: powershell
+ :caption: Windows (Powershell)
+
+ primaite setup
+
+Clone & Install PrimAITE for Development
+****************************************
+
+To be able to extend PrimAITE further, or to build wheels manually before install, clone the repository to a location
+of your choice:
+
+.. TODO:: Add repo path once we know what it is
+
+.. code-block:: bash
+
+ git clone
+ cd primaite
+
+Create and activate your Python virtual environment (venv)
+
+.. tabs:: lang
+
+ .. code-tab:: bash
+ :caption: Unix
+
+ python3 -m venv venv
+ source venv/bin/activate
+
+ .. code-tab:: powershell
+ :caption: Windows (Powershell)
+
+ python3 -m venv venv
+ .\venv\Scripts\activate
+
+Install PrimAITE with the dev extra
+
+.. tabs:: lang
+
+ .. code-tab:: bash
+ :caption: Unix
+
+ pip install -e .[dev]
+
+ .. code-tab:: powershell
+ :caption: Windows (Powershell)
+
+ pip install -e .[dev]
+
+
+To view the complete list of packages installed during PrimAITE installation, go to the dependencies page (:ref:`Dependencies`).
diff --git a/docs/source/glossary.rst b/docs/source/glossary.rst
index 58b4cd5e..3422d51e 100644
--- a/docs/source/glossary.rst
+++ b/docs/source/glossary.rst
@@ -1,3 +1,7 @@
+.. only:: comment
+
+ Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
+
Glossary
=============
diff --git a/docs/source/migration_1.2_-_2.0.rst b/docs/source/migration_1.2_-_2.0.rst
index 2adf9656..b7c9996d 100644
--- a/docs/source/migration_1.2_-_2.0.rst
+++ b/docs/source/migration_1.2_-_2.0.rst
@@ -1,3 +1,7 @@
+.. only:: comment
+
+ Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
+
v1.2 to v2.0 Migration guide
============================
diff --git a/docs/source/primaite_session.rst b/docs/source/primaite_session.rst
index a393093c..b8895fc7 100644
--- a/docs/source/primaite_session.rst
+++ b/docs/source/primaite_session.rst
@@ -1,3 +1,7 @@
+.. only:: comment
+
+ Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
+
.. _run a primaite session:
Run a PrimAITE Session
@@ -10,6 +14,7 @@ A PrimAITE session can be ran either with the ``primaite session`` command from
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` from a Python terminal or Jupyter Notebook.
Both the ``primaite session`` and :func:`primaite.main.run` take a training config and a lay down config as parameters.
+
.. tabs::
.. code-tab:: bash
@@ -19,7 +24,7 @@ Both the ``primaite session`` and :func:`primaite.main.run` take a training conf
source ./.venv/bin/activate
primaite session ./config/my_training_config.yaml ./config/my_lay_down_config.yaml
- .. code-tab:: bash
+ .. code-tab:: powershell
:caption: Powershell CLI
cd ~\primaite
@@ -42,6 +47,37 @@ The sub-directory is formatted as such: ``~/primaite/sessions//)
+
+When PrimAITE runs a loaded session, PrimAITE will output in the provided session directory
Outputs
-------
diff --git a/pyproject.toml b/pyproject.toml
index dc04f609..fc0551c3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -58,6 +58,7 @@ dev = [
"pip-licenses==4.3.0",
"pre-commit==2.20.0",
"pytest==7.2.0",
+ "pytest-xdist==3.3.1",
"pytest-cov==4.0.0",
"pytest-flake8==1.1.1",
"setuptools==66",
diff --git a/setup.py b/setup.py
index 63e905c0..efaf24bf 100644
--- a/setup.py
+++ b/setup.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from setuptools import setup
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel # noqa
diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py
index dacd5c12..c348681d 100644
--- a/src/primaite/__init__.py
+++ b/src/primaite/__init__.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import logging
import logging.config
import sys
diff --git a/src/primaite/acl/__init__.py b/src/primaite/acl/__init__.py
index 2623efbc..c6fd79f2 100644
--- a/src/primaite/acl/__init__.py
+++ b/src/primaite/acl/__init__.py
@@ -1,2 +1,2 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Access Control List. Models firewall functionality."""
diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py
index d4d843e3..007f12a0 100644
--- a/src/primaite/acl/access_control_list.py
+++ b/src/primaite/acl/access_control_list.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""A class that implements the access control list implementation for the network."""
from typing import Dict
diff --git a/src/primaite/acl/acl_rule.py b/src/primaite/acl/acl_rule.py
index 69532376..830cfe35 100644
--- a/src/primaite/acl/acl_rule.py
+++ b/src/primaite/acl/acl_rule.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""A class that implements an access control list rule."""
diff --git a/src/primaite/agents/__init__.py b/src/primaite/agents/__init__.py
index 89580145..d987b43f 100644
--- a/src/primaite/agents/__init__.py
+++ b/src/primaite/agents/__init__.py
@@ -1 +1,2 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Common interface between RL agents from different libraries and PrimAITE."""
diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent_abc.py
similarity index 62%
rename from src/primaite/agents/agent.py
rename to src/primaite/agents/agent_abc.py
index 90860f7d..9b0dd031 100644
--- a/src/primaite/agents/agent.py
+++ b/src/primaite/agents/agent_abc.py
@@ -1,28 +1,24 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from __future__ import annotations
import json
-import time
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
-from typing import Any, Dict, Final, TYPE_CHECKING, Union
+from typing import Any, Dict, Optional, TYPE_CHECKING, Union
from uuid import uuid4
-import yaml
-
import primaite
from primaite import getLogger, SESSIONS_DIR
from primaite.config import lay_down_config, training_config
from primaite.config.training_config import TrainingConfig
from primaite.data_viz.session_plots import plot_av_reward_per_episode
from primaite.environment.primaite_env import Primaite
+from primaite.utils.session_metadata_parser import parse_session_metadata
if TYPE_CHECKING:
from logging import Logger
- import numpy as np
-
-
_LOGGER: "Logger" = getLogger(__name__)
@@ -53,38 +49,63 @@ class AgentSessionABC(ABC):
"""
@abstractmethod
- def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
+ def __init__(
+ self,
+ training_config_path: Optional[Union[str, Path]] = None,
+ lay_down_config_path: Optional[Union[str, Path]] = None,
+ session_path: Optional[Union[str, Path]] = None,
+ ) -> None:
"""
- Initialise an agent session from config files.
+ Initialise an agent session from config files, or load a previous session.
+
+ If training configuration and laydown configuration are provided with a session path,
+ the session path will be used.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
+ :param session_path: directory path of the session to load
"""
- if not isinstance(training_config_path, Path):
- training_config_path = Path(training_config_path)
- self._training_config_path: Final[Union[Path, str]] = training_config_path
- self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path)
-
- if not isinstance(lay_down_config_path, Path):
- lay_down_config_path = Path(lay_down_config_path)
- self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
- self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
- self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
-
+ # initialise variables
self._env: Primaite
self._agent = None
self._can_learn: bool = False
self._can_evaluate: bool = False
self.is_eval = False
- self._uuid = str(uuid4())
self.session_timestamp: datetime = datetime.now()
- "The session timestamp"
- self.session_path = get_session_path(self.session_timestamp)
- "The Session path"
+
+ # convert session to path
+ if session_path is not None:
+ if not isinstance(session_path, Path):
+ session_path = Path(session_path)
+
+ # if a session path is provided, load it
+ if not session_path.exists():
+ raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
+
+ # load session
+ self.load(session_path)
+ else:
+ # set training config path
+ if not isinstance(training_config_path, Path):
+ training_config_path = Path(training_config_path)
+ self._training_config_path: Union[Path, str] = training_config_path
+ self._training_config: TrainingConfig = training_config.load(self._training_config_path)
+
+ if not isinstance(lay_down_config_path, Path):
+ lay_down_config_path = Path(lay_down_config_path)
+ self._lay_down_config_path: Union[Path, str] = lay_down_config_path
+ self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
+ self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
+
+ # set random UUID for session
+ self._uuid = str(uuid4())
+ "The session timestamp"
+ self.session_path = get_session_path(self.session_timestamp)
+ "The Session path"
@property
def timestamp_str(self) -> str:
@@ -233,51 +254,27 @@ class AgentSessionABC(ABC):
def _get_latest_checkpoint(self) -> None:
pass
- @classmethod
- @abstractmethod
- def load(cls, path: Union[str, Path]) -> AgentSessionABC:
+ def load(self, path: Union[str, Path]):
"""Load an agent from file."""
- if not isinstance(path, Path):
- path = Path(path)
+ md_dict, training_config_path, laydown_config_path = parse_session_metadata(path)
- if path.exists():
- # Unpack the session_metadata.json file
- md_file = path / "session_metadata.json"
- with open(md_file, "r") as file:
- md_dict = json.load(file)
+ # set training config path
+ self._training_config_path: Union[Path, str] = training_config_path
+ self._training_config: TrainingConfig = training_config.load(self._training_config_path)
+ self._lay_down_config_path: Union[Path, str] = laydown_config_path
+ self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
+ self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
- # Create a temp directory and dump the training and lay down
- # configs into it
- temp_dir = path / ".temp"
- temp_dir.mkdir(exist_ok=True)
+ # set random UUID for session
+ self._uuid = md_dict["uuid"]
- temp_tc = temp_dir / "tc.yaml"
- with open(temp_tc, "w") as file:
- yaml.dump(md_dict["env"]["training_config"], file)
-
- temp_ldc = temp_dir / "ldc.yaml"
- with open(temp_ldc, "w") as file:
- yaml.dump(md_dict["env"]["lay_down_config"], file)
-
- agent = cls(temp_tc, temp_ldc)
-
- agent.session_path = path
-
- return agent
-
- else:
- # Session path does not exist
- msg = f"Failed to load PrimAITE Session, path does not exist: {path}"
- _LOGGER.error(msg)
- raise FileNotFoundError(msg)
+ # set the session path
+ self.session_path = path
+ "The Session path"
@property
def _saved_agent_path(self) -> Path:
- file_name = (
- f"{self._training_config.agent_framework}_"
- f"{self._training_config.agent_identifier}_"
- f"{self.timestamp_str}.zip"
- )
+ file_name = f"{self._training_config.agent_framework}_" f"{self._training_config.agent_identifier}" f".zip"
return self.learning_path / file_name
@abstractmethod
@@ -313,104 +310,3 @@ class AgentSessionABC(ABC):
fig = plot_av_reward_per_episode(path, title, subtitle)
fig.write_image(image_path)
_LOGGER.debug(f"Saved average rewards per episode plot to: {path}")
-
-
-class HardCodedAgentSessionABC(AgentSessionABC):
- """
- An Agent Session ABC for evaluation deterministic agents.
-
- This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
- implemented.
- """
-
- def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
- """
- Initialise a hardcoded agent session.
-
- :param training_config_path: YAML file containing configurable items defined in
- `primaite.config.training_config.TrainingConfig`
- :type training_config_path: Union[path, str]
- :param lay_down_config_path: YAML file containing configurable items for generating network laydown.
- :type lay_down_config_path: Union[path, str]
- """
- super().__init__(training_config_path, lay_down_config_path)
- self._setup()
-
- def _setup(self) -> None:
- self._env: Primaite = Primaite(
- training_config_path=self._training_config_path,
- lay_down_config_path=self._lay_down_config_path,
- session_path=self.session_path,
- timestamp_str=self.timestamp_str,
- )
- super()._setup()
- self._can_learn = False
- self._can_evaluate = True
-
- def _save_checkpoint(self) -> None:
- pass
-
- def _get_latest_checkpoint(self) -> None:
- pass
-
- def learn(
- self,
- **kwargs: Any,
- ) -> None:
- """
- Train the agent.
-
- :param kwargs: Any agent-specific key-word args to be passed.
- """
- _LOGGER.warning("Deterministic agents cannot learn")
-
- @abstractmethod
- def _calculate_action(self, obs: np.ndarray) -> None:
- pass
-
- def evaluate(
- self,
- **kwargs: Any,
- ) -> None:
- """
- Evaluate the agent.
-
- :param kwargs: Any agent-specific key-word args to be passed.
- """
- self._env.set_as_eval() # noqa
- self.is_eval = True
-
- time_steps = self._training_config.num_eval_steps
- episodes = self._training_config.num_eval_episodes
-
- obs = self._env.reset()
- for episode in range(episodes):
- # Reset env and collect initial observation
- for step in range(time_steps):
- # Calculate action
- action = self._calculate_action(obs)
-
- # Perform the step
- obs, reward, done, info = self._env.step(action)
-
- if done:
- break
-
- # Introduce a delay between steps
- time.sleep(self._training_config.time_delay / 1000)
- obs = self._env.reset()
- self._env.close()
- super().evaluate()
-
- @classmethod
- def load(cls) -> None:
- """Load an agent from file."""
- _LOGGER.warning("Deterministic agents cannot be loaded")
-
- def save(self) -> None:
- """Save the agent."""
- _LOGGER.warning("Deterministic agents cannot be saved")
-
- def export(self) -> None:
- """Export the agent to transportable file format."""
- _LOGGER.warning("Deterministic agents cannot be exported")
diff --git a/src/primaite/agents/hardcoded_abc.py b/src/primaite/agents/hardcoded_abc.py
new file mode 100644
index 00000000..ec4b53e7
--- /dev/null
+++ b/src/primaite/agents/hardcoded_abc.py
@@ -0,0 +1,116 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
+import time
+from abc import abstractmethod
+from pathlib import Path
+from typing import Optional, Union
+
+from primaite import getLogger
+from primaite.agents.agent_abc import AgentSessionABC
+from primaite.environment.primaite_env import Primaite
+
+_LOGGER = getLogger(__name__)
+
+
+class HardCodedAgentSessionABC(AgentSessionABC):
+ """
+ An Agent Session ABC for evaluation deterministic agents.
+
+ This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
+ implemented.
+ """
+
+ def __init__(
+ self,
+ training_config_path: Optional[Union[str, Path]] = "",
+ lay_down_config_path: Optional[Union[str, Path]] = "",
+ session_path: Optional[Union[str, Path]] = None,
+ ):
+ """
+ Initialise a hardcoded agent session.
+
+ :param training_config_path: YAML file containing configurable items defined in
+ `primaite.config.training_config.TrainingConfig`
+ :type training_config_path: Union[path, str]
+ :param lay_down_config_path: YAML file containing configurable items for generating network laydown.
+ :type lay_down_config_path: Union[path, str]
+ """
+ super().__init__(training_config_path, lay_down_config_path, session_path)
+ self._setup()
+
+ def _setup(self):
+ self._env: Primaite = Primaite(
+ training_config_path=self._training_config_path,
+ lay_down_config_path=self._lay_down_config_path,
+ session_path=self.session_path,
+ timestamp_str=self.timestamp_str,
+ )
+ super()._setup()
+ self._can_learn = False
+ self._can_evaluate = True
+
+ def _save_checkpoint(self):
+ pass
+
+ def _get_latest_checkpoint(self):
+ pass
+
+ def learn(
+ self,
+ **kwargs,
+ ):
+ """
+ Train the agent.
+
+ :param kwargs: Any agent-specific key-word args to be passed.
+ """
+ _LOGGER.warning("Deterministic agents cannot learn")
+
+ @abstractmethod
+ def _calculate_action(self, obs):
+ pass
+
+ def evaluate(
+ self,
+ **kwargs,
+ ):
+ """
+ Evaluate the agent.
+
+ :param kwargs: Any agent-specific key-word args to be passed.
+ """
+ self._env.set_as_eval() # noqa
+ self.is_eval = True
+
+ time_steps = self._training_config.num_eval_steps
+ episodes = self._training_config.num_eval_episodes
+
+ obs = self._env.reset()
+ for episode in range(episodes):
+ # Reset env and collect initial observation
+ for step in range(time_steps):
+ # Calculate action
+ action = self._calculate_action(obs)
+
+ # Perform the step
+ obs, reward, done, info = self._env.step(action)
+
+ if done:
+ break
+
+ # Introduce a delay between steps
+ time.sleep(self._training_config.time_delay / 1000)
+ obs = self._env.reset()
+ self._env.close()
+
+ @classmethod
+ def load(cls, path=None):
+ """Load an agent from file."""
+ _LOGGER.warning("Deterministic agents cannot be loaded")
+
+ def save(self):
+ """Save the agent."""
+ _LOGGER.warning("Deterministic agents cannot be saved")
+
+ def export(self):
+ """Export the agent to transportable file format."""
+ _LOGGER.warning("Deterministic agents cannot be exported")
diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py
index 0ac5022c..b8c49c14 100644
--- a/src/primaite/agents/hardcoded_acl.py
+++ b/src/primaite/agents/hardcoded_acl.py
@@ -1,10 +1,11 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from typing import Dict, List, Union
import numpy as np
from primaite.acl.access_control_list import AccessControlList
from primaite.acl.acl_rule import ACLRule
-from primaite.agents.agent import HardCodedAgentSessionABC
+from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import (
get_new_action,
get_node_of_ip,
diff --git a/src/primaite/agents/hardcoded_node.py b/src/primaite/agents/hardcoded_node.py
index b74c3a0b..10cc2b72 100644
--- a/src/primaite/agents/hardcoded_node.py
+++ b/src/primaite/agents/hardcoded_node.py
@@ -1,6 +1,7 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import numpy as np
-from primaite.agents.agent import HardCodedAgentSessionABC
+from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable
diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py
index 0281de7e..8afc98a1 100644
--- a/src/primaite/agents/rllib.py
+++ b/src/primaite/agents/rllib.py
@@ -1,10 +1,11 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from __future__ import annotations
import json
import shutil
from datetime import datetime
from pathlib import Path
-from typing import Any, Callable, Dict, TYPE_CHECKING, Union
+from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Union
from uuid import uuid4
from ray.rllib.algorithms import Algorithm
@@ -14,7 +15,7 @@ from ray.tune.logger import UnifiedLogger
from ray.tune.registry import register_env
from primaite import getLogger
-from primaite.agents.agent import AgentSessionABC
+from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
@@ -48,7 +49,12 @@ def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]:
class RLlibAgent(AgentSessionABC):
"""An AgentSession class that implements a Ray RLlib agent."""
- def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
+ def __init__(
+ self,
+ training_config_path: Optional[Union[str, Path]] = "",
+ lay_down_config_path: Optional[Union[str, Path]] = "",
+ session_path: Optional[Union[str, Path]] = None,
+ ) -> None:
"""
Initialise the RLLib Agent training session.
@@ -61,6 +67,13 @@ class RLlibAgent(AgentSessionABC):
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
or `A2C`)
"""
+ # TODO: implement RLlib agent loading
+ if session_path is not None:
+ msg = "RLlib agent loading has not been implemented yet"
+ _LOGGER.error(msg)
+ print(msg)
+ raise NotImplementedError
+
super().__init__(training_config_path, lay_down_config_path)
if not self._training_config.agent_framework == AgentFramework.RLLIB:
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py
index 462360a0..881426ab 100644
--- a/src/primaite/agents/sb3.py
+++ b/src/primaite/agents/sb3.py
@@ -1,14 +1,16 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from __future__ import annotations
+import json
from pathlib import Path
-from typing import Any, TYPE_CHECKING, Union
+from typing import Any, Optional, TYPE_CHECKING, Union
import numpy as np
from stable_baselines3 import A2C, PPO
from stable_baselines3.ppo import MlpPolicy as PPOMlp
from primaite import getLogger
-from primaite.agents.agent import AgentSessionABC
+from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
@@ -21,7 +23,12 @@ _LOGGER: "Logger" = getLogger(__name__)
class SB3Agent(AgentSessionABC):
"""An AgentSession class that implements a Stable Baselines3 agent."""
- def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
+ def __init__(
+ self,
+ training_config_path: Optional[Union[str, Path]] = None,
+ lay_down_config_path: Optional[Union[str, Path]] = None,
+ session_path: Optional[Union[str, Path]] = None,
+ ) -> None:
"""
Initialise the SB3 Agent training session.
@@ -34,7 +41,7 @@ class SB3Agent(AgentSessionABC):
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
or `A2C`)
"""
- super().__init__(training_config_path, lay_down_config_path)
+ super().__init__(training_config_path, lay_down_config_path, session_path)
if not self._training_config.agent_framework == AgentFramework.SB3:
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
@@ -51,7 +58,7 @@ class SB3Agent(AgentSessionABC):
self._tensorboard_log_path = self.learning_path / "tensorboard_logs"
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
- self._setup()
+
_LOGGER.debug(
f"Created {self.__class__.__name__} using: "
f"agent_framework={self._training_config.agent_framework}, "
@@ -61,8 +68,10 @@ class SB3Agent(AgentSessionABC):
self.is_eval = False
+ self._setup()
+
def _setup(self) -> None:
- super()._setup()
+ """Set up the SB3 Agent."""
self._env = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
@@ -70,14 +79,43 @@ class SB3Agent(AgentSessionABC):
timestamp_str=self.timestamp_str,
)
- self._agent = self._agent_class(
- PPOMlp,
- self._env,
- verbose=self.sb3_output_verbose_level,
- n_steps=self._training_config.num_train_steps,
- tensorboard_log=str(self._tensorboard_log_path),
- seed=self._training_config.seed,
- )
+ # check if there is a zip file that needs to be loaded
+ load_file = next(self.session_path.rglob("*.zip"), None)
+
+ if not load_file:
+ # create a new env and agent
+
+ self._agent = self._agent_class(
+ PPOMlp,
+ self._env,
+ verbose=self.sb3_output_verbose_level,
+ n_steps=self._training_config.num_train_steps,
+ tensorboard_log=str(self._tensorboard_log_path),
+ seed=self._training_config.seed,
+ )
+ else:
+ # set env values from session metadata
+ with open(self.session_path / "session_metadata.json", "r") as file:
+ md_dict = json.load(file)
+
+ # load environment values
+ if self.is_eval:
+ # evaluation always starts at 0
+ self._env.episode_count = 0
+ self._env.total_step_count = 0
+ else:
+ # carry on from previous learning sessions
+ self._env.episode_count = md_dict["learning"]["total_episodes"]
+ self._env.total_step_count = md_dict["learning"]["total_time_steps"]
+
+ # load the file
+ self._agent = self._agent_class.load(load_file, env=self._env)
+
+ # set agent values
+ self._agent.verbose = self.sb3_output_verbose_level
+ self._agent.tensorboard_log = self.session_path / "learning/tensorboard_logs"
+
+ super()._setup()
def _save_checkpoint(self) -> None:
checkpoint_n = self._training_config.checkpoint_every_n_episodes
@@ -149,11 +187,6 @@ class SB3Agent(AgentSessionABC):
self._env.close()
super().evaluate()
- @classmethod
- def load(cls, path: Union[str, Path]) -> SB3Agent:
- """Load an agent from file."""
- raise NotImplementedError
-
def save(self) -> None:
"""Save the agent."""
self._agent.save(self._saved_agent_path)
diff --git a/src/primaite/agents/simple.py b/src/primaite/agents/simple.py
index 2c130c0c..bfc7bcf2 100644
--- a/src/primaite/agents/simple.py
+++ b/src/primaite/agents/simple.py
@@ -1,6 +1,7 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from typing import TYPE_CHECKING
-from primaite.agents.agent import HardCodedAgentSessionABC
+from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
if TYPE_CHECKING:
diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py
index 353978f1..ff0ca8d2 100644
--- a/src/primaite/agents/utils.py
+++ b/src/primaite/agents/utils.py
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from typing import Dict, List, Union
import numpy as np
diff --git a/src/primaite/cli.py b/src/primaite/cli.py
index 863cbfd2..14db236c 100644
--- a/src/primaite/cli.py
+++ b/src/primaite/cli.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Provides a CLI using Typer as an entry point."""
import logging
import os
@@ -151,7 +151,7 @@ def setup(overwrite_existing: bool = True) -> None:
@app.command()
-def session(tc: Optional[str] = None, ldc: Optional[str] = None) -> None:
+def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[str] = None) -> None:
"""
Run a PrimAITE session.
@@ -162,11 +162,19 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None) -> None:
ldc: The lay down config file path. Optional. If no value is passed then
example default lay down config is used from:
~/primaite/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml.
+
+ load: The directory of a previous session. Optional. If no value is passed, then the session
+ will use the default training config and laydown config. Inversely, if a training config and laydown config
+ is passed while a session directory is passed, PrimAITE will load the session and ignore the training config
+ and laydown config.
"""
from primaite.config.lay_down_config import dos_very_basic_config_path
from primaite.config.training_config import main_training_config_path
from primaite.main import run
+ if load is not None:
+ run(session_path=load)
+
if not tc:
tc = main_training_config_path()
diff --git a/src/primaite/common/__init__.py b/src/primaite/common/__init__.py
index 5f47b0b5..738a30d1 100644
--- a/src/primaite/common/__init__.py
+++ b/src/primaite/common/__init__.py
@@ -1,2 +1,2 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Objects which are shared between many PrimAITE modules."""
diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py
index ff090ca9..0209c64d 100644
--- a/src/primaite/common/enums.py
+++ b/src/primaite/common/enums.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Enumerations for APE."""
from enum import Enum, IntEnum
diff --git a/src/primaite/common/protocol.py b/src/primaite/common/protocol.py
index f7a757e8..048ed0ab 100644
--- a/src/primaite/common/protocol.py
+++ b/src/primaite/common/protocol.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The protocol class."""
diff --git a/src/primaite/common/service.py b/src/primaite/common/service.py
index 1351a30d..7ee694db 100644
--- a/src/primaite/common/service.py
+++ b/src/primaite/common/service.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The Service class."""
from primaite.common.enums import SoftwareState
diff --git a/src/primaite/config/__init__.py b/src/primaite/config/__init__.py
index 03ed4cf1..9bd899f7 100644
--- a/src/primaite/config/__init__.py
+++ b/src/primaite/config/__init__.py
@@ -1 +1,2 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Configuration parameters for running experiments."""
diff --git a/src/primaite/config/lay_down_config.py b/src/primaite/config/lay_down_config.py
index 2cc5f9c2..80b0f619 100644
--- a/src/primaite/config/lay_down_config.py
+++ b/src/primaite/config/lay_down_config.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from pathlib import Path
from typing import Any, Dict, Final, TYPE_CHECKING, Union
diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py
index 628e2818..f618b37c 100644
--- a/src/primaite/config/training_config.py
+++ b/src/primaite/config/training_config.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from __future__ import annotations
from dataclasses import dataclass, field
diff --git a/src/primaite/data_viz/__init__.py b/src/primaite/data_viz/__init__.py
index db6ce6c8..ad43c141 100644
--- a/src/primaite/data_viz/__init__.py
+++ b/src/primaite/data_viz/__init__.py
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Utility to generate plots of sessions metrics after PrimAITE."""
from enum import Enum
diff --git a/src/primaite/data_viz/session_plots.py b/src/primaite/data_viz/session_plots.py
index 245b9774..39c2b4cc 100644
--- a/src/primaite/data_viz/session_plots.py
+++ b/src/primaite/data_viz/session_plots.py
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from pathlib import Path
from typing import Dict, Optional, Union
diff --git a/src/primaite/environment/__init__.py b/src/primaite/environment/__init__.py
index 8b0060c0..e837fe1e 100644
--- a/src/primaite/environment/__init__.py
+++ b/src/primaite/environment/__init__.py
@@ -1,2 +1,2 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Gym/Gymnasium environment for RL agents consisting of a simulated computer network."""
diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py
index cb9872d1..ebc47043 100644
--- a/src/primaite/environment/observations.py
+++ b/src/primaite/environment/observations.py
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Module for handling configurable observation spaces in PrimAITE."""
import logging
from abc import ABC, abstractmethod
diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py
index f78b5f8d..8f34204b 100644
--- a/src/primaite/environment/primaite_env.py
+++ b/src/primaite/environment/primaite_env.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Main environment module containing the PRIMmary AI Training Evironment (Primaite) class."""
import copy
import logging
diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py
index a0efac4d..aad15246 100644
--- a/src/primaite/environment/reward.py
+++ b/src/primaite/environment/reward.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Implements reward function."""
from typing import Dict, TYPE_CHECKING, Union
diff --git a/src/primaite/links/__init__.py b/src/primaite/links/__init__.py
index 6257f282..21ce44ba 100644
--- a/src/primaite/links/__init__.py
+++ b/src/primaite/links/__init__.py
@@ -1,2 +1,2 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Network connections between nodes in the simulation."""
diff --git a/src/primaite/links/link.py b/src/primaite/links/link.py
index 145de5f3..aa3fa7fb 100644
--- a/src/primaite/links/link.py
+++ b/src/primaite/links/link.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The link class."""
from typing import List
diff --git a/src/primaite/main.py b/src/primaite/main.py
index 78420972..aed39d73 100644
--- a/src/primaite/main.py
+++ b/src/primaite/main.py
@@ -1,8 +1,8 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The main PrimAITE session runner module."""
import argparse
from pathlib import Path
-from typing import Union
+from typing import Optional, Union
from primaite import getLogger
from primaite.primaite_session import PrimaiteSession
@@ -11,16 +11,21 @@ _LOGGER = getLogger(__name__)
def run(
- training_config_path: Union[str, Path],
- lay_down_config_path: Union[str, Path],
+ training_config_path: Optional[Union[str, Path]] = "",
+ lay_down_config_path: Optional[Union[str, Path]] = "",
+ session_path: Optional[Union[str, Path]] = None,
) -> None:
"""
Run the PrimAITE Session.
- :param training_config_path: The training config filepath.
- :param lay_down_config_path: The lay down config filepath.
+ :param training_config_path: YAML file containing configurable items defined in
+ `primaite.config.training_config.TrainingConfig`
+ :type training_config_path: Union[path, str]
+ :param lay_down_config_path: YAML file containing configurable items for generating network laydown.
+ :type lay_down_config_path: Union[path, str]
+ :param session_path: directory path of the session to load
"""
- session = PrimaiteSession(training_config_path, lay_down_config_path)
+ session = PrimaiteSession(training_config_path, lay_down_config_path, session_path)
session.setup()
session.learn()
@@ -31,9 +36,14 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tc")
parser.add_argument("--ldc")
+ parser.add_argument("--load")
+
args = parser.parse_args()
- if not args.tc:
- _LOGGER.error("Please provide a training config file using the --tc " "argument")
- if not args.ldc:
- _LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
- run(training_config_path=args.tc, lay_down_config_path=args.ldc)
+ if args.load:
+ run(session_path=args.load)
+ else:
+ if not args.tc:
+ _LOGGER.error("Please provide a training config file using the --tc " "argument")
+ if not args.ldc:
+ _LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
+ run(training_config_path=args.tc, lay_down_config_path=args.ldc)
diff --git a/src/primaite/nodes/__init__.py b/src/primaite/nodes/__init__.py
index 19347372..43b213d6 100644
--- a/src/primaite/nodes/__init__.py
+++ b/src/primaite/nodes/__init__.py
@@ -1,2 +1,2 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Nodes represent network hosts in the simulation."""
diff --git a/src/primaite/nodes/active_node.py b/src/primaite/nodes/active_node.py
index b73f80f0..b5df70b5 100644
--- a/src/primaite/nodes/active_node.py
+++ b/src/primaite/nodes/active_node.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""An Active Node (i.e. not an actuator)."""
import logging
from typing import Final
diff --git a/src/primaite/nodes/node.py b/src/primaite/nodes/node.py
index 7dd7d962..9118fa46 100644
--- a/src/primaite/nodes/node.py
+++ b/src/primaite/nodes/node.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The base Node class."""
from typing import Final
diff --git a/src/primaite/nodes/node_state_instruction_green.py b/src/primaite/nodes/node_state_instruction_green.py
index 0826efe6..8e03b40f 100644
--- a/src/primaite/nodes/node_state_instruction_green.py
+++ b/src/primaite/nodes/node_state_instruction_green.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Defines node behaviour for Green PoL."""
from typing import TYPE_CHECKING, Union
diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py
index abbe07ad..786e93ac 100644
--- a/src/primaite/nodes/node_state_instruction_red.py
+++ b/src/primaite/nodes/node_state_instruction_red.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Defines node behaviour for Green PoL."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Union
diff --git a/src/primaite/nodes/passive_node.py b/src/primaite/nodes/passive_node.py
index c79636e3..88c8cc85 100644
--- a/src/primaite/nodes/passive_node.py
+++ b/src/primaite/nodes/passive_node.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The Passive Node class (i.e. an actuator)."""
from primaite.common.enums import HardwareState, NodeType, Priority
from primaite.config.training_config import TrainingConfig
diff --git a/src/primaite/nodes/service_node.py b/src/primaite/nodes/service_node.py
index ef0cd92e..ce1ffe92 100644
--- a/src/primaite/nodes/service_node.py
+++ b/src/primaite/nodes/service_node.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""A Service Node (i.e. not an actuator)."""
import logging
from typing import Dict, Final
diff --git a/src/primaite/notebooks/__init__.py b/src/primaite/notebooks/__init__.py
index 6bb5abf4..eaf10005 100644
--- a/src/primaite/notebooks/__init__.py
+++ b/src/primaite/notebooks/__init__.py
@@ -1,5 +1,6 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Contains default jupyter notebooks which demonstrate PrimAITE functionality."""
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+
import importlib.util
import os
import subprocess
diff --git a/src/primaite/pol/__init__.py b/src/primaite/pol/__init__.py
index cba4b28b..1adb1491 100644
--- a/src/primaite/pol/__init__.py
+++ b/src/primaite/pol/__init__.py
@@ -1,2 +1,2 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Pattern of Life- Represents the actions of users on the network."""
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
diff --git a/src/primaite/pol/green_pol.py b/src/primaite/pol/green_pol.py
index 7df87590..0425a831 100644
--- a/src/primaite/pol/green_pol.py
+++ b/src/primaite/pol/green_pol.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Implements Pattern of Life on the network (nodes and links)."""
from typing import Dict
diff --git a/src/primaite/pol/ier.py b/src/primaite/pol/ier.py
index b46dbf22..7fab340d 100644
--- a/src/primaite/pol/ier.py
+++ b/src/primaite/pol/ier.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""
Information Exchange Requirements for APE.
diff --git a/src/primaite/pol/red_agent_pol.py b/src/primaite/pol/red_agent_pol.py
index 2801e8b0..ad55fa24 100644
--- a/src/primaite/pol/red_agent_pol.py
+++ b/src/primaite/pol/red_agent_pol.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Implements POL on the network (nodes and links) resulting from the red agent attack."""
from typing import Dict
diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py
index 5ef856d7..ab3c2150 100644
--- a/src/primaite/primaite_session.py
+++ b/src/primaite/primaite_session.py
@@ -1,11 +1,12 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Main entry point to PrimAITE. Configure training/evaluation experiments and input/output."""
from __future__ import annotations
from pathlib import Path
-from typing import Any, Dict, Final, Union
+from typing import Any, Dict, Final, Optional, Union
from primaite import getLogger
-from primaite.agents.agent import AgentSessionABC
+from primaite.agents.agent_abc import AgentSessionABC
from primaite.agents.hardcoded_acl import HardCodedACLAgent
from primaite.agents.hardcoded_node import HardCodedNodeAgent
from primaite.agents.rllib import RLlibAgent
@@ -14,6 +15,7 @@ from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyA
from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType
from primaite.config import lay_down_config, training_config
from primaite.config.training_config import TrainingConfig
+from primaite.utils.session_metadata_parser import parse_session_metadata
_LOGGER = getLogger(__name__)
@@ -27,15 +29,39 @@ class PrimaiteSession:
def __init__(
self,
- training_config_path: Union[str, Path],
- lay_down_config_path: Union[str, Path],
+ training_config_path: Optional[Union[str, Path]] = "",
+ lay_down_config_path: Optional[Union[str, Path]] = "",
+ session_path: Optional[Union[str, Path]] = None,
) -> None:
"""
The PrimaiteSession constructor.
- :param training_config_path: The training config path.
- :param lay_down_config_path: The lay down config path.
+ :param training_config_path: YAML file containing configurable items defined in
+ `primaite.config.training_config.TrainingConfig`
+ :type training_config_path: Union[path, str]
+ :param lay_down_config_path: YAML file containing configurable items for generating network laydown.
+ :type lay_down_config_path: Union[path, str]
+ :param session_path: directory path of the session to load
"""
+ self._agent_session: AgentSessionABC = None # noqa
+ self.session_path: Path = session_path # noqa
+ self.timestamp_str: str = None # noqa
+ self.learning_path: Path = None # noqa
+ self.evaluation_path: Path = None # noqa
+
+ # check if session path is provided
+ if session_path is not None:
+ # set load_session to true
+ self.is_load_session = True
+ if not isinstance(session_path, Path):
+ session_path = Path(session_path)
+
+ # if a session path is provided, load it
+ if not session_path.exists():
+ raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
+
+ md_dict, training_config_path, lay_down_config_path = parse_session_metadata(session_path)
+
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path, str]] = training_config_path
@@ -60,11 +86,15 @@ class PrimaiteSession:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}")
if self._training_config.action_type == ActionType.NODE:
# Deterministic Hardcoded Agent with Node Action Space
- self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path)
+ self._agent_session = HardCodedNodeAgent(
+ self._training_config_path, self._lay_down_config_path, self.session_path
+ )
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
- self._agent_session = HardCodedACLAgent(self._training_config_path, self._lay_down_config_path)
+ self._agent_session = HardCodedACLAgent(
+ self._training_config_path, self._lay_down_config_path, self.session_path
+ )
elif self._training_config.action_type == ActionType.ANY:
# Deterministic Hardcoded Agent with ANY Action Space
@@ -77,11 +107,15 @@ class PrimaiteSession:
elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHING}")
if self._training_config.action_type == ActionType.NODE:
- self._agent_session = DoNothingNodeAgent(self._training_config_path, self._lay_down_config_path)
+ self._agent_session = DoNothingNodeAgent(
+ self._training_config_path, self._lay_down_config_path, self.session_path
+ )
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
- self._agent_session = DoNothingACLAgent(self._training_config_path, self._lay_down_config_path)
+ self._agent_session = DoNothingACLAgent(
+ self._training_config_path, self._lay_down_config_path, self.session_path
+ )
elif self._training_config.action_type == ActionType.ANY:
# Deterministic Hardcoded Agent with ANY Action Space
@@ -93,10 +127,14 @@ class PrimaiteSession:
elif self._training_config.agent_identifier == AgentIdentifier.RANDOM:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}")
- self._agent_session = RandomAgent(self._training_config_path, self._lay_down_config_path)
+ self._agent_session = RandomAgent(
+ self._training_config_path, self._lay_down_config_path, self.session_path
+ )
elif self._training_config.agent_identifier == AgentIdentifier.DUMMY:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}")
- self._agent_session = DummyAgent(self._training_config_path, self._lay_down_config_path)
+ self._agent_session = DummyAgent(
+ self._training_config_path, self._lay_down_config_path, self.session_path
+ )
else:
# Invalid AgentFramework AgentIdentifier combo
@@ -105,12 +143,12 @@ class PrimaiteSession:
elif self._training_config.agent_framework == AgentFramework.SB3:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}")
# Stable Baselines3 Agent
- self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path)
+ self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path, self.session_path)
elif self._training_config.agent_framework == AgentFramework.RLLIB:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}")
# Ray RLlib Agent
- self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path)
+ self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path, self.session_path)
else:
# Invalid AgentFramework
diff --git a/src/primaite/setup/__init__.py b/src/primaite/setup/__init__.py
index 3c0bfe14..acfa48c4 100644
--- a/src/primaite/setup/__init__.py
+++ b/src/primaite/setup/__init__.py
@@ -1,2 +1,2 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Utilities to prepare the user's data folders."""
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
diff --git a/src/primaite/setup/old_installation_clean_up.py b/src/primaite/setup/old_installation_clean_up.py
index 1603f06e..43950e4f 100644
--- a/src/primaite/setup/old_installation_clean_up.py
+++ b/src/primaite/setup/old_installation_clean_up.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from typing import TYPE_CHECKING
from primaite import getLogger
diff --git a/src/primaite/setup/reset_demo_notebooks.py b/src/primaite/setup/reset_demo_notebooks.py
index 530a2c30..775f43b5 100644
--- a/src/primaite/setup/reset_demo_notebooks.py
+++ b/src/primaite/setup/reset_demo_notebooks.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import filecmp
import os
import shutil
diff --git a/src/primaite/setup/reset_example_configs.py b/src/primaite/setup/reset_example_configs.py
index 99d04149..df3b36a1 100644
--- a/src/primaite/setup/reset_example_configs.py
+++ b/src/primaite/setup/reset_example_configs.py
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import filecmp
import os
import shutil
diff --git a/src/primaite/setup/setup_app_dirs.py b/src/primaite/setup/setup_app_dirs.py
index 1288e63c..56f16a08 100644
--- a/src/primaite/setup/setup_app_dirs.py
+++ b/src/primaite/setup/setup_app_dirs.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from typing import TYPE_CHECKING
from primaite import _USER_DIRS, getLogger, LOG_DIR, NOTEBOOKS_DIR
diff --git a/src/primaite/transactions/__init__.py b/src/primaite/transactions/__init__.py
index 45315b22..9a881fd5 100644
--- a/src/primaite/transactions/__init__.py
+++ b/src/primaite/transactions/__init__.py
@@ -1,2 +1,2 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Record data of the system's state and agent's observations and actions."""
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py
index 09ec2cec..1a702748 100644
--- a/src/primaite/transactions/transaction.py
+++ b/src/primaite/transactions/transaction.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The Transaction class."""
from datetime import datetime
from typing import List, Optional, Tuple, TYPE_CHECKING, Union
diff --git a/src/primaite/utils/__init__.py b/src/primaite/utils/__init__.py
index 55e8a6ba..5dbd1e5f 100644
--- a/src/primaite/utils/__init__.py
+++ b/src/primaite/utils/__init__.py
@@ -1 +1,2 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Utilities for PrimAITE."""
diff --git a/src/primaite/utils/package_data.py b/src/primaite/utils/package_data.py
index a994f880..b9abca8f 100644
--- a/src/primaite/utils/package_data.py
+++ b/src/primaite/utils/package_data.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import os
from pathlib import Path
from typing import TYPE_CHECKING
diff --git a/src/primaite/utils/session_metadata_parser.py b/src/primaite/utils/session_metadata_parser.py
new file mode 100644
index 00000000..eb3c3339
--- /dev/null
+++ b/src/primaite/utils/session_metadata_parser.py
@@ -0,0 +1,59 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
+import json
+from pathlib import Path
+from typing import Union
+
+import yaml
+
+from primaite import getLogger
+
+_LOGGER = getLogger(__name__)
+
+
+def parse_session_metadata(session_path: Union[Path, str], dict_only=False):
+ """
+ Loads a session metadata from the given directory path.
+
+ :param session_path: Directory where the session metadata file is in
+ :param dict_only: If dict_only is true, the function will only return the dict contents of session metadata
+
+ :return: Dictionary which has all the session metadata contents
+ :rtype: Dict
+
+ :return: Path where the YAML copy of the training config is dumped into
+ :rtype: str
+ :return: Path where the YAML copy of the laydown config is dumped into
+ :rtype: str
+ """
+ if not isinstance(session_path, Path):
+ session_path = Path(session_path)
+
+ if not session_path.exists():
+ # Session path does not exist
+ msg = f"Failed to load PrimAITE Session, path does not exist: {session_path}"
+ _LOGGER.error(msg)
+ raise FileNotFoundError(msg)
+
+ # Unpack the session_metadata.json file
+ md_file = session_path / "session_metadata.json"
+ with open(md_file, "r") as file:
+ md_dict = json.load(file)
+
+ # if dict only, return dict without doing anything else
+ if dict_only:
+ return md_dict
+
+ # Create a temp directory and dump the training and lay down
+ # configs into it
+ temp_dir = session_path / ".temp"
+ temp_dir.mkdir(exist_ok=True)
+
+ temp_tc = temp_dir / "tc.yaml"
+ with open(temp_tc, "w") as file:
+ yaml.dump(md_dict["env"]["training_config"], file)
+
+ temp_ldc = temp_dir / "ldc.yaml"
+ with open(temp_ldc, "w") as file:
+ yaml.dump(md_dict["env"]["lay_down_config"], file)
+
+ return [md_dict, temp_tc, temp_ldc]
diff --git a/src/primaite/utils/session_output_reader.py b/src/primaite/utils/session_output_reader.py
index ad3dd4f4..7089c69a 100644
--- a/src/primaite/utils/session_output_reader.py
+++ b/src/primaite/utils/session_output_reader.py
@@ -1,5 +1,6 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from pathlib import Path
-from typing import Dict, Union
+from typing import Any, Dict, Tuple, Union
# Using polars as it's faster than Pandas; it will speed things up when
# files get big!
@@ -13,8 +14,33 @@ def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
The dictionary keys are the episode number, and the values are the mean reward that episode.
:param av_rewards_csv_file: The average rewards per episode csv file path.
- :return: The average rewards per episode cdv as a dict.
+ :return: The average rewards per episode csv as a dict.
"""
- df = pl.read_csv(av_rewards_csv_file).to_dict()
+ df_dict = pl.read_csv(av_rewards_csv_file).to_dict()
- return {v: df["Average Reward"][i] for i, v in enumerate(df["Episode"])}
+ return {v: df_dict["Average Reward"][i] for i, v in enumerate(df_dict["Episode"])}
+
+
+def all_transactions_dict(all_transactions_csv_file: Union[str, Path]) -> Dict[Tuple[int, int], Dict[str, Any]]:
+ """
+ Read an all transactions csv file and return as a dict.
+
+ The dict keys are a tuple with the structure (episode, step). The dict
+ values are the remaining columns as a dict.
+
+ :param all_transactions_csv_file: The all transactions csv file path.
+ :return: The all transactions csv file as a dict.
+ """
+ df_dict = pl.read_csv(all_transactions_csv_file).to_dict()
+ new_dict = {}
+
+ episodes = df_dict["Episode"]
+ steps = df_dict["Step"]
+ keys = list(df_dict.keys())
+
+ for i in range(len(episodes)):
+ key = (episodes[i], steps[i])
+ value_dict = {key: df_dict[key][i] for key in keys if key not in ["Episode", "Step"]}
+ new_dict[key] = value_dict
+
+ return new_dict
diff --git a/src/primaite/utils/session_output_writer.py b/src/primaite/utils/session_output_writer.py
index d05f69b1..e7f1b248 100644
--- a/src/primaite/utils/session_output_writer.py
+++ b/src/primaite/utils/session_output_writer.py
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import csv
from logging import Logger
from typing import Final, List, Tuple, TYPE_CHECKING, Union
diff --git a/tests/__init__.py b/tests/__init__.py
index 4a0bdce1..f8e6fc55 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -1,6 +1,9 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from pathlib import Path
from typing import Final
TEST_CONFIG_ROOT: Final[Path] = Path(__file__).parent / "config"
"The tests config root directory."
+
+TEST_ASSETS_ROOT: Final[Path] = Path(__file__).parent / "assets"
+"The tests assets root directory."
diff --git a/tests/assets/example_sb3_agent_session/learning/SB3_PPO.zip b/tests/assets/example_sb3_agent_session/learning/SB3_PPO.zip
new file mode 100644
index 00000000..666151e7
Binary files /dev/null and b/tests/assets/example_sb3_agent_session/learning/SB3_PPO.zip differ
diff --git a/tests/assets/example_sb3_agent_session/session_metadata.json b/tests/assets/example_sb3_agent_session/session_metadata.json
new file mode 100644
index 00000000..20f6a77c
--- /dev/null
+++ b/tests/assets/example_sb3_agent_session/session_metadata.json
@@ -0,0 +1 @@
+{"uuid": "301874d3-2e14-43c2-ba7f-e2b03ad05dde", "start_datetime": "2023-07-14T09:48:22.973005", "end_datetime": "2023-07-14T09:48:34.182715", "learning": {"total_episodes": 10, "total_time_steps": 2560}, "evaluation": {"total_episodes": 5, "total_time_steps": 1280}, "env": {"training_config": {"agent_framework": "SB3", "deep_learning_framework": "TF2", "agent_identifier": "PPO", "hard_coded_agent_view": "FULL", "random_red_agent": false, "action_type": "NODE", "num_train_episodes": 10, "num_train_steps": 256, "num_eval_episodes": 5, "num_eval_steps": 256, "checkpoint_every_n_episodes": 10, "observation_space": {"components": [{"name": "NODE_LINK_TABLE"}]}, "time_delay": 5, "session_type": "TRAIN_EVAL", "load_agent": false, "agent_load_file": null, "observation_space_high_value": 1000000000, "sb3_output_verbose_level": "NONE", "all_ok": 0, "off_should_be_on": -0.001, "off_should_be_resetting": -0.0005, "on_should_be_off": -0.0002, "on_should_be_resetting": -0.0005, "resetting_should_be_on": -0.0005, "resetting_should_be_off": -0.0002, "resetting": -0.0003, "good_should_be_patching": 0.0002, "good_should_be_compromised": 0.0005, "good_should_be_overwhelmed": 0.0005, "patching_should_be_good": -0.0005, "patching_should_be_compromised": 0.0002, "patching_should_be_overwhelmed": 0.0002, "patching": -0.0003, "compromised_should_be_good": -0.002, "compromised_should_be_patching": -0.002, "compromised_should_be_overwhelmed": -0.002, "compromised": -0.002, "overwhelmed_should_be_good": -0.002, "overwhelmed_should_be_patching": -0.002, "overwhelmed_should_be_compromised": -0.002, "overwhelmed": -0.002, "good_should_be_repairing": 0.0002, "good_should_be_restoring": 0.0002, "good_should_be_corrupt": 0.0005, "good_should_be_destroyed": 0.001, "repairing_should_be_good": -0.0005, "repairing_should_be_restoring": 0.0002, "repairing_should_be_corrupt": 0.0002, "repairing_should_be_destroyed": 0.0, "repairing": -0.0003, "restoring_should_be_good": -0.001, "restoring_should_be_repairing": -0.0002, "restoring_should_be_corrupt": 0.0001, "restoring_should_be_destroyed": 0.0002, "restoring": -0.0006, "corrupt_should_be_good": -0.001, "corrupt_should_be_repairing": -0.001, "corrupt_should_be_restoring": -0.001, "corrupt_should_be_destroyed": 0.0002, "corrupt": -0.001, "destroyed_should_be_good": -0.002, "destroyed_should_be_repairing": -0.002, "destroyed_should_be_restoring": -0.002, "destroyed_should_be_corrupt": -0.002, "destroyed": -0.002, "scanning": -0.0002, "red_ier_running": -0.0005, "green_ier_blocked": -0.001, "os_patching_duration": 5, "node_reset_duration": 5, "node_booting_duration": 3, "node_shutdown_duration": 2, "service_patching_duration": 5, "file_system_repairing_limit": 5, "file_system_restoring_limit": 5, "file_system_scanning_limit": 5, "deterministic": true, "seed": 12345}, "lay_down_config": [{"item_type": "PORTS", "ports_list": [{"port": "80"}]}, {"item_type": "SERVICES", "service_list": [{"name": "TCP"}]}, {"item_type": "NODE", "node_id": "1", "name": "PC1", "node_class": "SERVICE", "node_type": "COMPUTER", "priority": "P5", "hardware_state": "ON", "ip_address": "192.168.1.2", "software_state": "GOOD", "file_system_state": "GOOD", "services": [{"name": "TCP", "port": "80", "state": "GOOD"}]}, {"item_type": "NODE", "node_id": "2", "name": "PC2", "node_class": "SERVICE", "node_type": "COMPUTER", "priority": "P5", "hardware_state": "ON", "ip_address": "192.168.1.3", "software_state": "GOOD", "file_system_state": "GOOD", "services": [{"name": "TCP", "port": "80", "state": "GOOD"}]}, {"item_type": "NODE", "node_id": "3", "name": "SWITCH1", "node_class": "ACTIVE", "node_type": "SWITCH", "priority": "P2", "hardware_state": "ON", "ip_address": "192.168.1.1", "software_state": "GOOD", "file_system_state": "GOOD"}, {"item_type": "NODE", "node_id": "4", "name": "SERVER1", "node_class": "SERVICE", "node_type": "SERVER", "priority": "P5", "hardware_state": "ON", "ip_address": "192.168.1.4", "software_state": "GOOD", "file_system_state": "GOOD", "services": [{"name": "TCP", "port": "80", "state": "GOOD"}]}, {"item_type": "LINK", "id": "5", "name": "link1", "bandwidth": 1000000000, "source": "1", "destination": "3"}, {"item_type": "LINK", "id": "6", "name": "link2", "bandwidth": 1000000000, "source": "2", "destination": "3"}, {"item_type": "LINK", "id": "7", "name": "link3", "bandwidth": 1000000000, "source": "3", "destination": "4"}, {"item_type": "GREEN_IER", "id": "8", "start_step": 1, "end_step": 256, "load": 10000, "protocol": "TCP", "port": "80", "source": "1", "destination": "4", "mission_criticality": 1}, {"item_type": "GREEN_IER", "id": "9", "start_step": 1, "end_step": 256, "load": 10000, "protocol": "TCP", "port": "80", "source": "2", "destination": "4", "mission_criticality": 1}, {"item_type": "GREEN_IER", "id": "10", "start_step": 1, "end_step": 256, "load": 10000, "protocol": "TCP", "port": "80", "source": "4", "destination": "2", "mission_criticality": 5}, {"item_type": "ACL_RULE", "id": "11", "permission": "ALLOW", "source": "192.168.1.2", "destination": "192.168.1.4", "protocol": "TCP", "port": 80}, {"item_type": "ACL_RULE", "id": "12", "permission": "ALLOW", "source": "192.168.1.3", "destination": "192.168.1.4", "protocol": "TCP", "port": 80}, {"item_type": "ACL_RULE", "id": "13", "permission": "ALLOW", "source": "192.168.1.4", "destination": "192.168.1.3", "protocol": "TCP", "port": 80}, {"item_type": "RED_POL", "id": "14", "start_step": 20, "end_step": 20, "targetNodeId": "1", "initiator": "DIRECT", "type": "SERVICE", "protocol": "TCP", "state": "COMPROMISED", "sourceNodeId": "NA", "sourceNodeService": "NA", "sourceNodeServiceState": "NA"}, {"item_type": "RED_IER", "id": "15", "start_step": 30, "end_step": 256, "load": 10000000, "protocol": "TCP", "port": "80", "source": "1", "destination": "4", "mission_criticality": 0}, {"item_type": "RED_POL", "id": "16", "start_step": 40, "end_step": 40, "targetNodeId": "4", "initiator": "IER", "type": "SERVICE", "protocol": "TCP", "state": "OVERWHELMED", "sourceNodeId": "NA", "sourceNodeService": "NA", "sourceNodeServiceState": "NA"}]}}
diff --git a/tests/config/legacy_conversion/legacy_training_config.yaml b/tests/config/legacy_conversion/legacy_training_config.yaml
index 5c2025a2..fb24e3d7 100644
--- a/tests/config/legacy_conversion/legacy_training_config.yaml
+++ b/tests/config/legacy_conversion/legacy_training_config.yaml
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Main Config File
# Generic config values
diff --git a/tests/config/legacy_conversion/new_training_config.yaml b/tests/config/legacy_conversion/new_training_config.yaml
index c57741f7..3df29d04 100644
--- a/tests/config/legacy_conversion/new_training_config.yaml
+++ b/tests/config/legacy_conversion/new_training_config.yaml
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Main Config File
# Generic config values
diff --git a/tests/config/obs_tests/laydown.yaml b/tests/config/obs_tests/laydown.yaml
index ef77ce83..3590492b 100644
--- a/tests/config/obs_tests/laydown.yaml
+++ b/tests/config/obs_tests/laydown.yaml
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
- item_type: PORTS
ports_list:
- port: '80'
diff --git a/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml b/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml
index 2ac8f59a..8374115d 100644
--- a/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml
+++ b/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.
diff --git a/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml b/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml
index a9986d5b..c68199a0 100644
--- a/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml
+++ b/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.
diff --git a/tests/config/obs_tests/main_config_NODE_STATUSES.yaml b/tests/config/obs_tests/main_config_NODE_STATUSES.yaml
index a129712c..c662e715 100644
--- a/tests/config/obs_tests/main_config_NODE_STATUSES.yaml
+++ b/tests/config/obs_tests/main_config_NODE_STATUSES.yaml
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.
diff --git a/tests/config/obs_tests/main_config_without_obs.yaml b/tests/config/obs_tests/main_config_without_obs.yaml
index 03d11b82..bd23bded 100644
--- a/tests/config/obs_tests/main_config_without_obs.yaml
+++ b/tests/config/obs_tests/main_config_without_obs.yaml
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.
diff --git a/tests/config/one_node_states_on_off_lay_down_config.yaml b/tests/config/one_node_states_on_off_lay_down_config.yaml
index aadbd449..65257d62 100644
--- a/tests/config/one_node_states_on_off_lay_down_config.yaml
+++ b/tests/config/one_node_states_on_off_lay_down_config.yaml
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
- item_type: PORTS
ports_list:
- port: '21'
diff --git a/tests/config/one_node_states_on_off_main_config.yaml b/tests/config/one_node_states_on_off_main_config.yaml
index dd425a8c..133b2af8 100644
--- a/tests/config/one_node_states_on_off_main_config.yaml
+++ b/tests/config/one_node_states_on_off_main_config.yaml
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.
@@ -7,6 +8,14 @@
# "CUSTOM" (Custom Agent)
agent_framework: CUSTOM
+# Sets which deep learning framework will be used (by RLlib ONLY).
+# Default is TF (Tensorflow).
+# Options are:
+# "TF" (Tensorflow)
+# TF2 (Tensorflow 2.X)
+# TORCH (PyTorch)
+deep_learning_framework: TF2
+
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
@@ -17,32 +26,78 @@ agent_framework: CUSTOM
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: DUMMY
+# Sets whether Red Agent POL and IER is randomised.
+# Options are:
+# True
+# False
+random_red_agent: False
+
+# The (integer) seed to be used in random number generation
+# Default is None (null)
+seed: null
+
+# Set whether the agent will be deterministic instead of stochastic
+# Options are:
+# True
+# False
+deterministic: False
+
+# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
+# Options are:
+# "BASIC" (The current observation space only)
+# "FULL" (Full environment view with actions taken and reward feedback)
+hard_coded_agent_view: FULL
+
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
+# observation space
+observation_space:
+ # flatten: true
+ components:
+ - name: NODE_LINK_TABLE
+ # - name: NODE_STATUSES
+ # - name: LINK_TRAFFIC_LEVELS
+# Number of episodes for training to run per session
+num_train_episodes: 10
+
+# Number of time_steps for training per episode
+num_train_steps: 256
+
# Number of episodes for evaluation to run per session
num_eval_episodes: 1
# Number of time_steps for evaluation per episode
num_eval_steps: 15
-# Time delay between steps (for generic agents)
-time_delay: 1
-# Type of session to be run (TRAINING or EVALUATION)
+# Sets how often the agent will save a checkpoint (every n time episodes).
+# Set to 0 if no checkpoints are required. Default is 10
+checkpoint_every_n_episodes: 10
+
+# Time delay (milliseconds) between steps for CUSTOM agents.
+time_delay: 5
+
+# Type of session to be run. Options are:
+# "TRAIN" (Trains an agent)
+# "EVAL" (Evaluates an agent)
+# "TRAIN_EVAL" (Trains then evaluates an agent)
session_type: EVAL
-# Determine whether to load an agent from file
-load_agent: False
-# File path and file name of agent if you're loading one in
-agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
+# The Stable Baselines3 learn/eval output verbosity level:
+# Options are:
+# "NONE" (No Output)
+# "INFO" (Info Messages (such as devices and wrappers used))
+# "DEBUG" (All Messages)
+sb3_output_verbose_level: NONE
+
# Reward values
# Generic
all_ok: 0
diff --git a/tests/config/ppo_not_seeded_training_config.yaml b/tests/config/ppo_not_seeded_training_config.yaml
index 14b3f087..1b1d5deb 100644
--- a/tests/config/ppo_not_seeded_training_config.yaml
+++ b/tests/config/ppo_not_seeded_training_config.yaml
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.
diff --git a/tests/config/ppo_seeded_training_config.yaml b/tests/config/ppo_seeded_training_config.yaml
index a176c793..14a4face 100644
--- a/tests/config/ppo_seeded_training_config.yaml
+++ b/tests/config/ppo_seeded_training_config.yaml
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.
diff --git a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml
index 0f378634..2fcca1f2 100644
--- a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml
+++ b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.
diff --git a/tests/config/single_action_space_lay_down_config.yaml b/tests/config/single_action_space_lay_down_config.yaml
index 9d05b84a..9fb82ac2 100644
--- a/tests/config/single_action_space_lay_down_config.yaml
+++ b/tests/config/single_action_space_lay_down_config.yaml
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
- item_type: PORTS
ports_list:
- port: '21'
diff --git a/tests/config/single_action_space_main_config.yaml b/tests/config/single_action_space_main_config.yaml
index c875757f..625491fe 100644
--- a/tests/config/single_action_space_main_config.yaml
+++ b/tests/config/single_action_space_main_config.yaml
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.
diff --git a/tests/config/test_random_red_main_config.yaml b/tests/config/test_random_red_main_config.yaml
index e2b24b41..3416029c 100644
--- a/tests/config/test_random_red_main_config.yaml
+++ b/tests/config/test_random_red_main_config.yaml
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.
@@ -5,7 +6,15 @@
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
-agent_framework: CUSTOM
+agent_framework: SB3
+
+# Sets which deep learning framework will be used (by RLlib ONLY).
+# Default is TF (Tensorflow).
+# Options are:
+# "TF" (Tensorflow)
+# TF2 (Tensorflow 2.X)
+# TORCH (PyTorch)
+deep_learning_framework: TF2
# Sets which Agent class will be used.
# Options are:
@@ -15,7 +24,7 @@ agent_framework: CUSTOM
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
-agent_identifier: DUMMY
+agent_identifier: PPO
# Sets whether Red Agent POL and IER is randomised.
# Options are:
@@ -23,92 +32,128 @@ agent_identifier: DUMMY
# False
random_red_agent: True
+# The (integer) seed to be used in random number generation
+# Default is None (null)
+seed: null
+
+# Set whether the agent will be deterministic instead of stochastic
+# Options are:
+# True
+# False
+deterministic: False
+
+# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
+# Options are:
+# "BASIC" (The current observation space only)
+# "FULL" (Full environment view with actions taken and reward feedback)
+hard_coded_agent_view: FULL
+
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
+# observation space
+observation_space:
+ # flatten: true
+ components:
+ - name: NODE_LINK_TABLE
+ # - name: NODE_STATUSES
+ # - name: LINK_TRAFFIC_LEVELS
+
+
# Number of episodes for training to run per session
-num_train_episodes: 2
+num_train_episodes: 10
# Number of time_steps for training per episode
-num_train_steps: 15
+num_train_steps: 256
# Number of episodes for evaluation to run per session
-num_eval_episodes: 2
+num_eval_episodes: 1
# Number of time_steps for evaluation per episode
-num_eval_steps: 15
-# Time delay between steps (for generic agents)
-time_delay: 1
+num_eval_steps: 256
-# Type of session to be run (TRAINING or EVALUATION)
-session_type: EVAL
-# Determine whether to load an agent from file
-load_agent: False
-# File path and file name of agent if you're loading one in
-agent_load_file: C:\[Path]\[agent_saved_filename.zip]
+# Sets how often the agent will save a checkpoint (every n time episodes).
+# Set to 0 if no checkpoints are required. Default is 10
+checkpoint_every_n_episodes: 10
+
+# Time delay (milliseconds) between steps for CUSTOM agents.
+time_delay: 5
+
+# Type of session to be run. Options are:
+# "TRAIN" (Trains an agent)
+# "EVAL" (Evaluates an agent)
+# "TRAIN_EVAL" (Trains then evaluates an agent)
+session_type: TRAIN_EVAL
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
+# The Stable Baselines3 learn/eval output verbosity level:
+# Options are:
+# "NONE" (No Output)
+# "INFO" (Info Messages (such as devices and wrappers used))
+# "DEBUG" (All Messages)
+sb3_output_verbose_level: NONE
+
# Reward values
# Generic
all_ok: 0
# Node Hardware State
-off_should_be_on: -10
-off_should_be_resetting: -5
-on_should_be_off: -2
-on_should_be_resetting: -5
-resetting_should_be_on: -5
-resetting_should_be_off: -2
-resetting: -3
+off_should_be_on: -0.001
+off_should_be_resetting: -0.0005
+on_should_be_off: -0.0002
+on_should_be_resetting: -0.0005
+resetting_should_be_on: -0.0005
+resetting_should_be_off: -0.0002
+resetting: -0.0003
# Node Software or Service State
-good_should_be_patching: 2
-good_should_be_compromised: 5
-good_should_be_overwhelmed: 5
-patching_should_be_good: -5
-patching_should_be_compromised: 2
-patching_should_be_overwhelmed: 2
-patching: -3
-compromised_should_be_good: -20
-compromised_should_be_patching: -20
-compromised_should_be_overwhelmed: -20
-compromised: -20
-overwhelmed_should_be_good: -20
-overwhelmed_should_be_patching: -20
-overwhelmed_should_be_compromised: -20
-overwhelmed: -20
+good_should_be_patching: 0.0002
+good_should_be_compromised: 0.0005
+good_should_be_overwhelmed: 0.0005
+patching_should_be_good: -0.0005
+patching_should_be_compromised: 0.0002
+patching_should_be_overwhelmed: 0.0002
+patching: -0.0003
+compromised_should_be_good: -0.002
+compromised_should_be_patching: -0.002
+compromised_should_be_overwhelmed: -0.002
+compromised: -0.002
+overwhelmed_should_be_good: -0.002
+overwhelmed_should_be_patching: -0.002
+overwhelmed_should_be_compromised: -0.002
+overwhelmed: -0.002
# Node File System State
-good_should_be_repairing: 2
-good_should_be_restoring: 2
-good_should_be_corrupt: 5
-good_should_be_destroyed: 10
-repairing_should_be_good: -5
-repairing_should_be_restoring: 2
-repairing_should_be_corrupt: 2
-repairing_should_be_destroyed: 0
-repairing: -3
-restoring_should_be_good: -10
-restoring_should_be_repairing: -2
-restoring_should_be_corrupt: 1
-restoring_should_be_destroyed: 2
-restoring: -6
-corrupt_should_be_good: -10
-corrupt_should_be_repairing: -10
-corrupt_should_be_restoring: -10
-corrupt_should_be_destroyed: 2
-corrupt: -10
-destroyed_should_be_good: -20
-destroyed_should_be_repairing: -20
-destroyed_should_be_restoring: -20
-destroyed_should_be_corrupt: -20
-destroyed: -20
-scanning: -2
+good_should_be_repairing: 0.0002
+good_should_be_restoring: 0.0002
+good_should_be_corrupt: 0.0005
+good_should_be_destroyed: 0.001
+repairing_should_be_good: -0.0005
+repairing_should_be_restoring: 0.0002
+repairing_should_be_corrupt: 0.0002
+repairing_should_be_destroyed: 0.0000
+repairing: -0.0003
+restoring_should_be_good: -0.001
+restoring_should_be_repairing: -0.0002
+restoring_should_be_corrupt: 0.0001
+restoring_should_be_destroyed: 0.0002
+restoring: -0.0006
+corrupt_should_be_good: -0.001
+corrupt_should_be_repairing: -0.001
+corrupt_should_be_restoring: -0.001
+corrupt_should_be_destroyed: 0.0002
+corrupt: -0.001
+destroyed_should_be_good: -0.002
+destroyed_should_be_repairing: -0.002
+destroyed_should_be_restoring: -0.002
+destroyed_should_be_corrupt: -0.002
+destroyed: -0.002
+scanning: -0.0002
# IER status
-red_ier_running: -5
-green_ier_blocked: -10
+red_ier_running: -0.0005
+green_ier_blocked: -0.001
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS
diff --git a/tests/config/train_episode_step.yaml b/tests/config/train_episode_step.yaml
index f112b741..31337b0c 100644
--- a/tests/config/train_episode_step.yaml
+++ b/tests/config/train_episode_step.yaml
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.
diff --git a/tests/config/training_config_main_rllib.yaml b/tests/config/training_config_main_rllib.yaml
new file mode 100644
index 00000000..40cbc0fc
--- /dev/null
+++ b/tests/config/training_config_main_rllib.yaml
@@ -0,0 +1,164 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
+# Training Config File
+
+# Sets which agent algorithm framework will be used.
+# Options are:
+# "SB3" (Stable Baselines3)
+# "RLLIB" (Ray RLlib)
+# "CUSTOM" (Custom Agent)
+agent_framework: RLLIB
+
+# Sets which deep learning framework will be used (by RLlib ONLY).
+# Default is TF (Tensorflow).
+# Options are:
+# "TF" (Tensorflow)
+# TF2 (Tensorflow 2.X)
+# TORCH (PyTorch)
+deep_learning_framework: TF2
+
+# Sets which Agent class will be used.
+# Options are:
+# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
+# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
+# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
+# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
+# "RANDOM" (primaite.agents.simple.RandomAgent)
+# "DUMMY" (primaite.agents.simple.DummyAgent)
+agent_identifier: PPO
+
+# Sets whether Red Agent POL and IER is randomised.
+# Options are:
+# True
+# False
+random_red_agent: False
+
+# The (integer) seed to be used in random number generation
+# Default is None (null)
+seed: null
+
+# Set whether the agent will be deterministic instead of stochastic
+# Options are:
+# True
+# False
+deterministic: False
+
+# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
+# Options are:
+# "BASIC" (The current observation space only)
+# "FULL" (Full environment view with actions taken and reward feedback)
+hard_coded_agent_view: FULL
+
+# Sets How the Action Space is defined:
+# "NODE"
+# "ACL"
+# "ANY" node and acl actions
+action_type: NODE
+# observation space
+observation_space:
+ # flatten: true
+ components:
+ - name: NODE_LINK_TABLE
+ # - name: NODE_STATUSES
+ # - name: LINK_TRAFFIC_LEVELS
+
+
+# Number of episodes for training to run per session
+num_train_episodes: 10
+
+# Number of time_steps for training per episode
+num_train_steps: 256
+
+# Number of episodes for evaluation to run per session
+num_eval_episodes: 1
+
+# Number of time_steps for evaluation per episode
+num_eval_steps: 256
+
+# Sets how often the agent will save a checkpoint (every n time episodes).
+# Set to 0 if no checkpoints are required. Default is 10
+checkpoint_every_n_episodes: 10
+
+# Time delay (milliseconds) between steps for CUSTOM agents.
+time_delay: 5
+
+# Type of session to be run. Options are:
+# "TRAIN" (Trains an agent)
+# "EVAL" (Evaluates an agent)
+# "TRAIN_EVAL" (Trains then evaluates an agent)
+session_type: TRAIN_EVAL
+
+# Environment config values
+# The high value for the observation space
+observation_space_high_value: 1000000000
+
+# The Stable Baselines3 learn/eval output verbosity level:
+# Options are:
+# "NONE" (No Output)
+# "INFO" (Info Messages (such as devices and wrappers used))
+# "DEBUG" (All Messages)
+sb3_output_verbose_level: NONE
+
+# Reward values
+# Generic
+all_ok: 0
+# Node Hardware State
+off_should_be_on: -0.001
+off_should_be_resetting: -0.0005
+on_should_be_off: -0.0002
+on_should_be_resetting: -0.0005
+resetting_should_be_on: -0.0005
+resetting_should_be_off: -0.0002
+resetting: -0.0003
+# Node Software or Service State
+good_should_be_patching: 0.0002
+good_should_be_compromised: 0.0005
+good_should_be_overwhelmed: 0.0005
+patching_should_be_good: -0.0005
+patching_should_be_compromised: 0.0002
+patching_should_be_overwhelmed: 0.0002
+patching: -0.0003
+compromised_should_be_good: -0.002
+compromised_should_be_patching: -0.002
+compromised_should_be_overwhelmed: -0.002
+compromised: -0.002
+overwhelmed_should_be_good: -0.002
+overwhelmed_should_be_patching: -0.002
+overwhelmed_should_be_compromised: -0.002
+overwhelmed: -0.002
+# Node File System State
+good_should_be_repairing: 0.0002
+good_should_be_restoring: 0.0002
+good_should_be_corrupt: 0.0005
+good_should_be_destroyed: 0.001
+repairing_should_be_good: -0.0005
+repairing_should_be_restoring: 0.0002
+repairing_should_be_corrupt: 0.0002
+repairing_should_be_destroyed: 0.0000
+repairing: -0.0003
+restoring_should_be_good: -0.001
+restoring_should_be_repairing: -0.0002
+restoring_should_be_corrupt: 0.0001
+restoring_should_be_destroyed: 0.0002
+restoring: -0.0006
+corrupt_should_be_good: -0.001
+corrupt_should_be_repairing: -0.001
+corrupt_should_be_restoring: -0.001
+corrupt_should_be_destroyed: 0.0002
+corrupt: -0.001
+destroyed_should_be_good: -0.002
+destroyed_should_be_repairing: -0.002
+destroyed_should_be_restoring: -0.002
+destroyed_should_be_corrupt: -0.002
+destroyed: -0.002
+scanning: -0.0002
+# IER status
+red_ier_running: -0.0005
+green_ier_blocked: -0.001
+
+# Patching / Reset durations
+os_patching_duration: 5 # The time taken to patch the OS
+node_reset_duration: 5 # The time taken to reset a node (hardware)
+service_patching_duration: 5 # The time taken to patch a service
+file_system_repairing_limit: 5 # The time take to repair the file system
+file_system_restoring_limit: 5 # The time take to restore the file system
+file_system_scanning_limit: 5 # The time taken to scan the file system
diff --git a/tests/conftest.py b/tests/conftest.py
index aaf4dbce..9b0db139 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,11 +1,11 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import datetime
import json
import shutil
import tempfile
from datetime import datetime
from pathlib import Path
-from typing import Any, Dict, Union
+from typing import Any, Dict, Tuple, Union
from unittest.mock import patch
import pytest
@@ -13,7 +13,7 @@ import pytest
from primaite import getLogger
from primaite.environment.primaite_env import Primaite
from primaite.primaite_session import PrimaiteSession
-from primaite.utils.session_output_reader import av_rewards_dict
+from primaite.utils.session_output_reader import all_transactions_dict, av_rewards_dict
from tests.mock_and_patch.get_session_path_mock import get_temp_session_path
ACTION_SPACE_NODE_VALUES = 1
@@ -37,16 +37,26 @@ class TempPrimaiteSession(PrimaiteSession):
super().__init__(training_config_path, lay_down_config_path)
self.setup()
- def learn_av_reward_per_episode(self) -> Dict[int, float]:
+ def learn_av_reward_per_episode_dict(self) -> Dict[int, float]:
"""Get the learn av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.learning_path / csv_file)
- def eval_av_reward_per_episode_csv(self) -> Dict[int, float]:
+ def eval_av_reward_per_episode_dict(self) -> Dict[int, float]:
"""Get the eval av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.evaluation_path / csv_file)
+ def learn_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
+ """Get the learn all transactions from file."""
+ csv_file = f"all_transactions_{self.timestamp_str}.csv"
+ return all_transactions_dict(self.learning_path / csv_file)
+
+ def eval_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
+ """Get the eval all transactions from file."""
+ csv_file = f"all_transactions_{self.timestamp_str}.csv"
+ return all_transactions_dict(self.evaluation_path / csv_file)
+
def metadata_file_as_dict(self) -> Dict[str, Any]:
"""Read the session_metadata.json file and return as a dict."""
with open(self.session_path / "session_metadata.json", "r") as file:
@@ -62,7 +72,6 @@ class TempPrimaiteSession(PrimaiteSession):
def __exit__(self, type, value, tb):
shutil.rmtree(self.session_path)
- shutil.rmtree(self.session_path.parent)
_LOGGER.debug(f"Deleted temp session directory: {self.session_path}")
@@ -114,7 +123,7 @@ def temp_primaite_session(request):
"""
training_config_path = request.param[0]
lay_down_config_path = request.param[1]
- with patch("primaite.agents.agent.get_session_path", get_temp_session_path) as mck:
+ with patch("primaite.agents.agent_abc.get_session_path", get_temp_session_path) as mck:
mck.session_timestamp = datetime.now()
return TempPrimaiteSession(training_config_path, lay_down_config_path)
diff --git a/tests/mock_and_patch/__init__.py b/tests/mock_and_patch/__init__.py
index e69de29b..778748f7 100644
--- a/tests/mock_and_patch/__init__.py
+++ b/tests/mock_and_patch/__init__.py
@@ -0,0 +1 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
diff --git a/tests/mock_and_patch/get_session_path_mock.py b/tests/mock_and_patch/get_session_path_mock.py
index 90c0cb5d..190e1dba 100644
--- a/tests/mock_and_patch/get_session_path_mock.py
+++ b/tests/mock_and_patch/get_session_path_mock.py
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import tempfile
from datetime import datetime
from pathlib import Path
diff --git a/tests/test_acl.py b/tests/test_acl.py
index 30f12697..4ef9d78c 100644
--- a/tests/test_acl.py
+++ b/tests/test_acl.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Used to tes the ACL functions."""
from primaite.acl.access_control_list import AccessControlList
diff --git a/tests/test_active_node.py b/tests/test_active_node.py
index addc595c..880c0f02 100644
--- a/tests/test_active_node.py
+++ b/tests/test_active_node.py
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Used to test Active Node functions."""
import pytest
diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py
index d5844fd9..3bcdb66d 100644
--- a/tests/test_observation_space.py
+++ b/tests/test_observation_space.py
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Test env creation and behaviour with different observation spaces."""
import numpy as np
import pytest
diff --git a/tests/test_primaite_session.py b/tests/test_primaite_session.py
index 75ea5882..210d931e 100644
--- a/tests/test_primaite_session.py
+++ b/tests/test_primaite_session.py
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import os
import pytest
diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py
index f8885f3e..3496ed9d 100644
--- a/tests/test_red_random_agent_behaviour.py
+++ b/tests/test_red_random_agent_behaviour.py
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import pytest
from primaite.config.lay_down_config import data_manipulation_config_path
diff --git a/tests/test_resetting_node.py b/tests/test_resetting_node.py
index fb7dc83d..80e13c5b 100644
--- a/tests/test_resetting_node.py
+++ b/tests/test_resetting_node.py
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Used to test Active Node functions."""
import pytest
diff --git a/tests/test_reward.py b/tests/test_reward.py
index bb6eb1b0..741c6f13 100644
--- a/tests/test_reward.py
+++ b/tests/test_reward.py
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import pytest
from primaite import getLogger
@@ -48,5 +49,5 @@ def test_rewards_are_being_penalised_at_each_step_function(
"""
with temp_primaite_session as session:
session.evaluate()
- ev_rewards = session.eval_av_reward_per_episode_csv()
+ ev_rewards = session.eval_av_reward_per_episode_dict()
assert ev_rewards[1] == -8.0
diff --git a/tests/test_rllib_agent.py b/tests/test_rllib_agent.py
new file mode 100644
index 00000000..f494ea81
--- /dev/null
+++ b/tests/test_rllib_agent.py
@@ -0,0 +1,24 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
+import pytest
+
+from primaite import getLogger
+from primaite.config.lay_down_config import dos_very_basic_config_path
+from tests import TEST_CONFIG_ROOT
+
+_LOGGER = getLogger(__name__)
+
+
+@pytest.mark.parametrize(
+ "temp_primaite_session",
+ [[TEST_CONFIG_ROOT / "training_config_main_rllib.yaml", dos_very_basic_config_path()]],
+ indirect=True,
+)
+def test_primaite_session(temp_primaite_session):
+ """Test the training_config_main_rllib.yaml training config file."""
+ with temp_primaite_session as session:
+ session_path = session.session_path
+ assert session_path.exists()
+ session.learn()
+
+ assert len(session.learn_av_reward_per_episode_dict().keys()) == 10
+ assert len(session.learn_all_transactions_dict().keys()) == 10 * 256
diff --git a/tests/test_seeding_and_deterministic_session.py b/tests/test_seeding_and_deterministic_session.py
index 34cb43fb..5220fb1c 100644
--- a/tests/test_seeding_and_deterministic_session.py
+++ b/tests/test_seeding_and_deterministic_session.py
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import pytest as pytest
from primaite.config.lay_down_config import dos_very_basic_config_path
@@ -28,7 +29,7 @@ def test_seeded_learning(temp_primaite_session):
"Expected output is based upon a agent that was trained with " "seed 67890"
)
session.learn()
- actual_mean_reward_per_episode = session.learn_av_reward_per_episode()
+ actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict()
assert actual_mean_reward_per_episode == expected_mean_reward_per_episode
@@ -45,5 +46,5 @@ def test_deterministic_evaluation(temp_primaite_session):
# do stuff
session.learn()
session.evaluate()
- eval_mean_reward = session.eval_av_reward_per_episode_csv()
+ eval_mean_reward = session.eval_av_reward_per_episode_dict()
assert len(set(eval_mean_reward.values())) == 1
diff --git a/tests/test_service_node.py b/tests/test_service_node.py
index 4383fc1b..2f504cd6 100644
--- a/tests/test_service_node.py
+++ b/tests/test_service_node.py
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Used to test Service Node functions."""
import pytest
diff --git a/tests/test_session_loading.py b/tests/test_session_loading.py
new file mode 100644
index 00000000..bcd28d96
--- /dev/null
+++ b/tests/test_session_loading.py
@@ -0,0 +1,189 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
+import os.path
+import shutil
+import tempfile
+from pathlib import Path
+from typing import Union
+from uuid import uuid4
+
+from primaite import getLogger
+from primaite.agents.sb3 import SB3Agent
+from primaite.common.enums import AgentFramework, AgentIdentifier
+from primaite.main import run
+from primaite.primaite_session import PrimaiteSession
+from primaite.utils.session_output_reader import av_rewards_dict
+from tests import TEST_ASSETS_ROOT
+
+_LOGGER = getLogger(__name__)
+
+
+def copy_session_asset(asset_path: Union[str, Path]) -> str:
+ """Copies the asset into a temporary test folder."""
+ if asset_path is None:
+ raise Exception("No path provided")
+
+ if isinstance(asset_path, Path):
+ asset_path = str(os.path.normpath(asset_path))
+
+ copy_path = str(Path(tempfile.gettempdir()) / "primaite" / str(uuid4()))
+
+ # copy the asset into a temp path
+ try:
+ shutil.copytree(asset_path, copy_path)
+ except Exception as e:
+ msg = f"Unable to copy directory: {asset_path}"
+ _LOGGER.error(msg, e)
+ print(msg, e)
+
+ _LOGGER.debug(f"Copied test asset to: {copy_path}")
+
+ # return the copied assets path
+ return copy_path
+
+
+def test_load_sb3_session():
+ """Test that loading an SB3 agent works."""
+ expected_learn_mean_reward_per_episode = {
+ 10: 0,
+ 11: -0.008037109374999995,
+ 12: -0.007978515624999988,
+ 13: -0.008191406249999991,
+ 14: -0.00817578124999999,
+ 15: -0.008085937499999998,
+ 16: -0.007837890624999982,
+ 17: -0.007798828124999992,
+ 18: -0.007777343749999998,
+ 19: -0.007958984374999988,
+ 20: -0.0077499999999999835,
+ }
+
+ test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
+
+ loaded_agent = SB3Agent(session_path=test_path)
+
+ # loaded agent should have the same UUID as the previous agent
+ assert loaded_agent.uuid == "301874d3-2e14-43c2-ba7f-e2b03ad05dde"
+ assert loaded_agent._training_config.agent_framework == AgentFramework.SB3.name
+ assert loaded_agent._training_config.agent_identifier == AgentIdentifier.PPO.name
+ assert loaded_agent._training_config.deterministic
+ assert loaded_agent._training_config.seed == 12345
+ assert str(loaded_agent.session_path) == str(test_path)
+
+ # run another learn session
+ loaded_agent.learn()
+
+ learn_mean_rewards = av_rewards_dict(
+ loaded_agent.learning_path / f"average_reward_per_episode_{loaded_agent.timestamp_str}.csv"
+ )
+
+ # run is seeded so should have the expected learn value
+ assert learn_mean_rewards == expected_learn_mean_reward_per_episode
+
+ # run an evaluation
+ loaded_agent.evaluate()
+
+ # load the evaluation average reward csv file
+ eval_mean_reward = av_rewards_dict(
+ loaded_agent.evaluation_path / f"average_reward_per_episode_{loaded_agent.timestamp_str}.csv"
+ )
+
+ # the agent config ran the evaluation in deterministic mode, so should have the same reward value
+ assert len(set(eval_mean_reward.values())) == 1
+
+ # the evaluation should be the same as a previous run
+ assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988
+
+ # delete the test directory
+ shutil.rmtree(test_path)
+
+
+def test_load_primaite_session():
+ """Test that loading a Primaite session works."""
+ expected_learn_mean_reward_per_episode = {
+ 10: 0,
+ 11: -0.008037109374999995,
+ 12: -0.007978515624999988,
+ 13: -0.008191406249999991,
+ 14: -0.00817578124999999,
+ 15: -0.008085937499999998,
+ 16: -0.007837890624999982,
+ 17: -0.007798828124999992,
+ 18: -0.007777343749999998,
+ 19: -0.007958984374999988,
+ 20: -0.0077499999999999835,
+ }
+
+ test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
+
+ # create loaded session
+ session = PrimaiteSession(session_path=test_path)
+
+ # run setup on session
+ session.setup()
+
+ # make sure that the session was loaded correctly
+ assert session._agent_session.uuid == "301874d3-2e14-43c2-ba7f-e2b03ad05dde"
+ assert session._agent_session._training_config.agent_framework == AgentFramework.SB3.name
+ assert session._agent_session._training_config.agent_identifier == AgentIdentifier.PPO.name
+ assert session._agent_session._training_config.deterministic
+ assert session._agent_session._training_config.seed == 12345
+ assert str(session._agent_session.session_path) == str(test_path)
+
+ # run another learn session
+ session.learn()
+
+ learn_mean_rewards = av_rewards_dict(
+ session.learning_path / f"average_reward_per_episode_{session.timestamp_str}.csv"
+ )
+
+ # run is seeded so should have the expected learn value
+ assert learn_mean_rewards == expected_learn_mean_reward_per_episode
+
+ # run an evaluation
+ session.evaluate()
+
+ # load the evaluation average reward csv file
+ eval_mean_reward = av_rewards_dict(
+ session.evaluation_path / f"average_reward_per_episode_{session.timestamp_str}.csv"
+ )
+
+ # the agent config ran the evaluation in deterministic mode, so should have the same reward value
+ assert len(set(eval_mean_reward.values())) == 1
+
+ # the evaluation should be the same as a previous run
+ assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988
+
+ # delete the test directory
+ shutil.rmtree(test_path)
+
+
+def test_run_loading():
+ """Test loading session via main.run."""
+ expected_learn_mean_reward_per_episode = {
+ 10: 0,
+ 11: -0.008037109374999995,
+ 12: -0.007978515624999988,
+ 13: -0.008191406249999991,
+ 14: -0.00817578124999999,
+ 15: -0.008085937499999998,
+ 16: -0.007837890624999982,
+ 17: -0.007798828124999992,
+ 18: -0.007777343749999998,
+ 19: -0.007958984374999988,
+ 20: -0.0077499999999999835,
+ }
+
+ test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
+
+ # create loaded session
+ run(session_path=test_path)
+
+ learn_mean_rewards = av_rewards_dict(
+ next(Path(test_path).rglob("**/learning/average_reward_per_episode_*.csv"), None)
+ )
+
+ # run is seeded so should have the expected learn value
+ assert learn_mean_rewards == expected_learn_mean_reward_per_episode
+
+ # delete the test directory
+ shutil.rmtree(test_path)
diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py
index bfcffd42..4f7af9a6 100644
--- a/tests/test_single_action_space.py
+++ b/tests/test_single_action_space.py
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import time
import pytest
diff --git a/tests/test_train_eval_episode_steps.py b/tests/test_train_eval_episode_steps.py
index b839e630..4f7bed16 100644
--- a/tests/test_train_eval_episode_steps.py
+++ b/tests/test_train_eval_episode_steps.py
@@ -1,3 +1,4 @@
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import pytest
from primaite import getLogger
diff --git a/tests/test_training_config.py b/tests/test_training_config.py
index d7fe4e50..4123ee39 100644
--- a/tests/test_training_config.py
+++ b/tests/test_training_config.py
@@ -1,4 +1,4 @@
-# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
+# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import yaml
from primaite.config import training_config