Merge branch 'dev' into feature/1623-typehints

This commit is contained in:
Marek Wolan
2023-07-18 10:03:48 +01:00
109 changed files with 1321 additions and 348 deletions

View File

@@ -86,5 +86,5 @@ stages:
displayName: 'Perform PrimAITE Setup'
- script: |
pytest tests/
pytest -n 4
displayName: 'Run tests'

3
.gitignore vendored
View File

@@ -50,6 +50,9 @@ coverage.xml
.hypothesis/
.pytest_cache/
cover/
tests/assets/**/*.png
tests/assets/**/tensorboard_logs/
tests/assets/**/checkpoints/
# Translations
*.mo

View File

@@ -1,3 +1,7 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
..
Credit to https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion for the custom templates.
..

View File

@@ -1,3 +1,7 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
..
Credit to https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion for the custom templates.
..

View File

@@ -1,3 +1,7 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
..
DO NOT DELETE THIS FILE! It contains the all-important `.. autosummary::` directive with `:recursive:` option, without
which API documentation wouldn't get extracted from docstrings by the `sphinx.ext.autosummary` engine. It is hidden

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Configuration file for the Sphinx documentation builder.
#
# For the full list of built-in configuration values, see the documentation:

View File

@@ -1,3 +1,7 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
Welcome to PrimAITE's documentation
====================================
@@ -20,17 +24,24 @@ What is PrimAITE built with
* `OpenAI's Gym <https://gym.openai.com/>`_ is used as the basis for AI blue agent interaction with the PrimAITE environment
* `Networkx <https://github.com/networkx/networkx>`_ is used as the underlying data structure used for the PrimAITE environment
* `Stable Baselines 3 <https://github.com/DLR-RM/stable-baselines3>`_ is used as a default source of RL algorithms (although PrimAITE is not limited to SB3 agents)
* `Ray RLlib <https://github.com/ray-project/ray>`_ is used as an additional source of RL algorithms
* `Typer <https://github.com/tiangolo/typer>`_ is used for building CLIs (Command Line Interface applications)
* `Jupyterlab <https://github.com/jupyterlab/jupyterlab>`_ is used as an extensible environment for interactive and reproducible computing, based on the Jupyter Notebook Architecture
* `Platformdirs <https://github.com/platformdirs/platformdirs>`_ is used for finding the right location to store user data and configuration but varies per platform
* `Plotly <https://github.com/plotly/plotly.py>`_ is used for building high level charts
Where next?
------------
The best place to start is :ref:`about`
Head over to the :ref:`getting-started` page to install and setup PrimAITE!
.. toctree::
:maxdepth: 8
:caption: Contents:
:hidden:
source/getting_started
source/about
source/config
source/primaite_session
@@ -41,12 +52,14 @@ The best place to start is :ref:`about`
source/glossary
source/migration_1.2_-_2.0
.. TODO: Add project links once public repo has been created
.. toctree::
:caption: Project Links:
:hidden:
..
#Code <>
#Issues <>
#Pull Requests <>
#Discussions <>
Code <https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE>
Issues <https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/issues>
Pull Requests <https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/pulls>
Discussions <https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/discussions>

View File

@@ -1,4 +1,8 @@
.. _about:
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
.. _about:
About PrimAITE
==============

View File

@@ -1,3 +1,7 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
.. _config:
The Config Files Explained

View File

@@ -1,4 +1,8 @@
Custom Agents
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
Custom Agents
=============

View File

@@ -1,3 +1,7 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
.. role:: raw-html(raw)
:format: html

View File

@@ -0,0 +1,155 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
.. _getting-started:
Getting Started
===============
**Getting Started with PrimAITE**
Pre-Requisites
In order to get **PrimAITE** installed, you will need to have a python version between 3.8 and 3.10 installed. If you don't already have it, this is how to install it:
.. tabs:: lang
.. code-tab:: bash
:caption: Unix
sudo add-apt-repository ppa:deadsnakes/ppa
sudo apt install python3.10
sudo apt-get install python3-pip
sudo apt-get install python3-venv
.. code-tab:: text
:caption: Windows (Powershell)
- Manual install from: https://www.python.org/downloads/release/python-31011/
**PrimAITE** is designed to be OS-agnostic, and thus should work on most variations/distros of Linux, Windows, and MacOS.
Install PrimAITE
****************
1. Create a primaite directory in your home directory:
.. tabs:: lang
.. code-tab:: bash
:caption: Unix
mkdir ~/primaite
.. code-tab:: powershell
:caption: Windows (Powershell)
mkdir ~\primaite
2. Navigate to the primaite directory and create a new python virtual environment (venv)
.. tabs:: lang
.. code-tab:: bash
:caption: Unix
cd ~/primaite
python3 -m venv .venv
.. code-tab:: powershell
:caption: Windows (Powershell)
cd ~\primaite
python3 -m venv .venv
attrib +h .venv /s /d # Hides the .venv directory
3. Activate the venv
.. tabs:: lang
.. code-tab:: bash
:caption: Unix
source .venv/bin/activate
.. code-tab:: powershell
:caption: Windows (Powershell)
.\.venv\Scripts\activate
4. Install PrimAITE using pip from PyPi
.. tabs:: lang
.. code-tab:: bash
:caption: Unix
pip install primaite
.. code-tab:: powershell
:caption: Windows (Powershell)
pip install primaite
5. Perform the PrimAITE setup
.. tabs:: lang
.. code-tab:: bash
:caption: Unix
primaite setup
.. code-tab:: powershell
:caption: Windows (Powershell)
primaite setup
Clone & Install PrimAITE for Development
****************************************
To be able to extend PrimAITE further, or to build wheels manually before install, clone the repository to a location
of your choice:
.. TODO:: Add repo path once we know what it is
.. code-block:: bash
git clone <repo path>
cd primaite
Create and activate your Python virtual environment (venv)
.. tabs:: lang
.. code-tab:: bash
:caption: Unix
python3 -m venv venv
source venv/bin/activate
.. code-tab:: powershell
:caption: Windows (Powershell)
python3 -m venv venv
.\venv\Scripts\activate
Install PrimAITE with the dev extra
.. tabs:: lang
.. code-tab:: bash
:caption: Unix
pip install -e .[dev]
.. code-tab:: powershell
:caption: Windows (Powershell)
pip install -e .[dev]
To view the complete list of packages installed during PrimAITE installation, go to the dependencies page (:ref:`Dependencies`).

View File

@@ -1,3 +1,7 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
Glossary
=============

View File

@@ -1,3 +1,7 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
v1.2 to v2.0 Migration guide
============================

View File

@@ -1,3 +1,7 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
.. _run a primaite session:
Run a PrimAITE Session
@@ -10,6 +14,7 @@ A PrimAITE session can be ran either with the ``primaite session`` command from
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` from a Python terminal or Jupyter Notebook.
Both the ``primaite session`` and :func:`primaite.main.run` take a training config and a lay down config as parameters.
.. tabs::
.. code-tab:: bash
@@ -19,7 +24,7 @@ Both the ``primaite session`` and :func:`primaite.main.run` take a training conf
source ./.venv/bin/activate
primaite session ./config/my_training_config.yaml ./config/my_lay_down_config.yaml
.. code-tab:: bash
.. code-tab:: powershell
:caption: Powershell CLI
cd ~\primaite
@@ -42,6 +47,37 @@ The sub-directory is formatted as such: ``~/primaite/sessions/<yyyy-mm-dd>/<yyyy
For example, when running a session at 17:30:00 on 31st January 2023, the session will output to:
``~/primaite/sessions/2023-01-31/2023-01-31_17-30-00/``.
Loading a session
-----------------
A previous session can be loaded by providing the **directory** of the previous session to either the ``primaite session`` command from the cli
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` with session_path.
.. tabs::
.. code-tab:: bash
:caption: Unix CLI
cd ~/primaite
source ./.venv/bin/activate
primaite session --load "path/to/session"
.. code-tab:: bash
:caption: Powershell CLI
cd ~\primaite
.\.venv\Scripts\activate
primaite session --load "path\to\session"
.. code-tab:: python
:caption: Python
from primaite.main import run
run(session_path=<previous session directory>)
When PrimAITE runs a loaded session, PrimAITE will output in the provided session directory
Outputs
-------

View File

@@ -58,6 +58,7 @@ dev = [
"pip-licenses==4.3.0",
"pre-commit==2.20.0",
"pytest==7.2.0",
"pytest-xdist==3.3.1",
"pytest-cov==4.0.0",
"pytest-flake8==1.1.1",
"setuptools==66",

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from setuptools import setup
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel # noqa

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import logging
import logging.config
import sys

View File

@@ -1,2 +1,2 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Access Control List. Models firewall functionality."""

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""A class that implements the access control list implementation for the network."""
from typing import Dict

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""A class that implements an access control list rule."""

View File

@@ -1 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Common interface between RL agents from different libraries and PrimAITE."""

View File

@@ -1,28 +1,24 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from __future__ import annotations
import json
import time
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Final, TYPE_CHECKING, Union
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
from uuid import uuid4
import yaml
import primaite
from primaite import getLogger, SESSIONS_DIR
from primaite.config import lay_down_config, training_config
from primaite.config.training_config import TrainingConfig
from primaite.data_viz.session_plots import plot_av_reward_per_episode
from primaite.environment.primaite_env import Primaite
from primaite.utils.session_metadata_parser import parse_session_metadata
if TYPE_CHECKING:
from logging import Logger
import numpy as np
_LOGGER: "Logger" = getLogger(__name__)
@@ -53,38 +49,63 @@ class AgentSessionABC(ABC):
"""
@abstractmethod
def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = None,
lay_down_config_path: Optional[Union[str, Path]] = None,
session_path: Optional[Union[str, Path]] = None,
) -> None:
"""
Initialise an agent session from config files.
Initialise an agent session from config files, or load a previous session.
If training configuration and laydown configuration are provided with a session path,
the session path will be used.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param session_path: directory path of the session to load
"""
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path, str]] = training_config_path
self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path)
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
# initialise variables
self._env: Primaite
self._agent = None
self._can_learn: bool = False
self._can_evaluate: bool = False
self.is_eval = False
self._uuid = str(uuid4())
self.session_timestamp: datetime = datetime.now()
"The session timestamp"
self.session_path = get_session_path(self.session_timestamp)
"The Session path"
# convert session to path
if session_path is not None:
if not isinstance(session_path, Path):
session_path = Path(session_path)
# if a session path is provided, load it
if not session_path.exists():
raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
# load session
self.load(session_path)
else:
# set training config path
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Union[Path, str] = training_config_path
self._training_config: TrainingConfig = training_config.load(self._training_config_path)
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Union[Path, str] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
# set random UUID for session
self._uuid = str(uuid4())
"The session timestamp"
self.session_path = get_session_path(self.session_timestamp)
"The Session path"
@property
def timestamp_str(self) -> str:
@@ -233,51 +254,27 @@ class AgentSessionABC(ABC):
def _get_latest_checkpoint(self) -> None:
pass
@classmethod
@abstractmethod
def load(cls, path: Union[str, Path]) -> AgentSessionABC:
def load(self, path: Union[str, Path]):
"""Load an agent from file."""
if not isinstance(path, Path):
path = Path(path)
md_dict, training_config_path, laydown_config_path = parse_session_metadata(path)
if path.exists():
# Unpack the session_metadata.json file
md_file = path / "session_metadata.json"
with open(md_file, "r") as file:
md_dict = json.load(file)
# set training config path
self._training_config_path: Union[Path, str] = training_config_path
self._training_config: TrainingConfig = training_config.load(self._training_config_path)
self._lay_down_config_path: Union[Path, str] = laydown_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
# Create a temp directory and dump the training and lay down
# configs into it
temp_dir = path / ".temp"
temp_dir.mkdir(exist_ok=True)
# set random UUID for session
self._uuid = md_dict["uuid"]
temp_tc = temp_dir / "tc.yaml"
with open(temp_tc, "w") as file:
yaml.dump(md_dict["env"]["training_config"], file)
temp_ldc = temp_dir / "ldc.yaml"
with open(temp_ldc, "w") as file:
yaml.dump(md_dict["env"]["lay_down_config"], file)
agent = cls(temp_tc, temp_ldc)
agent.session_path = path
return agent
else:
# Session path does not exist
msg = f"Failed to load PrimAITE Session, path does not exist: {path}"
_LOGGER.error(msg)
raise FileNotFoundError(msg)
# set the session path
self.session_path = path
"The Session path"
@property
def _saved_agent_path(self) -> Path:
file_name = (
f"{self._training_config.agent_framework}_"
f"{self._training_config.agent_identifier}_"
f"{self.timestamp_str}.zip"
)
file_name = f"{self._training_config.agent_framework}_" f"{self._training_config.agent_identifier}" f".zip"
return self.learning_path / file_name
@abstractmethod
@@ -313,104 +310,3 @@ class AgentSessionABC(ABC):
fig = plot_av_reward_per_episode(path, title, subtitle)
fig.write_image(image_path)
_LOGGER.debug(f"Saved average rewards per episode plot to: {path}")
class HardCodedAgentSessionABC(AgentSessionABC):
"""
An Agent Session ABC for evaluation deterministic agents.
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
implemented.
"""
def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
"""
Initialise a hardcoded agent session.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
"""
super().__init__(training_config_path, lay_down_config_path)
self._setup()
def _setup(self) -> None:
self._env: Primaite = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
)
super()._setup()
self._can_learn = False
self._can_evaluate = True
def _save_checkpoint(self) -> None:
pass
def _get_latest_checkpoint(self) -> None:
pass
def learn(
self,
**kwargs: Any,
) -> None:
"""
Train the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
_LOGGER.warning("Deterministic agents cannot learn")
@abstractmethod
def _calculate_action(self, obs: np.ndarray) -> None:
pass
def evaluate(
self,
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
self._env.set_as_eval() # noqa
self.is_eval = True
time_steps = self._training_config.num_eval_steps
episodes = self._training_config.num_eval_episodes
obs = self._env.reset()
for episode in range(episodes):
# Reset env and collect initial observation
for step in range(time_steps):
# Calculate action
action = self._calculate_action(obs)
# Perform the step
obs, reward, done, info = self._env.step(action)
if done:
break
# Introduce a delay between steps
time.sleep(self._training_config.time_delay / 1000)
obs = self._env.reset()
self._env.close()
super().evaluate()
@classmethod
def load(cls) -> None:
"""Load an agent from file."""
_LOGGER.warning("Deterministic agents cannot be loaded")
def save(self) -> None:
"""Save the agent."""
_LOGGER.warning("Deterministic agents cannot be saved")
def export(self) -> None:
"""Export the agent to transportable file format."""
_LOGGER.warning("Deterministic agents cannot be exported")

View File

@@ -0,0 +1,116 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import time
from abc import abstractmethod
from pathlib import Path
from typing import Optional, Union
from primaite import getLogger
from primaite.agents.agent_abc import AgentSessionABC
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
class HardCodedAgentSessionABC(AgentSessionABC):
"""
An Agent Session ABC for evaluation deterministic agents.
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
implemented.
"""
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
):
"""
Initialise a hardcoded agent session.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
"""
super().__init__(training_config_path, lay_down_config_path, session_path)
self._setup()
def _setup(self):
self._env: Primaite = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
)
super()._setup()
self._can_learn = False
self._can_evaluate = True
def _save_checkpoint(self):
pass
def _get_latest_checkpoint(self):
pass
def learn(
self,
**kwargs,
):
"""
Train the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
_LOGGER.warning("Deterministic agents cannot learn")
@abstractmethod
def _calculate_action(self, obs):
pass
def evaluate(
self,
**kwargs,
):
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
self._env.set_as_eval() # noqa
self.is_eval = True
time_steps = self._training_config.num_eval_steps
episodes = self._training_config.num_eval_episodes
obs = self._env.reset()
for episode in range(episodes):
# Reset env and collect initial observation
for step in range(time_steps):
# Calculate action
action = self._calculate_action(obs)
# Perform the step
obs, reward, done, info = self._env.step(action)
if done:
break
# Introduce a delay between steps
time.sleep(self._training_config.time_delay / 1000)
obs = self._env.reset()
self._env.close()
@classmethod
def load(cls, path=None):
"""Load an agent from file."""
_LOGGER.warning("Deterministic agents cannot be loaded")
def save(self):
"""Save the agent."""
_LOGGER.warning("Deterministic agents cannot be saved")
def export(self):
"""Export the agent to transportable file format."""
_LOGGER.warning("Deterministic agents cannot be exported")

View File

@@ -1,10 +1,11 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from typing import Dict, List, Union
import numpy as np
from primaite.acl.access_control_list import AccessControlList
from primaite.acl.acl_rule import ACLRule
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import (
get_new_action,
get_node_of_ip,

View File

@@ -1,6 +1,7 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import numpy as np
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable

View File

@@ -1,10 +1,11 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from __future__ import annotations
import json
import shutil
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Union
from uuid import uuid4
from ray.rllib.algorithms import Algorithm
@@ -14,7 +15,7 @@ from ray.tune.logger import UnifiedLogger
from ray.tune.registry import register_env
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
@@ -48,7 +49,12 @@ def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]:
class RLlibAgent(AgentSessionABC):
"""An AgentSession class that implements a Ray RLlib agent."""
def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
) -> None:
"""
Initialise the RLLib Agent training session.
@@ -61,6 +67,13 @@ class RLlibAgent(AgentSessionABC):
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
or `A2C`)
"""
# TODO: implement RLlib agent loading
if session_path is not None:
msg = "RLlib agent loading has not been implemented yet"
_LOGGER.error(msg)
print(msg)
raise NotImplementedError
super().__init__(training_config_path, lay_down_config_path)
if not self._training_config.agent_framework == AgentFramework.RLLIB:
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"

View File

@@ -1,14 +1,16 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, TYPE_CHECKING, Union
from typing import Any, Optional, TYPE_CHECKING, Union
import numpy as np
from stable_baselines3 import A2C, PPO
from stable_baselines3.ppo import MlpPolicy as PPOMlp
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
@@ -21,7 +23,12 @@ _LOGGER: "Logger" = getLogger(__name__)
class SB3Agent(AgentSessionABC):
"""An AgentSession class that implements a Stable Baselines3 agent."""
def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = None,
lay_down_config_path: Optional[Union[str, Path]] = None,
session_path: Optional[Union[str, Path]] = None,
) -> None:
"""
Initialise the SB3 Agent training session.
@@ -34,7 +41,7 @@ class SB3Agent(AgentSessionABC):
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
or `A2C`)
"""
super().__init__(training_config_path, lay_down_config_path)
super().__init__(training_config_path, lay_down_config_path, session_path)
if not self._training_config.agent_framework == AgentFramework.SB3:
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
@@ -51,7 +58,7 @@ class SB3Agent(AgentSessionABC):
self._tensorboard_log_path = self.learning_path / "tensorboard_logs"
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
self._setup()
_LOGGER.debug(
f"Created {self.__class__.__name__} using: "
f"agent_framework={self._training_config.agent_framework}, "
@@ -61,8 +68,10 @@ class SB3Agent(AgentSessionABC):
self.is_eval = False
self._setup()
def _setup(self) -> None:
super()._setup()
"""Set up the SB3 Agent."""
self._env = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
@@ -70,14 +79,43 @@ class SB3Agent(AgentSessionABC):
timestamp_str=self.timestamp_str,
)
self._agent = self._agent_class(
PPOMlp,
self._env,
verbose=self.sb3_output_verbose_level,
n_steps=self._training_config.num_train_steps,
tensorboard_log=str(self._tensorboard_log_path),
seed=self._training_config.seed,
)
# check if there is a zip file that needs to be loaded
load_file = next(self.session_path.rglob("*.zip"), None)
if not load_file:
# create a new env and agent
self._agent = self._agent_class(
PPOMlp,
self._env,
verbose=self.sb3_output_verbose_level,
n_steps=self._training_config.num_train_steps,
tensorboard_log=str(self._tensorboard_log_path),
seed=self._training_config.seed,
)
else:
# set env values from session metadata
with open(self.session_path / "session_metadata.json", "r") as file:
md_dict = json.load(file)
# load environment values
if self.is_eval:
# evaluation always starts at 0
self._env.episode_count = 0
self._env.total_step_count = 0
else:
# carry on from previous learning sessions
self._env.episode_count = md_dict["learning"]["total_episodes"]
self._env.total_step_count = md_dict["learning"]["total_time_steps"]
# load the file
self._agent = self._agent_class.load(load_file, env=self._env)
# set agent values
self._agent.verbose = self.sb3_output_verbose_level
self._agent.tensorboard_log = self.session_path / "learning/tensorboard_logs"
super()._setup()
def _save_checkpoint(self) -> None:
checkpoint_n = self._training_config.checkpoint_every_n_episodes
@@ -149,11 +187,6 @@ class SB3Agent(AgentSessionABC):
self._env.close()
super().evaluate()
@classmethod
def load(cls, path: Union[str, Path]) -> SB3Agent:
"""Load an agent from file."""
raise NotImplementedError
def save(self) -> None:
"""Save the agent."""
self._agent.save(self._saved_agent_path)

View File

@@ -1,6 +1,7 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from typing import TYPE_CHECKING
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
if TYPE_CHECKING:

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from typing import Dict, List, Union
import numpy as np

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Provides a CLI using Typer as an entry point."""
import logging
import os
@@ -151,7 +151,7 @@ def setup(overwrite_existing: bool = True) -> None:
@app.command()
def session(tc: Optional[str] = None, ldc: Optional[str] = None) -> None:
def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[str] = None) -> None:
"""
Run a PrimAITE session.
@@ -162,11 +162,19 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None) -> None:
ldc: The lay down config file path. Optional. If no value is passed then
example default lay down config is used from:
~/primaite/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml.
load: The directory of a previous session. Optional. If no value is passed, then the session
will use the default training config and laydown config. Inversely, if a training config and laydown config
is passed while a session directory is passed, PrimAITE will load the session and ignore the training config
and laydown config.
"""
from primaite.config.lay_down_config import dos_very_basic_config_path
from primaite.config.training_config import main_training_config_path
from primaite.main import run
if load is not None:
run(session_path=load)
if not tc:
tc = main_training_config_path()

View File

@@ -1,2 +1,2 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Objects which are shared between many PrimAITE modules."""

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Enumerations for APE."""
from enum import Enum, IntEnum

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The protocol class."""

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The Service class."""
from primaite.common.enums import SoftwareState

View File

@@ -1 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Configuration parameters for running experiments."""

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from pathlib import Path
from typing import Any, Dict, Final, TYPE_CHECKING, Union

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from __future__ import annotations
from dataclasses import dataclass, field

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Utility to generate plots of sessions metrics after PrimAITE."""
from enum import Enum

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from pathlib import Path
from typing import Dict, Optional, Union

View File

@@ -1,2 +1,2 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Gym/Gymnasium environment for RL agents consisting of a simulated computer network."""

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Module for handling configurable observation spaces in PrimAITE."""
import logging
from abc import ABC, abstractmethod

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Main environment module containing the PRIMmary AI Training Evironment (Primaite) class."""
import copy
import logging

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Implements reward function."""
from typing import Dict, TYPE_CHECKING, Union

View File

@@ -1,2 +1,2 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Network connections between nodes in the simulation."""

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The link class."""
from typing import List

View File

@@ -1,8 +1,8 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The main PrimAITE session runner module."""
import argparse
from pathlib import Path
from typing import Union
from typing import Optional, Union
from primaite import getLogger
from primaite.primaite_session import PrimaiteSession
@@ -11,16 +11,21 @@ _LOGGER = getLogger(__name__)
def run(
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
) -> None:
"""
Run the PrimAITE Session.
:param training_config_path: The training config filepath.
:param lay_down_config_path: The lay down config filepath.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param session_path: directory path of the session to load
"""
session = PrimaiteSession(training_config_path, lay_down_config_path)
session = PrimaiteSession(training_config_path, lay_down_config_path, session_path)
session.setup()
session.learn()
@@ -31,9 +36,14 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tc")
parser.add_argument("--ldc")
parser.add_argument("--load")
args = parser.parse_args()
if not args.tc:
_LOGGER.error("Please provide a training config file using the --tc " "argument")
if not args.ldc:
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
if args.load:
run(session_path=args.load)
else:
if not args.tc:
_LOGGER.error("Please provide a training config file using the --tc " "argument")
if not args.ldc:
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
run(training_config_path=args.tc, lay_down_config_path=args.ldc)

View File

@@ -1,2 +1,2 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Nodes represent network hosts in the simulation."""

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""An Active Node (i.e. not an actuator)."""
import logging
from typing import Final

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The base Node class."""
from typing import Final

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Defines node behaviour for Green PoL."""
from typing import TYPE_CHECKING, Union

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Defines node behaviour for Green PoL."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Union

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The Passive Node class (i.e. an actuator)."""
from primaite.common.enums import HardwareState, NodeType, Priority
from primaite.config.training_config import TrainingConfig

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""A Service Node (i.e. not an actuator)."""
import logging
from typing import Dict, Final

View File

@@ -1,5 +1,6 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Contains default jupyter notebooks which demonstrate PrimAITE functionality."""
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
import importlib.util
import os
import subprocess

View File

@@ -1,2 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Pattern of Life- Represents the actions of users on the network."""
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Implements Pattern of Life on the network (nodes and links)."""
from typing import Dict

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""
Information Exchange Requirements for APE.

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Implements POL on the network (nodes and links) resulting from the red agent attack."""
from typing import Dict

View File

@@ -1,11 +1,12 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Main entry point to PrimAITE. Configure training/evaluation experiments and input/output."""
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, Final, Union
from typing import Any, Dict, Final, Optional, Union
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.agents.agent_abc import AgentSessionABC
from primaite.agents.hardcoded_acl import HardCodedACLAgent
from primaite.agents.hardcoded_node import HardCodedNodeAgent
from primaite.agents.rllib import RLlibAgent
@@ -14,6 +15,7 @@ from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyA
from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType
from primaite.config import lay_down_config, training_config
from primaite.config.training_config import TrainingConfig
from primaite.utils.session_metadata_parser import parse_session_metadata
_LOGGER = getLogger(__name__)
@@ -27,15 +29,39 @@ class PrimaiteSession:
def __init__(
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
) -> None:
"""
The PrimaiteSession constructor.
:param training_config_path: The training config path.
:param lay_down_config_path: The lay down config path.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param session_path: directory path of the session to load
"""
self._agent_session: AgentSessionABC = None # noqa
self.session_path: Path = session_path # noqa
self.timestamp_str: str = None # noqa
self.learning_path: Path = None # noqa
self.evaluation_path: Path = None # noqa
# check if session path is provided
if session_path is not None:
# set load_session to true
self.is_load_session = True
if not isinstance(session_path, Path):
session_path = Path(session_path)
# if a session path is provided, load it
if not session_path.exists():
raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
md_dict, training_config_path, lay_down_config_path = parse_session_metadata(session_path)
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path, str]] = training_config_path
@@ -60,11 +86,15 @@ class PrimaiteSession:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}")
if self._training_config.action_type == ActionType.NODE:
# Deterministic Hardcoded Agent with Node Action Space
self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path)
self._agent_session = HardCodedNodeAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
self._agent_session = HardCodedACLAgent(self._training_config_path, self._lay_down_config_path)
self._agent_session = HardCodedACLAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.action_type == ActionType.ANY:
# Deterministic Hardcoded Agent with ANY Action Space
@@ -77,11 +107,15 @@ class PrimaiteSession:
elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHING}")
if self._training_config.action_type == ActionType.NODE:
self._agent_session = DoNothingNodeAgent(self._training_config_path, self._lay_down_config_path)
self._agent_session = DoNothingNodeAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
self._agent_session = DoNothingACLAgent(self._training_config_path, self._lay_down_config_path)
self._agent_session = DoNothingACLAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.action_type == ActionType.ANY:
# Deterministic Hardcoded Agent with ANY Action Space
@@ -93,10 +127,14 @@ class PrimaiteSession:
elif self._training_config.agent_identifier == AgentIdentifier.RANDOM:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}")
self._agent_session = RandomAgent(self._training_config_path, self._lay_down_config_path)
self._agent_session = RandomAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.agent_identifier == AgentIdentifier.DUMMY:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}")
self._agent_session = DummyAgent(self._training_config_path, self._lay_down_config_path)
self._agent_session = DummyAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
else:
# Invalid AgentFramework AgentIdentifier combo
@@ -105,12 +143,12 @@ class PrimaiteSession:
elif self._training_config.agent_framework == AgentFramework.SB3:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}")
# Stable Baselines3 Agent
self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path)
self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path, self.session_path)
elif self._training_config.agent_framework == AgentFramework.RLLIB:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}")
# Ray RLlib Agent
self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path)
self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path, self.session_path)
else:
# Invalid AgentFramework

View File

@@ -1,2 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Utilities to prepare the user's data folders."""
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from typing import TYPE_CHECKING
from primaite import getLogger

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import filecmp
import os
import shutil

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import filecmp
import os
import shutil

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from typing import TYPE_CHECKING
from primaite import _USER_DIRS, getLogger, LOG_DIR, NOTEBOOKS_DIR

View File

@@ -1,2 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Record data of the system's state and agent's observations and actions."""
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The Transaction class."""
from datetime import datetime
from typing import List, Optional, Tuple, TYPE_CHECKING, Union

View File

@@ -1 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Utilities for PrimAITE."""

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import os
from pathlib import Path
from typing import TYPE_CHECKING

View File

@@ -0,0 +1,59 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import json
from pathlib import Path
from typing import Union
import yaml
from primaite import getLogger
_LOGGER = getLogger(__name__)
def parse_session_metadata(session_path: Union[Path, str], dict_only=False):
"""
Loads a session metadata from the given directory path.
:param session_path: Directory where the session metadata file is in
:param dict_only: If dict_only is true, the function will only return the dict contents of session metadata
:return: Dictionary which has all the session metadata contents
:rtype: Dict
:return: Path where the YAML copy of the training config is dumped into
:rtype: str
:return: Path where the YAML copy of the laydown config is dumped into
:rtype: str
"""
if not isinstance(session_path, Path):
session_path = Path(session_path)
if not session_path.exists():
# Session path does not exist
msg = f"Failed to load PrimAITE Session, path does not exist: {session_path}"
_LOGGER.error(msg)
raise FileNotFoundError(msg)
# Unpack the session_metadata.json file
md_file = session_path / "session_metadata.json"
with open(md_file, "r") as file:
md_dict = json.load(file)
# if dict only, return dict without doing anything else
if dict_only:
return md_dict
# Create a temp directory and dump the training and lay down
# configs into it
temp_dir = session_path / ".temp"
temp_dir.mkdir(exist_ok=True)
temp_tc = temp_dir / "tc.yaml"
with open(temp_tc, "w") as file:
yaml.dump(md_dict["env"]["training_config"], file)
temp_ldc = temp_dir / "ldc.yaml"
with open(temp_ldc, "w") as file:
yaml.dump(md_dict["env"]["lay_down_config"], file)
return [md_dict, temp_tc, temp_ldc]

View File

@@ -1,5 +1,6 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from pathlib import Path
from typing import Dict, Union
from typing import Any, Dict, Tuple, Union
# Using polars as it's faster than Pandas; it will speed things up when
# files get big!
@@ -13,8 +14,33 @@ def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
The dictionary keys are the episode number, and the values are the mean reward that episode.
:param av_rewards_csv_file: The average rewards per episode csv file path.
:return: The average rewards per episode cdv as a dict.
:return: The average rewards per episode csv as a dict.
"""
df = pl.read_csv(av_rewards_csv_file).to_dict()
df_dict = pl.read_csv(av_rewards_csv_file).to_dict()
return {v: df["Average Reward"][i] for i, v in enumerate(df["Episode"])}
return {v: df_dict["Average Reward"][i] for i, v in enumerate(df_dict["Episode"])}
def all_transactions_dict(all_transactions_csv_file: Union[str, Path]) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""
Read an all transactions csv file and return as a dict.
The dict keys are a tuple with the structure (episode, step). The dict
values are the remaining columns as a dict.
:param all_transactions_csv_file: The all transactions csv file path.
:return: The all transactions csv file as a dict.
"""
df_dict = pl.read_csv(all_transactions_csv_file).to_dict()
new_dict = {}
episodes = df_dict["Episode"]
steps = df_dict["Step"]
keys = list(df_dict.keys())
for i in range(len(episodes)):
key = (episodes[i], steps[i])
value_dict = {key: df_dict[key][i] for key in keys if key not in ["Episode", "Step"]}
new_dict[key] = value_dict
return new_dict

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import csv
from logging import Logger
from typing import Final, List, Tuple, TYPE_CHECKING, Union

View File

@@ -1,6 +1,9 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from pathlib import Path
from typing import Final
TEST_CONFIG_ROOT: Final[Path] = Path(__file__).parent / "config"
"The tests config root directory."
TEST_ASSETS_ROOT: Final[Path] = Path(__file__).parent / "assets"
"The tests assets root directory."

File diff suppressed because one or more lines are too long

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Main Config File
# Generic config values

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Main Config File
# Generic config values

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
- item_type: PORTS
ports_list:
- port: '80'

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
- item_type: PORTS
ports_list:
- port: '21'

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.
@@ -7,6 +8,14 @@
# "CUSTOM" (Custom Agent)
agent_framework: CUSTOM
# Sets which deep learning framework will be used (by RLlib ONLY).
# Default is TF (Tensorflow).
# Options are:
# "TF" (Tensorflow)
# TF2 (Tensorflow 2.X)
# TORCH (PyTorch)
deep_learning_framework: TF2
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
@@ -17,32 +26,78 @@ agent_framework: CUSTOM
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: DUMMY
# Sets whether Red Agent POL and IER is randomised.
# Options are:
# True
# False
random_red_agent: False
# The (integer) seed to be used in random number generation
# Default is None (null)
seed: null
# Set whether the agent will be deterministic instead of stochastic
# Options are:
# True
# False
deterministic: False
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are:
# "BASIC" (The current observation space only)
# "FULL" (Full environment view with actions taken and reward feedback)
hard_coded_agent_view: FULL
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# Number of episodes for training to run per session
num_train_episodes: 10
# Number of time_steps for training per episode
num_train_steps: 256
# Number of episodes for evaluation to run per session
num_eval_episodes: 1
# Number of time_steps for evaluation per episode
num_eval_steps: 15
# Time delay between steps (for generic agents)
time_delay: 1
# Type of session to be run (TRAINING or EVALUATION)
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10
checkpoint_every_n_episodes: 10
# Time delay (milliseconds) between steps for CUSTOM agents.
time_delay: 5
# Type of session to be run. Options are:
# "TRAIN" (Trains an agent)
# "EVAL" (Evaluates an agent)
# "TRAIN_EVAL" (Trains then evaluates an agent)
session_type: EVAL
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)
# "INFO" (Info Messages (such as devices and wrappers used))
# "DEBUG" (All Messages)
sb3_output_verbose_level: NONE
# Reward values
# Generic
all_ok: 0

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
- item_type: PORTS
ports_list:
- port: '21'

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.
@@ -5,7 +6,15 @@
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: CUSTOM
agent_framework: SB3
# Sets which deep learning framework will be used (by RLlib ONLY).
# Default is TF (Tensorflow).
# Options are:
# "TF" (Tensorflow)
# TF2 (Tensorflow 2.X)
# TORCH (PyTorch)
deep_learning_framework: TF2
# Sets which Agent class will be used.
# Options are:
@@ -15,7 +24,7 @@ agent_framework: CUSTOM
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: DUMMY
agent_identifier: PPO
# Sets whether Red Agent POL and IER is randomised.
# Options are:
@@ -23,92 +32,128 @@ agent_identifier: DUMMY
# False
random_red_agent: True
# The (integer) seed to be used in random number generation
# Default is None (null)
seed: null
# Set whether the agent will be deterministic instead of stochastic
# Options are:
# True
# False
deterministic: False
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are:
# "BASIC" (The current observation space only)
# "FULL" (Full environment view with actions taken and reward feedback)
hard_coded_agent_view: FULL
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# Number of episodes for training to run per session
num_train_episodes: 2
num_train_episodes: 10
# Number of time_steps for training per episode
num_train_steps: 15
num_train_steps: 256
# Number of episodes for evaluation to run per session
num_eval_episodes: 2
num_eval_episodes: 1
# Number of time_steps for evaluation per episode
num_eval_steps: 15
# Time delay between steps (for generic agents)
time_delay: 1
num_eval_steps: 256
# Type of session to be run (TRAINING or EVALUATION)
session_type: EVAL
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10
checkpoint_every_n_episodes: 10
# Time delay (milliseconds) between steps for CUSTOM agents.
time_delay: 5
# Type of session to be run. Options are:
# "TRAIN" (Trains an agent)
# "EVAL" (Evaluates an agent)
# "TRAIN_EVAL" (Trains then evaluates an agent)
session_type: TRAIN_EVAL
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)
# "INFO" (Info Messages (such as devices and wrappers used))
# "DEBUG" (All Messages)
sb3_output_verbose_level: NONE
# Reward values
# Generic
all_ok: 0
# Node Hardware State
off_should_be_on: -10
off_should_be_resetting: -5
on_should_be_off: -2
on_should_be_resetting: -5
resetting_should_be_on: -5
resetting_should_be_off: -2
resetting: -3
off_should_be_on: -0.001
off_should_be_resetting: -0.0005
on_should_be_off: -0.0002
on_should_be_resetting: -0.0005
resetting_should_be_on: -0.0005
resetting_should_be_off: -0.0002
resetting: -0.0003
# Node Software or Service State
good_should_be_patching: 2
good_should_be_compromised: 5
good_should_be_overwhelmed: 5
patching_should_be_good: -5
patching_should_be_compromised: 2
patching_should_be_overwhelmed: 2
patching: -3
compromised_should_be_good: -20
compromised_should_be_patching: -20
compromised_should_be_overwhelmed: -20
compromised: -20
overwhelmed_should_be_good: -20
overwhelmed_should_be_patching: -20
overwhelmed_should_be_compromised: -20
overwhelmed: -20
good_should_be_patching: 0.0002
good_should_be_compromised: 0.0005
good_should_be_overwhelmed: 0.0005
patching_should_be_good: -0.0005
patching_should_be_compromised: 0.0002
patching_should_be_overwhelmed: 0.0002
patching: -0.0003
compromised_should_be_good: -0.002
compromised_should_be_patching: -0.002
compromised_should_be_overwhelmed: -0.002
compromised: -0.002
overwhelmed_should_be_good: -0.002
overwhelmed_should_be_patching: -0.002
overwhelmed_should_be_compromised: -0.002
overwhelmed: -0.002
# Node File System State
good_should_be_repairing: 2
good_should_be_restoring: 2
good_should_be_corrupt: 5
good_should_be_destroyed: 10
repairing_should_be_good: -5
repairing_should_be_restoring: 2
repairing_should_be_corrupt: 2
repairing_should_be_destroyed: 0
repairing: -3
restoring_should_be_good: -10
restoring_should_be_repairing: -2
restoring_should_be_corrupt: 1
restoring_should_be_destroyed: 2
restoring: -6
corrupt_should_be_good: -10
corrupt_should_be_repairing: -10
corrupt_should_be_restoring: -10
corrupt_should_be_destroyed: 2
corrupt: -10
destroyed_should_be_good: -20
destroyed_should_be_repairing: -20
destroyed_should_be_restoring: -20
destroyed_should_be_corrupt: -20
destroyed: -20
scanning: -2
good_should_be_repairing: 0.0002
good_should_be_restoring: 0.0002
good_should_be_corrupt: 0.0005
good_should_be_destroyed: 0.001
repairing_should_be_good: -0.0005
repairing_should_be_restoring: 0.0002
repairing_should_be_corrupt: 0.0002
repairing_should_be_destroyed: 0.0000
repairing: -0.0003
restoring_should_be_good: -0.001
restoring_should_be_repairing: -0.0002
restoring_should_be_corrupt: 0.0001
restoring_should_be_destroyed: 0.0002
restoring: -0.0006
corrupt_should_be_good: -0.001
corrupt_should_be_repairing: -0.001
corrupt_should_be_restoring: -0.001
corrupt_should_be_destroyed: 0.0002
corrupt: -0.001
destroyed_should_be_good: -0.002
destroyed_should_be_repairing: -0.002
destroyed_should_be_restoring: -0.002
destroyed_should_be_corrupt: -0.002
destroyed: -0.002
scanning: -0.0002
# IER status
red_ier_running: -5
green_ier_blocked: -10
red_ier_running: -0.0005
green_ier_blocked: -0.001
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -0,0 +1,164 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: RLLIB
# Sets which deep learning framework will be used (by RLlib ONLY).
# Default is TF (Tensorflow).
# Options are:
# "TF" (Tensorflow)
# TF2 (Tensorflow 2.X)
# TORCH (PyTorch)
deep_learning_framework: TF2
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: PPO
# Sets whether Red Agent POL and IER is randomised.
# Options are:
# True
# False
random_red_agent: False
# The (integer) seed to be used in random number generation
# Default is None (null)
seed: null
# Set whether the agent will be deterministic instead of stochastic
# Options are:
# True
# False
deterministic: False
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are:
# "BASIC" (The current observation space only)
# "FULL" (Full environment view with actions taken and reward feedback)
hard_coded_agent_view: FULL
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# Number of episodes for training to run per session
num_train_episodes: 10
# Number of time_steps for training per episode
num_train_steps: 256
# Number of episodes for evaluation to run per session
num_eval_episodes: 1
# Number of time_steps for evaluation per episode
num_eval_steps: 256
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10
checkpoint_every_n_episodes: 10
# Time delay (milliseconds) between steps for CUSTOM agents.
time_delay: 5
# Type of session to be run. Options are:
# "TRAIN" (Trains an agent)
# "EVAL" (Evaluates an agent)
# "TRAIN_EVAL" (Trains then evaluates an agent)
session_type: TRAIN_EVAL
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)
# "INFO" (Info Messages (such as devices and wrappers used))
# "DEBUG" (All Messages)
sb3_output_verbose_level: NONE
# Reward values
# Generic
all_ok: 0
# Node Hardware State
off_should_be_on: -0.001
off_should_be_resetting: -0.0005
on_should_be_off: -0.0002
on_should_be_resetting: -0.0005
resetting_should_be_on: -0.0005
resetting_should_be_off: -0.0002
resetting: -0.0003
# Node Software or Service State
good_should_be_patching: 0.0002
good_should_be_compromised: 0.0005
good_should_be_overwhelmed: 0.0005
patching_should_be_good: -0.0005
patching_should_be_compromised: 0.0002
patching_should_be_overwhelmed: 0.0002
patching: -0.0003
compromised_should_be_good: -0.002
compromised_should_be_patching: -0.002
compromised_should_be_overwhelmed: -0.002
compromised: -0.002
overwhelmed_should_be_good: -0.002
overwhelmed_should_be_patching: -0.002
overwhelmed_should_be_compromised: -0.002
overwhelmed: -0.002
# Node File System State
good_should_be_repairing: 0.0002
good_should_be_restoring: 0.0002
good_should_be_corrupt: 0.0005
good_should_be_destroyed: 0.001
repairing_should_be_good: -0.0005
repairing_should_be_restoring: 0.0002
repairing_should_be_corrupt: 0.0002
repairing_should_be_destroyed: 0.0000
repairing: -0.0003
restoring_should_be_good: -0.001
restoring_should_be_repairing: -0.0002
restoring_should_be_corrupt: 0.0001
restoring_should_be_destroyed: 0.0002
restoring: -0.0006
corrupt_should_be_good: -0.001
corrupt_should_be_repairing: -0.001
corrupt_should_be_restoring: -0.001
corrupt_should_be_destroyed: 0.0002
corrupt: -0.001
destroyed_should_be_good: -0.002
destroyed_should_be_repairing: -0.002
destroyed_should_be_restoring: -0.002
destroyed_should_be_corrupt: -0.002
destroyed: -0.002
scanning: -0.0002
# IER status
red_ier_running: -0.0005
green_ier_blocked: -0.001
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS
node_reset_duration: 5 # The time taken to reset a node (hardware)
service_patching_duration: 5 # The time taken to patch a service
file_system_repairing_limit: 5 # The time take to repair the file system
file_system_restoring_limit: 5 # The time take to restore the file system
file_system_scanning_limit: 5 # The time taken to scan the file system

View File

@@ -1,11 +1,11 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import datetime
import json
import shutil
import tempfile
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Union
from typing import Any, Dict, Tuple, Union
from unittest.mock import patch
import pytest
@@ -13,7 +13,7 @@ import pytest
from primaite import getLogger
from primaite.environment.primaite_env import Primaite
from primaite.primaite_session import PrimaiteSession
from primaite.utils.session_output_reader import av_rewards_dict
from primaite.utils.session_output_reader import all_transactions_dict, av_rewards_dict
from tests.mock_and_patch.get_session_path_mock import get_temp_session_path
ACTION_SPACE_NODE_VALUES = 1
@@ -37,16 +37,26 @@ class TempPrimaiteSession(PrimaiteSession):
super().__init__(training_config_path, lay_down_config_path)
self.setup()
def learn_av_reward_per_episode(self) -> Dict[int, float]:
def learn_av_reward_per_episode_dict(self) -> Dict[int, float]:
"""Get the learn av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.learning_path / csv_file)
def eval_av_reward_per_episode_csv(self) -> Dict[int, float]:
def eval_av_reward_per_episode_dict(self) -> Dict[int, float]:
"""Get the eval av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.evaluation_path / csv_file)
def learn_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""Get the learn all transactions from file."""
csv_file = f"all_transactions_{self.timestamp_str}.csv"
return all_transactions_dict(self.learning_path / csv_file)
def eval_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""Get the eval all transactions from file."""
csv_file = f"all_transactions_{self.timestamp_str}.csv"
return all_transactions_dict(self.evaluation_path / csv_file)
def metadata_file_as_dict(self) -> Dict[str, Any]:
"""Read the session_metadata.json file and return as a dict."""
with open(self.session_path / "session_metadata.json", "r") as file:
@@ -62,7 +72,6 @@ class TempPrimaiteSession(PrimaiteSession):
def __exit__(self, type, value, tb):
shutil.rmtree(self.session_path)
shutil.rmtree(self.session_path.parent)
_LOGGER.debug(f"Deleted temp session directory: {self.session_path}")
@@ -114,7 +123,7 @@ def temp_primaite_session(request):
"""
training_config_path = request.param[0]
lay_down_config_path = request.param[1]
with patch("primaite.agents.agent.get_session_path", get_temp_session_path) as mck:
with patch("primaite.agents.agent_abc.get_session_path", get_temp_session_path) as mck:
mck.session_timestamp = datetime.now()
return TempPrimaiteSession(training_config_path, lay_down_config_path)

View File

@@ -0,0 +1 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import tempfile
from datetime import datetime
from pathlib import Path

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Used to tes the ACL functions."""
from primaite.acl.access_control_list import AccessControlList

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Used to test Active Node functions."""
import pytest

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Test env creation and behaviour with different observation spaces."""
import numpy as np
import pytest

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import os
import pytest

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import pytest
from primaite.config.lay_down_config import data_manipulation_config_path

Some files were not shown because too many files have changed in this diff Show More