Merged PR 101: Integrate ADSP RLlib and use PrimaiteSession for running between agent frameworks

## Summary
* Brought over the RLlib, hardcoded agents, and simple agents from ADSP 1.1.0. This opened a can of worms... ADSP got their stuff working in notebooks (***_stares at data scientists!_** 😂) but hadn't integrated it into the PrimAITE package or made the other PrimAITE functionality work with it.
* RLlib agents have been fully integrated with the wider PrimAITE package. This was done by:
  * The creation of an `AgentSessionABC` and `HardCodedAgentSessionABC` classes.
  * `SB3Agent` and `RLlibAgent` classes then inherited from `AgentSessionABC`.
  * The ADSP hardcoded agents were integrated into subclasses of `HardCodedAgentSessionABC`.
  * The random and dummy agents were also integrated into subclasses of `HardCodedagentSessionABC`.
  * A set of session output directories were created and managed by the agent session to enable consistent storage of session outputs in a common format regardless of the agent type.
  * The main config was rafactored so that it had
    * **agent_framework** - To identify whether SB3, RLlib, or Custom.
    * **agent_identifier** - To identify whether PPO, A2C, hardcoded, random, or dummy.
    *  **deep_learning_framework** - To identify which framework to use for RLlib.
* Transactions have been overhauled to simplify the process. It also means that they're written in real time so they're not lost if the agent crashes.
* Tests completely overhauled to use `PrimaiteSession`, or at least a test subclass, `TempPrimaiteSession`. It's temp because it uses temp directory rather than main primaite session directory, and it cleans up after itself.
* All the crap removed from `main.py` and made it so that it just runs `PrimaiteSession`.

Now this is where I went off on a tangent...
* CLI added to just make my life and everyone else's life easier.
* Primaite app config added to hold things like logging format, levels etc.
* A `primaite.data_viz.session_plots` module added so that the average reward per episode for each session is plotted and saves for each session (this helped while we were testing and bug fixing).

## Test process
* All tests use `TempPrimaiteSession`, which uses `PrimaiteSession`.
* I still need to write a tests that runs the RLlib, hardcoded, and random/dummy agents. I'll do that now while this is being reviewed.

## Still to do
* Update docs. I'm getting this PR up now so we can get it in to make use of the features. I'll get the docs updated today either on this branch or another branch (depending on how long this review takes).

## Checklist
- [X] This PR is linked to a **work item**
- [X] I have performed **self-review** of the code
- [X] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [X] I have run **pre-commit** checks for code style

Related work items: #917, #1563
This commit is contained in:
Christopher McCarthy
2023-07-04 08:08:31 +00:00
70 changed files with 3581 additions and 1369 deletions

View File

@@ -1,38 +1,12 @@
trigger:
- main
pool:
vmImage: ubuntu-latest
strategy:
matrix:
Ubuntu2004Python38:
python.version: '3.8'
imageName: 'ubuntu-20.04'
Ubuntu2004Python39:
python.version: '3.9'
imageName: 'ubuntu-20.04'
Ubuntu2004Python310:
Python310:
python.version: '3.10'
imageName: 'ubuntu-20.04'
WindowsPython38:
python.version: '3.8'
imageName: 'windows-latest'
WindowsPython39:
python.version: '3.9'
imageName: 'windows-latest'
WindowsPython310:
python.version: '3.10'
imageName: 'windows-latest'
MacPython38:
python.version: '3.8'
imageName: 'macOS-latest'
MacPython39:
python.version: '3.9'
imageName: 'macOS-latest'
MacPython310:
python.version: '3.10'
imageName: 'macOS-latest'
pool:
vmImage: $(imageName)
steps:
- task: UsePythonVersion@0

View File

@@ -6,18 +6,29 @@ trigger:
- bugfix/*
- release/*
pool:
vmImage: ubuntu-latest
strategy:
matrix:
Python38:
UbuntuPython38:
python.version: '3.8'
Python39:
python.version: '3.9'
Python310:
imageName: 'ubuntu-latest'
UbuntuPython310:
python.version: '3.10'
Python311:
python.version: '3.11'
imageName: 'ubuntu-latest'
WindowsPython38:
python.version: '3.8'
imageName: 'windows-latest'
WindowsPython310:
python.version: '3.10'
imageName: 'windows-latest'
MacOSPython38:
python.version: '3.8'
imageName: 'macOS-latest'
MacOSPython310:
python.version: '3.10'
imageName: 'macOS-latest'
pool:
vmImage: $(imageName)
steps:
- task: UsePythonVersion@0
@@ -47,16 +58,17 @@ steps:
PRIMAITE_WHEEL=$(ls ./dist/primaite*.whl)
python -m pip install $PRIMAITE_WHEEL[dev]
displayName: 'Install PrimAITE'
condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' ))
- script: |
forfiles /p dist\ /m *.whl /c "cmd /c python -m pip install @file[dev]"
displayName: 'Install PrimAITE'
condition: eq( variables['Agent.OS'], 'Windows_NT' )
- script: |
primaite setup
displayName: 'Perform PrimAITE Setup'
#- script: |
# flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
# displayName: 'Lint with flake8'
- script: |
pytest tests/
displayName: 'Run unmarked tests'
displayName: 'Run tests'

View File

@@ -13,6 +13,9 @@ repos:
rev: 23.1.0
hooks:
- id: black
args: [ "--line-length=120" ]
additional_dependencies:
- jupyter
- repo: http://github.com/pycqa/isort
rev: 5.12.0
hooks:
@@ -22,4 +25,5 @@ repos:
rev: 6.0.0
hooks:
- id: flake8
additional_dependencies: [ flake8-docstrings ]
additional_dependencies:
- flake8-docstrings

View File

@@ -20,13 +20,24 @@ The environment config file consists of the following attributes:
**Generic Config Values**
* **agent_identifier** [enum]
This identifies the agent to use for the session. Select from one of the following:
* **agent_framework** [enum]
This identifies the agent framework to be used to instantiate the agent algorithm. Select from one of the following:
* NONE - Where a user developed agent is to be used
* SB3 - Stable Baselines3
* RLLIB - Ray RLlib.
* **agent_identifier**
This identifies the agent to use for the session. Select from one of the following:
* A2C - Advantage Actor Critic
* PPO - Proximal Policy Optimization
* HARDCODED - A custom built deterministic agent
* RANDOM - A Stochastic random agent
* GENERIC - Where a user developed agent is to be used
* STABLE_BASELINES3_PPO - Use a SB3 PPO agent
* STABLE_BASELINES3_A2C - use a SB3 A2C agent
* **random_red_agent** [bool]
@@ -34,38 +45,38 @@ The environment config file consists of the following attributes:
* **action_type** [enum]
Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session
Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session
* **num_episodes** [int]
This defines the number of episodes that the agent will train or be evaluated over.
This defines the number of episodes that the agent will train or be evaluated over.
* **num_steps** [int]
Determines the number of steps to run in each episode of the session
Determines the number of steps to run in each episode of the session
* **time_delay** [int]
The time delay (in milliseconds) to take between each step when running a GENERIC agent session
The time delay (in milliseconds) to take between each step when running a GENERIC agent session
* **session_type** [text]
Type of session to be run (TRAINING or EVALUATION)
Type of session to be run (TRAINING, EVALUATION, or BOTH)
* **load_agent** [bool]
Determine whether to load an agent from file
Determine whether to load an agent from file
* **agent_load_file** [text]
File path and file name of agent if you're loading one in
File path and file name of agent if you're loading one in
* **observation_space_high_value** [int]
The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases
The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases
**Reward-Based Config Values**
@@ -73,95 +84,95 @@ Rewards are calculated based on the difference between the current state and ref
* **Generic [all_ok]** [int]
The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken)
The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken)
* **Node Hardware State [off_should_be_on]** [int]
The score to give when the node should be on, but is off
The score to give when the node should be on, but is off
* **Node Hardware State [off_should_be_resetting]** [int]
The score to give when the node should be resetting, but is off
The score to give when the node should be resetting, but is off
* **Node Hardware State [on_should_be_off]** [int]
The score to give when the node should be off, but is on
The score to give when the node should be off, but is on
* **Node Hardware State [on_should_be_resetting]** [int]
The score to give when the node should be resetting, but is on
The score to give when the node should be resetting, but is on
* **Node Hardware State [resetting_should_be_on]** [int]
The score to give when the node should be on, but is resetting
The score to give when the node should be on, but is resetting
* **Node Hardware State [resetting_should_be_off]** [int]
The score to give when the node should be off, but is resetting
The score to give when the node should be off, but is resetting
* **Node Hardware State [resetting]** [int]
The score to give when the node is resetting
The score to give when the node is resetting
* **Node Operating System or Service State [good_should_be_patching]** [int]
The score to give when the state should be patching, but is good
The score to give when the state should be patching, but is good
* **Node Operating System or Service State [good_should_be_compromised]** [int]
The score to give when the state should be compromised, but is good
The score to give when the state should be compromised, but is good
* **Node Operating System or Service State [good_should_be_overwhelmed]** [int]
The score to give when the state should be overwhelmed, but is good
The score to give when the state should be overwhelmed, but is good
* **Node Operating System or Service State [patching_should_be_good]** [int]
The score to give when the state should be good, but is patching
The score to give when the state should be good, but is patching
* **Node Operating System or Service State [patching_should_be_compromised]** [int]
The score to give when the state should be compromised, but is patching
The score to give when the state should be compromised, but is patching
* **Node Operating System or Service State [patching_should_be_overwhelmed]** [int]
The score to give when the state should be overwhelmed, but is patching
The score to give when the state should be overwhelmed, but is patching
* **Node Operating System or Service State [patching]** [int]
The score to give when the state is patching
The score to give when the state is patching
* **Node Operating System or Service State [compromised_should_be_good]** [int]
The score to give when the state should be good, but is compromised
The score to give when the state should be good, but is compromised
* **Node Operating System or Service State [compromised_should_be_patching]** [int]
The score to give when the state should be patching, but is compromised
The score to give when the state should be patching, but is compromised
* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [int]
The score to give when the state should be overwhelmed, but is compromised
The score to give when the state should be overwhelmed, but is compromised
* **Node Operating System or Service State [compromised]** [int]
The score to give when the state is compromised
The score to give when the state is compromised
* **Node Operating System or Service State [overwhelmed_should_be_good]** [int]
The score to give when the state should be good, but is overwhelmed
The score to give when the state should be good, but is overwhelmed
* **Node Operating System or Service State [overwhelmed_should_be_patching]** [int]
The score to give when the state should be patching, but is overwhelmed
The score to give when the state should be patching, but is overwhelmed
* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [int]
The score to give when the state should be compromised, but is overwhelmed
The score to give when the state should be compromised, but is overwhelmed
* **Node Operating System or Service State [overwhelmed]** [int]
The score to give when the state is overwhelmed
The score to give when the state is overwhelmed
* **Node File System State [good_should_be_repairing]** [int]
@@ -265,37 +276,37 @@ Rewards are calculated based on the difference between the current state and ref
* **IER Status [red_ier_running]** [int]
The score to give when a red agent IER is permitted to run
The score to give when a red agent IER is permitted to run
* **IER Status [green_ier_blocked]** [int]
The score to give when a green agent IER is prevented from running
The score to give when a green agent IER is prevented from running
**Patching / Reset Durations**
* **os_patching_duration** [int]
The number of steps to take when patching an Operating System
The number of steps to take when patching an Operating System
* **node_reset_duration** [int]
The number of steps to take when resetting a node's hardware state
The number of steps to take when resetting a node's hardware state
* **service_patching_duration** [int]
The number of steps to take when patching a service
The number of steps to take when patching a service
* **file_system_repairing_limit** [int]:
The number of steps to take when repairing the file system
The number of steps to take when repairing the file system
* **file_system_restoring_limit** [int]
The number of steps to take when restoring the file system
The number of steps to take when restoring the file system
* **file_system_scanning_limit** [int]
The number of steps to take when scanning the file system
The number of steps to take when scanning the file system
The Lay Down Config
*******************
@@ -304,22 +315,22 @@ The lay down config file consists of the following attributes:
* **itemType: ACTIONS** [enum]
Determines whether a NODE or ACL action space format is adopted for the session
Determines whether a NODE or ACL action space format is adopted for the session
* **itemType: OBSERVATION_SPACE** [dict]
Allows for user to configure observation space by combining one or more observation components. List of available
components is is :py:mod:'primaite.environment.observations'.
Allows for user to configure observation space by combining one or more observation components. List of available
components is is :py:mod:'primaite.environment.observations'.
The observation space config item should have a ``components`` key which is a list of components. Each component
config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the
component while it is being initialised.
The observation space config item should have a ``components`` key which is a list of components. Each component
config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the
component while it is being initialised.
This example illustrates the correct format for the observation space config item
This example illustrates the correct format for the observation space config item
.. code-block::yaml
- itemType: OBSERVATION_SPACE
- item_type: OBSERVATION_SPACE
components:
- name: LINK_TRAFFIC_LEVELS
options:
@@ -332,15 +343,15 @@ The lay down config file consists of the following attributes:
* **item_type: PORTS** [int]
Provides a list of ports modelled in this session
Provides a list of ports modelled in this session
* **item_type: SERVICES** [freetext]
Provides a list of services modelled in this session
Provides a list of services modelled in this session
* **item_type: NODE**
Defines a node included in the system laydown being simulated. It should consist of the following attributes:
Defines a node included in the system laydown being simulated. It should consist of the following attributes:
* **id** [int]: Unique ID for this YAML item
* **name** [freetext]: Human-readable name of the component
@@ -359,7 +370,7 @@ The lay down config file consists of the following attributes:
* **item_type: LINK**
Defines a link included in the system laydown being simulated. It should consist of the following attributes:
Defines a link included in the system laydown being simulated. It should consist of the following attributes:
* **id** [int]: Unique ID for this YAML item
* **name** [freetext]: Human-readable name of the component
@@ -369,7 +380,7 @@ The lay down config file consists of the following attributes:
* **item_type: GREEN_IER**
Defines a green agent Information Exchange Requirement (IER). It should consist of:
Defines a green agent Information Exchange Requirement (IER). It should consist of:
* **id** [int]: Unique ID for this YAML item
* **start_step** [int]: The start step (in the episode) for this IER to begin
@@ -383,7 +394,7 @@ The lay down config file consists of the following attributes:
* **item_type: RED_IER**
Defines a red agent Information Exchange Requirement (IER). It should consist of:
Defines a red agent Information Exchange Requirement (IER). It should consist of:
* **id** [int]: Unique ID for this YAML item
* **start_step** [int]: The start step (in the episode) for this IER to begin

View File

@@ -47,6 +47,8 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| asttokens | 2.2.1 | Apache 2.0 | https://github.com/gristlabs/asttokens |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| astunparse | 1.6.3 | BSD License | https://github.com/simonpercivall/astunparse |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| attrs | 23.1.0 | MIT License | https://www.attrs.org/en/stable/changelog.html |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| backcall | 0.2.0 | BSD License | https://github.com/takluyver/backcall |
@@ -103,6 +105,8 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| flake8 | 6.0.0 | MIT License | https://github.com/pycqa/flake8 |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| flatbuffers | 23.5.26 | Apache Software License | https://google.github.io/flatbuffers/ |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| fonttools | 4.39.4 | MIT License | http://github.com/fonttools/fonttools |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| fqdn | 1.5.1 | Mozilla Public License 2.0 (MPL 2.0) | https://github.com/ypcrts/fqdn |
@@ -111,9 +115,13 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| furo | 2023.3.27 | MIT License | UNKNOWN |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| gast | 0.4.0 | BSD License | https://github.com/serge-sans-paille/gast/ |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| google-auth | 2.19.0 | Apache Software License | https://github.com/googleapis/google-auth-library-python |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| google-auth-oauthlib | 1.0.0 | Apache Software License | https://github.com/GoogleCloudPlatform/google-auth-library-python-oauthlib |
| google-auth-oauthlib | 0.4.6 | Apache Software License | https://github.com/GoogleCloudPlatform/google-auth-library-python-oauthlib |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| google-pasta | 0.2.0 | Apache Software License | https://github.com/google/pasta |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| grpcio | 1.51.3 | Apache Software License | https://grpc.io |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
@@ -121,6 +129,8 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| gymnasium-notices | 0.0.1 | MIT License | https://github.com/Farama-Foundation/gym-notices |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| h5py | 3.9.0 | BSD License | https://www.h5py.org/ |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| identify | 2.5.24 | MIT License | https://github.com/pre-commit/identify |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| idna | 3.4 | BSD License | https://github.com/kjd/idna |
@@ -141,6 +151,8 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| isoduration | 20.11.0 | ISC License (ISCL) | https://github.com/bolsote/isoduration |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| jax | 0.4.12 | Apache-2.0 | https://github.com/google/jax |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| jedi | 0.18.2 | MIT License | https://github.com/davidhalter/jedi |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| json5 | 0.9.14 | Apache Software License | https://github.com/dpranke/pyjson5 |
@@ -151,14 +163,14 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| jupyter-events | 0.6.3 | BSD License | http://jupyter.org |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| jupyter-server | 1.24.0 | BSD License | https://jupyter-server.readthedocs.io |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| jupyter-ydoc | 0.2.4 | BSD 3-Clause License | https://jupyter.org |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| jupyter_client | 8.2.0 | BSD License | https://jupyter.org |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| jupyter_core | 5.3.0 | BSD License | https://jupyter.org |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| jupyter_server | 2.6.0 | BSD License | https://jupyter-server.readthedocs.io |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| jupyter_server_fileid | 0.9.0 | BSD License | UNKNOWN |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| jupyter_server_terminals | 0.4.4 | BSD License | https://jupyter.org |
@@ -171,10 +183,14 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| jupyterlab_server | 2.22.1 | BSD License | https://jupyterlab-server.readthedocs.io |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| keras | 2.12.0 | Apache Software License | https://keras.io/ |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| kiwisolver | 1.4.4 | BSD License | UNKNOWN |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| lazy_loader | 0.2 | BSD License | https://github.com/scientific-python/lazy_loader |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| libclang | 16.0.0 | Apache Software License | https://github.com/sighingnow/libclang |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| lz4 | 4.3.2 | BSD License | https://github.com/python-lz4/python-lz4 |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| markdown-it-py | 2.2.0 | MIT License | https://github.com/executablebooks/markdown-it-py |
@@ -183,19 +199,23 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| matplotlib-inline | 0.1.6 | BSD 3-Clause | https://github.com/ipython/matplotlib-inline |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| mavizstyle | 1.0.0 | UNKNOWN | UNKNOWN |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| mccabe | 0.7.0 | MIT License | https://github.com/pycqa/mccabe |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| mdurl | 0.1.2 | MIT License | https://github.com/executablebooks/mdurl |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| mistune | 2.0.5 | BSD License | https://github.com/lepture/mistune |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| ml-dtypes | 0.2.0 | Apache Software License | UNKNOWN |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| mock | 5.0.2 | BSD License | http://mock.readthedocs.org/en/latest/ |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| mpmath | 1.3.0 | BSD License | http://mpmath.org/ |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| msgpack | 1.0.5 | Apache Software License | https://msgpack.org/ |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| nbclassic | 1.0.0 | BSD License | https://github.com/jupyter/nbclassic |
| nbclassic | 0.5.6 | BSD License | https://github.com/jupyter/nbclassic |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| nbclient | 0.8.0 | BSD License | https://jupyter.org |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
@@ -217,6 +237,8 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| oauthlib | 3.2.2 | BSD License | https://github.com/oauthlib/oauthlib |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| opt-einsum | 3.3.0 | MIT | https://github.com/dgasmith/opt_einsum |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| overrides | 7.3.1 | Apache License, Version 2.0 | https://github.com/mkorpela/overrides |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| packaging | 23.1 | Apache Software License; BSD License | https://github.com/pypa/packaging |
@@ -231,11 +253,17 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| platformdirs | 3.5.1 | MIT License | https://github.com/platformdirs/platformdirs |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| plotly | 5.15.0 | MIT License | https://plotly.com/python/ |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| pluggy | 1.0.0 | MIT License | https://github.com/pytest-dev/pluggy |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| pre-commit | 2.20.0 | MIT License | https://github.com/pre-commit/pre-commit |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| primaite | 1.2.1 | MIT License | UNKNOWN |
| primaite | 2.0.0rc1 | MIT License | UNKNOWN |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| primaite | 2.0.0rc1 | MIT License | UNKNOWN |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| primaite | 2.0.0rc1 | MIT License | UNKNOWN |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| prometheus-client | 0.17.0 | Apache Software License | https://github.com/prometheus/client_python |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
@@ -295,6 +323,8 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| rsa | 4.9 | Apache Software License | https://stuvel.eu/rsa |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| ruff | 0.0.272 | MIT License | https://github.com/charliermarsh/ruff |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| scikit-image | 0.20.0 | BSD License | https://scikit-image.org |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| scipy | 1.10.1 | BSD License | https://scipy.org/ |
@@ -335,14 +365,26 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| tabulate | 0.9.0 | MIT License | https://github.com/astanin/python-tabulate |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| tensorboard | 2.12.3 | Apache Software License | https://github.com/tensorflow/tensorboard |
| tenacity | 8.2.2 | Apache Software License | https://github.com/jd/tenacity |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| tensorboard-data-server | 0.7.0 | Apache Software License | https://github.com/tensorflow/tensorboard/tree/master/tensorboard/data/server |
| tensorboard | 2.11.2 | Apache Software License | https://github.com/tensorflow/tensorboard |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| tensorboard-data-server | 0.6.1 | Apache Software License | https://github.com/tensorflow/tensorboard/tree/master/tensorboard/data/server |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| tensorboard-plugin-wit | 1.8.1 | Apache 2.0 | https://whatif-tool.dev |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| tensorboardX | 2.6 | MIT License | https://github.com/lanpa/tensorboardX |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| tensorflow | 2.12.0 | Apache Software License | https://www.tensorflow.org/ |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| tensorflow-estimator | 2.12.0 | Apache Software License | https://www.tensorflow.org/ |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| tensorflow-intel | 2.12.0 | Apache Software License | https://www.tensorflow.org/ |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| tensorflow-io-gcs-filesystem | 0.31.0 | Apache Software License | https://github.com/tensorflow/io |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| termcolor | 2.3.0 | MIT License | https://github.com/termcolor/termcolor |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| terminado | 0.17.1 | BSD License | https://github.com/jupyter/terminado |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| tifffile | 2023.4.12 | BSD License | https://www.cgohlke.com |
@@ -377,6 +419,8 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| websocket-client | 1.5.2 | Apache Software License | https://github.com/websocket-client/websocket-client.git |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| wrapt | 1.14.1 | BSD License | https://github.com/GrahamDumpleton/wrapt |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| y-py | 0.5.9 | MIT License | https://github.com/y-crdt/ypy |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| ypy-websocket | 0.8.2 | UNKNOWN | https://github.com/y-crdt/ypy-websocket |

View File

@@ -7,7 +7,7 @@ name = "primaite"
description = "PrimAITE (Primary-level AI Training Environment) is a simulation environment for training AI under the ARCD programme."
authors = [{name="QinetiQ Training and Simulation Ltd"}]
license = {text = "MIT License"}
requires-python = ">=3.8"
requires-python = ">=3.8, <3.11"
dynamic = ["version", "readme"]
classifiers = [
"License :: MIT License",
@@ -20,19 +20,23 @@ classifiers = [
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3 :: Only",
]
dependencies = [
"gym==0.21.0",
"jupyterlab==3.6.1",
"kaleido==0.2.1",
"matplotlib==3.7.1",
"networkx==3.1",
"numpy==1.23.5",
"platformdirs==3.5.1",
"plotly==5.15.0",
"polars==0.18.4",
"PyYAML==6.0",
"ray[rllib]==2.2.0",
"stable-baselines3==1.6.2",
"tensorflow==2.12.0",
"typer[all]==0.9.0"
]
@@ -62,9 +66,15 @@ dev = [
"sphinx-copybutton==0.5.2",
"wheel==0.38.4"
]
tensorflow = [
"tensorflow==2.12.0",
]
[project.scripts]
primaite = "primaite.cli:app"
[tool.isort]
profile = "black"
line_length = 120
force_sort_within_sections = "False"
order_by_type = "False"
[tool.black]
line-length = 120

View File

@@ -1 +1 @@
2.0.0dev0
2.0.0rc1

View File

@@ -2,10 +2,11 @@
import logging
import logging.config
import sys
from logging import Logger, StreamHandler
from bisect import bisect
from logging import Formatter, Logger, LogRecord, StreamHandler
from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import Final
from typing import Dict, Final
import pkg_resources
import yaml
@@ -18,11 +19,7 @@ _PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite")
def _get_primaite_config():
config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml"
if not config_path.exists():
config_path = Path(
pkg_resources.resource_filename(
"primaite", "setup/_package_data/primaite_config.yaml"
)
)
config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
with open(config_path, "r") as file:
primaite_config = yaml.safe_load(file)
log_level_map = {
@@ -33,7 +30,7 @@ def _get_primaite_config():
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
primaite_config["log_level"] = log_level_map[primaite_config["log_level"]]
primaite_config["log_level"] = log_level_map[primaite_config["logging"]["log_level"]]
return primaite_config
@@ -68,6 +65,28 @@ Users PrimAITE Sessions are stored at: ``~/primaite/sessions``.
# region Setup Logging
class _LevelFormatter(Formatter):
"""
A custom level-specific formatter.
Credit to: https://stackoverflow.com/a/68154386
"""
def __init__(self, formats: Dict[int, str], **kwargs):
super().__init__()
if "fmt" in kwargs:
raise ValueError("Format string must be passed to level-surrogate formatters, " "not this one")
self.formats = sorted((level, Formatter(fmt, **kwargs)) for level, fmt in formats.items())
def format(self, record: LogRecord) -> str:
"""Overrides ``Formatter.format``."""
idx = bisect(self.formats, (record.levelno,), hi=len(self.formats) - 1)
level, formatter = self.formats[idx]
return formatter.format(record)
def _log_dir() -> Path:
if sys.platform == "win32":
dir_path = _PLATFORM_DIRS.user_data_path / "logs"
@@ -76,6 +95,16 @@ def _log_dir() -> Path:
return dir_path
_LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter(
{
logging.DEBUG: _PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"],
logging.INFO: _PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"],
logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"],
logging.ERROR: _PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"],
logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"],
}
)
LOG_DIR: Final[Path] = _log_dir()
"""The path to the app log directory as an instance of `Path` or `PosixPath`, depending on the OS."""
@@ -85,18 +114,19 @@ LOG_PATH: Final[Path] = LOG_DIR / "primaite.log"
"""The primaite.log file path as an instance of `Path` or `PosixPath`, depending on the OS."""
_STREAM_HANDLER: Final[StreamHandler] = StreamHandler()
_FILE_HANDLER: Final[RotatingFileHandler] = RotatingFileHandler(
filename=LOG_PATH,
maxBytes=10485760, # 10MB
backupCount=9, # Max 100MB of logs
encoding="utf8",
)
_STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["log_level"])
_FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["log_level"])
_STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"])
_FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"])
_LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logger_format"]
_STREAM_HANDLER.setFormatter(logging.Formatter(_LOG_FORMAT_STR))
_FILE_HANDLER.setFormatter(logging.Formatter(_LOG_FORMAT_STR))
_LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logging"]["logger_format"]
_STREAM_HANDLER.setFormatter(_LEVEL_FORMATTER)
_FILE_HANDLER.setFormatter(_LEVEL_FORMATTER)
_LOGGER = logging.getLogger(__name__)
@@ -104,7 +134,7 @@ _LOGGER.addHandler(_STREAM_HANDLER)
_LOGGER.addHandler(_FILE_HANDLER)
def getLogger(name: str) -> Logger:
def getLogger(name: str) -> Logger: # noqa
"""
Get a PrimAITE logger.

View File

@@ -25,18 +25,9 @@ class AccessControlList:
True if match; False otherwise.
"""
if (
(
_rule.get_source_ip() == _source_ip_address
and _rule.get_dest_ip() == _dest_ip_address
)
or (
_rule.get_source_ip() == "ANY"
and _rule.get_dest_ip() == _dest_ip_address
)
or (
_rule.get_source_ip() == _source_ip_address
and _rule.get_dest_ip() == "ANY"
)
(_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address)
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == _dest_ip_address)
or (_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == "ANY")
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY")
):
return True
@@ -57,15 +48,9 @@ class AccessControlList:
Indicates block if all conditions are satisfied.
"""
for rule_key, rule_value in self.acl.items():
if self.check_address_match(
rule_value, _source_ip_address, _dest_ip_address
):
if (
rule_value.get_protocol() == _protocol
or rule_value.get_protocol() == "ANY"
) and (
str(rule_value.get_port()) == str(_port)
or rule_value.get_port() == "ANY"
if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address):
if (rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY") and (
str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY"
):
# There's a matching rule. Get the permission
if rule_value.get_permission() == "DENY":

View File

@@ -30,7 +30,13 @@ class ACLRule:
Returns hash of core parameters.
"""
return hash(
(self.permission, self.source_ip, self.dest_ip, self.protocol, self.port)
(
self.permission,
self.source_ip,
self.dest_ip,
self.protocol,
self.port,
)
)
def get_permission(self):

View File

@@ -0,0 +1,383 @@
from __future__ import annotations
import json
import time
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from typing import Dict, Final, Union
from uuid import uuid4
import yaml
import primaite
from primaite import getLogger, SESSIONS_DIR
from primaite.config import lay_down_config, training_config
from primaite.config.training_config import TrainingConfig
from primaite.data_viz.session_plots import plot_av_reward_per_episode
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
def get_session_path(session_timestamp: datetime) -> Path:
"""
Get the directory path the session will output to.
This is set in the format of:
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
:param session_timestamp: This is the datetime that the session started.
:return: The session directory path.
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = SESSIONS_DIR / date_dir / session_path
session_path.mkdir(exist_ok=True, parents=True)
return session_path
class AgentSessionABC(ABC):
"""
An ABC that manages training and/or evaluation of agents in PrimAITE.
This class cannot be directly instantiated and must be inherited from
with all implemented abstract methods implemented.
"""
@abstractmethod
def __init__(self, training_config_path, lay_down_config_path):
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path, str]] = training_config_path
self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path)
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
self._env: Primaite
self._agent = None
self._can_learn: bool = False
self._can_evaluate: bool = False
self.is_eval = False
self._uuid = str(uuid4())
self.session_timestamp: datetime = datetime.now()
"The session timestamp"
self.session_path = get_session_path(self.session_timestamp)
"The Session path"
@property
def timestamp_str(self) -> str:
"""The session timestamp as a string."""
return self.session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
@property
def learning_path(self) -> Path:
"""The learning outputs path."""
path = self.session_path / "learning"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def evaluation_path(self) -> Path:
"""The evaluation outputs path."""
path = self.session_path / "evaluation"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def checkpoints_path(self) -> Path:
"""The Session checkpoints path."""
path = self.learning_path / "checkpoints"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def uuid(self):
"""The Agent Session UUID."""
return self._uuid
def _write_session_metadata_file(self):
"""
Write the ``session_metadata.json`` file.
Creates a ``session_metadata.json`` in the ``session_path`` directory
and adds the following key/value pairs:
- uuid: The UUID assigned to the session upon instantiation.
- start_datetime: The date & time the session started in iso format.
- end_datetime: NULL.
- total_episodes: NULL.
- total_time_steps: NULL.
- env:
- training_config:
- All training config items
- lay_down_config:
- All lay down config items
"""
metadata_dict = {
"uuid": self.uuid,
"start_datetime": self.session_timestamp.isoformat(),
"end_datetime": None,
"learning": {"total_episodes": None, "total_time_steps": None},
"evaluation": {"total_episodes": None, "total_time_steps": None},
"env": {
"training_config": self._training_config.to_dict(json_serializable=True),
"lay_down_config": self._lay_down_config,
},
}
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Writing Session Metadata file: {filepath}")
with open(filepath, "w") as file:
json.dump(metadata_dict, file)
_LOGGER.debug("Finished writing session metadata file")
def _update_session_metadata_file(self):
"""
Update the ``session_metadata.json`` file.
Updates the `session_metadata.json`` in the ``session_path`` directory
with the following key/value pairs:
- end_datetime: The date & time the session ended in iso format.
- total_episodes: The total number of training episodes completed.
- total_time_steps: The total number of training time steps completed.
"""
with open(self.session_path / "session_metadata.json", "r") as file:
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
if not self.is_eval:
metadata_dict["learning"]["total_episodes"] = self._env.episode_count # noqa
metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa
else:
metadata_dict["evaluation"]["total_episodes"] = self._env.episode_count # noqa
metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
with open(filepath, "w") as file:
json.dump(metadata_dict, file)
_LOGGER.debug("Finished updating session metadata file")
@abstractmethod
def _setup(self):
_LOGGER.info(
"Welcome to the Primary-level AI Training Environment " f"(PrimAITE) (version: {primaite.__version__})"
)
_LOGGER.info(f"The output directory for this session is: {self.session_path}")
self._write_session_metadata_file()
self._can_learn = True
self._can_evaluate = False
@abstractmethod
def _save_checkpoint(self):
pass
@abstractmethod
def learn(
self,
**kwargs,
):
"""
Train the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
if self._can_learn:
_LOGGER.info("Finished learning")
_LOGGER.debug("Writing transactions")
self._update_session_metadata_file()
self._can_evaluate = True
self.is_eval = False
self._plot_av_reward_per_episode(learning_session=True)
@abstractmethod
def evaluate(
self,
**kwargs,
):
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
self._env.set_as_eval() # noqa
self.is_eval = True
self._plot_av_reward_per_episode(learning_session=False)
_LOGGER.info("Finished evaluation")
@abstractmethod
def _get_latest_checkpoint(self):
pass
@classmethod
@abstractmethod
def load(cls, path: Union[str, Path]) -> AgentSessionABC:
"""Load an agent from file."""
if not isinstance(path, Path):
path = Path(path)
if path.exists():
# Unpack the session_metadata.json file
md_file = path / "session_metadata.json"
with open(md_file, "r") as file:
md_dict = json.load(file)
# Create a temp directory and dump the training and lay down
# configs into it
temp_dir = path / ".temp"
temp_dir.mkdir(exist_ok=True)
temp_tc = temp_dir / "tc.yaml"
with open(temp_tc, "w") as file:
yaml.dump(md_dict["env"]["training_config"], file)
temp_ldc = temp_dir / "ldc.yaml"
with open(temp_ldc, "w") as file:
yaml.dump(md_dict["env"]["lay_down_config"], file)
agent = cls(temp_tc, temp_ldc)
agent.session_path = path
return agent
else:
# Session path does not exist
msg = f"Failed to load PrimAITE Session, path does not exist: {path}"
_LOGGER.error(msg)
raise FileNotFoundError(msg)
pass
@abstractmethod
def save(self):
"""Save the agent."""
self._agent.save(self.session_path)
@abstractmethod
def export(self):
"""Export the agent to transportable file format."""
pass
def close(self):
"""Closes the agent."""
self._env.episode_av_reward_writer.close() # noqa
self._env.transaction_writer.close() # noqa
def _plot_av_reward_per_episode(self, learning_session: bool = True):
# self.close()
title = f"PrimAITE Session {self.timestamp_str} "
subtitle = str(self._training_config)
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
image_file = f"average_reward_per_episode_{self.timestamp_str}.png"
if learning_session:
title += "(Learning)"
path = self.learning_path / csv_file
image_path = self.learning_path / image_file
else:
title += "(Evaluation)"
path = self.evaluation_path / csv_file
image_path = self.evaluation_path / image_file
fig = plot_av_reward_per_episode(path, title, subtitle)
fig.write_image(image_path)
_LOGGER.debug(f"Saved average rewards per episode plot to: {path}")
class HardCodedAgentSessionABC(AgentSessionABC):
"""
An Agent Session ABC for evaluation deterministic agents.
This class cannot be directly instantiated and must be inherited from
with all implemented abstract methods implemented.
"""
def __init__(self, training_config_path, lay_down_config_path):
super().__init__(training_config_path, lay_down_config_path)
self._setup()
def _setup(self):
self._env: Primaite = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
)
super()._setup()
self._can_learn = False
self._can_evaluate = True
def _save_checkpoint(self):
pass
def _get_latest_checkpoint(self):
pass
def learn(
self,
**kwargs,
):
"""
Train the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
_LOGGER.warning("Deterministic agents cannot learn")
@abstractmethod
def _calculate_action(self, obs):
pass
def evaluate(
self,
**kwargs,
):
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
self._env.set_as_eval() # noqa
self.is_eval = True
time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes
obs = self._env.reset()
for episode in range(episodes):
# Reset env and collect initial observation
for step in range(time_steps):
# Calculate action
action = self._calculate_action(obs)
# Perform the step
obs, reward, done, info = self._env.step(action)
if done:
break
# Introduce a delay between steps
time.sleep(self._training_config.time_delay / 1000)
obs = self._env.reset()
self._env.close()
@classmethod
def load(cls):
"""Load an agent from file."""
_LOGGER.warning("Deterministic agents cannot be loaded")
def save(self):
"""Save the agent."""
_LOGGER.warning("Deterministic agents cannot be saved")
def export(self):
"""Export the agent to transportable file format."""
_LOGGER.warning("Deterministic agents cannot be exported")

View File

@@ -0,0 +1,431 @@
import numpy as np
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.utils import (
get_new_action,
get_node_of_ip,
transform_action_acl_enum,
transform_change_obs_readable,
)
from primaite.common.enums import HardCodedAgentView
class HardCodedACLAgent(HardCodedAgentSessionABC):
"""An Agent Session class that implements a deterministic ACL agent."""
def _calculate_action(self, obs):
if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC:
# Basic view action using only the current observation
return self._calculate_action_basic_view(obs)
else:
# full view action using observation space, action
# history and reward feedback
return self._calculate_action_full_view(obs)
def get_blocked_green_iers(self, green_iers, acl, nodes):
"""
Get blocked green IERs.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
blocked_green_iers = {}
for green_ier_id, green_ier in green_iers.items():
source_node_id = green_ier.get_source_node_id()
source_node_address = nodes[source_node_id].ip_address
dest_node_id = green_ier.get_dest_node_id()
dest_node_address = nodes[dest_node_id].ip_address
protocol = green_ier.get_protocol() # e.g. 'TCP'
port = green_ier.get_port()
# Can be blocked by an ACL or by default (no allow rule exists)
if acl.is_blocked(source_node_address, dest_node_address, protocol, port):
blocked_green_iers[green_ier_id] = green_ier
return blocked_green_iers
def get_matching_acl_rules_for_ier(self, ier, acl, nodes):
"""
Get matching ACL rules for an IER.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
source_node_id = ier.get_source_node_id()
source_node_address = nodes[source_node_id].ip_address
dest_node_id = ier.get_dest_node_id()
dest_node_address = nodes[dest_node_id].ip_address
protocol = ier.get_protocol() # e.g. 'TCP'
port = ier.get_port()
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
return matching_rules
def get_blocking_acl_rules_for_ier(self, ier, acl, nodes):
"""
Get blocking ACL rules for an IER.
.. warning::
Can return empty dict but IER can still be blocked by default
(No ALLOW rule, therefore blocked).
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
blocked_rules = {}
for rule_key, rule_value in matching_rules.items():
if rule_value.get_permission() == "DENY":
blocked_rules[rule_key] = rule_value
return blocked_rules
def get_allow_acl_rules_for_ier(self, ier, acl, nodes):
"""
Get all allowing ACL rules for an IER.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
allowed_rules = {}
for rule_key, rule_value in matching_rules.items():
if rule_value.get_permission() == "ALLOW":
allowed_rules[rule_key] = rule_value
return allowed_rules
def get_matching_acl_rules(
self,
source_node_id,
dest_node_id,
protocol,
port,
acl,
nodes,
services_list,
):
"""
Get matching ACL rules.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
if source_node_id != "ANY":
source_node_address = nodes[str(source_node_id)].ip_address
else:
source_node_address = source_node_id
if dest_node_id != "ANY":
dest_node_address = nodes[str(dest_node_id)].ip_address
else:
dest_node_address = dest_node_id
if protocol != "ANY":
protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
return matching_rules
def get_allow_acl_rules(
self,
source_node_id,
dest_node_id,
protocol,
port,
acl,
nodes,
services_list,
):
"""
Get the ALLOW ACL rules.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
matching_rules = self.get_matching_acl_rules(
source_node_id,
dest_node_id,
protocol,
port,
acl,
nodes,
services_list,
)
allowed_rules = {}
for rule_key, rule_value in matching_rules.items():
if rule_value.get_permission() == "ALLOW":
allowed_rules[rule_key] = rule_value
return allowed_rules
def get_deny_acl_rules(
self,
source_node_id,
dest_node_id,
protocol,
port,
acl,
nodes,
services_list,
):
"""
Get the DENY ACL rules.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
matching_rules = self.get_matching_acl_rules(
source_node_id,
dest_node_id,
protocol,
port,
acl,
nodes,
services_list,
)
allowed_rules = {}
for rule_key, rule_value in matching_rules.items():
if rule_value.get_permission() == "DENY":
allowed_rules[rule_key] = rule_value
return allowed_rules
def _calculate_action_full_view(self, obs):
"""
Calculate a good acl-based action for the blue agent to take.
Knowledge of just the observation space is insufficient for a perfect solution, as we need to know:
- Which ACL rules already exist, - otherwise:
- The agent would perminently get stuck in a loop of performing the same action over and over.
(best action is to block something, but its already blocked but doesn't know this)
- The agent would be unable to interact with existing rules (e.g. how would it know to delete a rule,
if it doesnt know what rules exist)
- The Green IERs (optional) - It often needs to know which traffic it should be allowing. For example
in the default config one of the green IERs is blocked by default, but it has no way of knowing this
based on the observation space. Additionally, potentially in the future, once a node state
has been fixed (no longer compromised), it needs a way to know it should reallow traffic.
A RL agent can learn what the green IERs are on its own - but the rule based agent cannot easily do this.
There doesn't seem like there's much that can be done if an Operating or OS State is compromised
If a service node becomes compromised there's a decision to make - do we block that service?
Pros: It cannot launch an attack on another node, so the node will not be able to be OVERWHELMED
Cons: Will block a green IER, decreasing the reward
We decide to block the service.
Potentially a better solution (for the reward) would be to block the incomming traffic from compromised
nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing
an overwhelmed state, so we don't do this.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
# obs = convert_to_old_obs(obs)
r_obs = transform_change_obs_readable(obs)
_, _, _, *s = r_obs
if len(r_obs) == 4: # only 1 service
s = [*s]
# 1. Check if node is compromised. If so we want to block its outwards services
# a. If it is comprimised check if there's an allow rule we should delete.
# cons: might delete a multi-rule from any source node (ANY -> x)
# b. OPTIONAL (Deny rules not needed): Check if there already exists an existing Deny Rule so not to duplicate
# c. OPTIONAL (no allow rule = blocked): Add a DENY rule
found_action = False
for service_num, service_states in enumerate(s):
for x, service_state in enumerate(service_states):
if service_state == "COMPROMISED":
action_source_id = x + 1 # +1 as 0 is any
action_destination_id = "ANY"
action_protocol = service_num + 1 # +1 as 0 is any
action_port = "ANY"
allow_rules = self.get_allow_acl_rules(
action_source_id,
action_destination_id,
action_protocol,
action_port,
self._env.acl,
self._env.nodes,
self._env.services_list,
)
deny_rules = self.get_deny_acl_rules(
action_source_id,
action_destination_id,
action_protocol,
action_port,
self._env.acl,
self._env.nodes,
self._env.services_list,
)
if len(allow_rules) > 0:
# Check if there's an allow rule we should delete
rule = list(allow_rules.values())[0]
action_decision = "DELETE"
action_permission = "ALLOW"
action_source_ip = rule.get_source_ip()
action_source_id = int(get_node_of_ip(action_source_ip, self._env.nodes))
action_destination_ip = rule.get_dest_ip()
action_destination_id = int(get_node_of_ip(action_destination_ip, self._env.nodes))
action_protocol_name = rule.get_protocol()
action_protocol = (
self._env.services_list.index(action_protocol_name) + 1
) # convert name e.g. 'TCP' to index
action_port_name = rule.get_port()
action_port = (
self._env.ports_list.index(action_port_name) + 1
) # convert port name e.g. '80' to index
found_action = True
break
elif len(deny_rules) > 0:
# TODO OPTIONAL
# If there's already a DENY RULE, that blocks EVERYTHING from the source ip we don't need
# to create another
# Check to see if the DENY rule really blocks everything (ANY) or just a specific rule
continue
else:
# TODO OPTIONAL: Add a DENY rule, optional as by default no allow rule == blocked
action_decision = "CREATE"
action_permission = "DENY"
break
if found_action:
break
# 2. If NO Node is Comprimised, or the node has already been blocked, check the green IERs and
# add an Allow rule if the green IER is being blocked.
# a. OPTIONAL - NOT IMPLEMENTED (optional as a deny rule does not overwrite an allow rule):
# If there's a DENY rule delete it if:
# - There isn't already a deny rule
# - It doesnt allows a comprimised node to become operational.
# b. Add an ALLOW rule if:
# - There isn't already an allow rule
# - It doesnt allows a comprimised node to become operational
if not found_action:
# Which Green IERS are blocked
blocked_green_iers = self.get_blocked_green_iers(self._env.green_iers, self._env.acl, self._env.nodes)
for ier_key, ier in blocked_green_iers.items():
# Which ALLOW rules are allowing this IER (none)
allowing_rules = self.get_allow_acl_rules_for_ier(ier, self._env.acl, self._env.nodes)
# If there are no blocking rules, it may be being blocked by default
# If there is already an allow rule
node_id_to_check = int(ier.get_source_node_id())
service_name_to_check = ier.get_protocol()
service_id_to_check = self._env.services_list.index(service_name_to_check)
# Service state of the the source node in the ier
service_state = s[service_id_to_check][node_id_to_check - 1]
if len(allowing_rules) == 0 and service_state != "COMPROMISED":
action_decision = "CREATE"
action_permission = "ALLOW"
action_source_id = int(ier.get_source_node_id())
action_destination_id = int(ier.get_dest_node_id())
action_protocol_name = ier.get_protocol()
action_protocol = (
self._env.services_list.index(action_protocol_name) + 1
) # convert name e.g. 'TCP' to index
action_port_name = ier.get_port()
action_port = (
self._env.ports_list.index(action_port_name) + 1
) # convert port name e.g. '80' to index
found_action = True
break
if found_action:
action = [
action_decision,
action_permission,
action_source_id,
action_destination_id,
action_protocol,
action_port,
]
action = transform_action_acl_enum(action)
action = get_new_action(action, self._env.action_dict)
else:
# If no good/useful action has been found, just perform a nothing action
action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
action = transform_action_acl_enum(action)
action = get_new_action(action, self._env.action_dict)
return action
def _calculate_action_basic_view(self, obs):
"""Calculate a good acl-based action for the blue agent to take.
Uses ONLY information from the current observation with NO knowledge
of previous actions taken and NO reward feedback.
We rely on randomness to select the precise action, as we want to
block all traffic originating from a compromised node, without being
able to tell:
1. Which ACL rules already exist
2. Which actions the agent has already tried.
There is a high probability that the correct rule will not be deleted
before the state becomes overwhelmed.
Currently, a deny rule does not overwrite an allow rule. The allow
rules must be deleted.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_dict = self._env.action_dict
r_obs = transform_change_obs_readable(obs)
_, o, _, *s = r_obs
if len(r_obs) == 4: # only 1 service
s = [*s]
number_of_nodes = len([i for i in o if i != "NONE"]) # number of nodes (not links)
for service_num, service_states in enumerate(s):
comprimised_states = [n for n, i in enumerate(service_states) if i == "COMPROMISED"]
if len(comprimised_states) == 0:
# No states are COMPROMISED, try the next service
continue
compromised_node = np.random.choice(comprimised_states) + 1 # +1 as 0 would be any
action_decision = "DELETE"
action_permission = "ALLOW"
action_source_ip = compromised_node
# Randomly select a destination ID to block
action_destination_ip = np.random.choice(list(range(1, number_of_nodes + 1)) + ["ANY"])
action_destination_ip = (
int(action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip
)
action_protocol = service_num + 1 # +1 as 0 is any
# Randomly select a port to block
# Bad assumption that number of protocols equals number of ports
# AND no rules exist with an ANY port
action_port = np.random.choice(list(range(1, len(s) + 1)))
action = [
action_decision,
action_permission,
action_source_ip,
action_destination_ip,
action_protocol,
action_port,
]
action = transform_action_acl_enum(action)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
return action
# If no good/useful action has been found, just perform a nothing action
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
nothing_action = transform_action_acl_enum(nothing_action)
nothing_action = get_new_action(nothing_action, action_dict)
return nothing_action

View File

@@ -0,0 +1,119 @@
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable
class HardCodedNodeAgent(HardCodedAgentSessionABC):
"""An Agent Session class that implements a deterministic Node agent."""
def _calculate_action(self, obs):
"""
Calculate a good node-based action for the blue agent to take.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_dict = self._env.action_dict
r_obs = transform_change_obs_readable(obs)
_, o, os, *s = r_obs
if len(r_obs) == 4: # only 1 service
s = [*s]
# Check in order of most important states (order doesn't currently
# matter, but it probably should)
# First see if any OS states are compromised
for x, os_state in enumerate(os):
if os_state == "COMPROMISED":
action_node_id = x + 1
action_node_property = "OS"
property_action = "PATCHING"
action_service_index = 0 # does nothing isn't relevant for os
action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
action = transform_action_node_enum(action)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
return action
# Next, see if any Services are compromised
# We fix the compromised state before overwhelemd state,
# If a compromised entry node is fixed before the overwhelmed state is triggered, instruction is ignored
for service_num, service in enumerate(s):
for x, service_state in enumerate(service):
if service_state == "COMPROMISED":
action_node_id = x + 1
action_node_property = "SERVICE"
property_action = "PATCHING"
action_service_index = service_num
action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
action = transform_action_node_enum(action)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
return action
# Next, See if any services are overwhelmed
# perhaps this should be fixed automatically when the compromised PCs issues are also resolved
# Currently there's no reason that an Overwhelmed state cannot be resolved before resolving the compromised PCs
for service_num, service in enumerate(s):
for x, service_state in enumerate(service):
if service_state == "OVERWHELMED":
action_node_id = x + 1
action_node_property = "SERVICE"
property_action = "PATCHING"
action_service_index = service_num
action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
action = transform_action_node_enum(action)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
return action
# Finally, turn on any off nodes
for x, operating_state in enumerate(o):
if os_state == "OFF":
action_node_id = x + 1
action_node_property = "OPERATING"
property_action = "ON" # Why reset it when we can just turn it on
action_service_index = 0 # does nothing isn't relevant for operating state
action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
action = transform_action_node_enum(action, action_dict)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
return action
# If no good actions, just go with an action that wont do any harm
action_node_id = 1
action_node_property = "NONE"
property_action = "ON"
action_service_index = 0
action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
action = transform_action_node_enum(action)
action = get_new_action(action, action_dict)
return action

View File

@@ -0,0 +1,171 @@
from __future__ import annotations
import json
from datetime import datetime
from pathlib import Path
from typing import Union
from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.a2c import A2CConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import UnifiedLogger
from ray.tune.registry import register_env
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
def _env_creator(env_config):
return Primaite(
training_config_path=env_config["training_config_path"],
lay_down_config_path=env_config["lay_down_config_path"],
session_path=env_config["session_path"],
timestamp_str=env_config["timestamp_str"],
)
def _custom_log_creator(session_path: Path):
logdir = session_path / "ray_results"
logdir.mkdir(parents=True, exist_ok=True)
def logger_creator(config):
return UnifiedLogger(config, logdir, loggers=None)
return logger_creator
class RLlibAgent(AgentSessionABC):
"""An AgentSession class that implements a Ray RLlib agent."""
def __init__(self, training_config_path, lay_down_config_path):
super().__init__(training_config_path, lay_down_config_path)
if not self._training_config.agent_framework == AgentFramework.RLLIB:
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
raise ValueError(msg)
if self._training_config.agent_identifier == AgentIdentifier.PPO:
self._agent_config_class = PPOConfig
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
self._agent_config_class = A2CConfig
else:
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}"
_LOGGER.error(msg)
raise ValueError(msg)
self._agent_config: Union[PPOConfig, A2CConfig]
self._current_result: dict
self._setup()
_LOGGER.debug(
f"Created {self.__class__.__name__} using: "
f"agent_framework={self._training_config.agent_framework}, "
f"agent_identifier="
f"{self._training_config.agent_identifier}, "
f"deep_learning_framework="
f"{self._training_config.deep_learning_framework}"
)
def _update_session_metadata_file(self):
"""
Update the ``session_metadata.json`` file.
Updates the `session_metadata.json`` in the ``session_path`` directory
with the following key/value pairs:
- end_datetime: The date & time the session ended in iso format.
- total_episodes: The total number of training episodes completed.
- total_time_steps: The total number of training time steps completed.
"""
with open(self.session_path / "session_metadata.json", "r") as file:
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
metadata_dict["total_episodes"] = self._current_result["episodes_total"]
metadata_dict["total_time_steps"] = self._current_result["timesteps_total"]
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
with open(filepath, "w") as file:
json.dump(metadata_dict, file)
_LOGGER.debug("Finished updating session metadata file")
def _setup(self):
super()._setup()
register_env("primaite", _env_creator)
self._agent_config = self._agent_config_class()
self._agent_config.environment(
env="primaite",
env_config=dict(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
),
)
self._agent_config.training(train_batch_size=self._training_config.num_steps)
self._agent_config.framework(framework="tf")
self._agent_config.rollouts(
num_rollout_workers=1,
num_envs_per_worker=1,
horizon=self._training_config.num_steps,
)
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._current_result["episodes_total"]
if checkpoint_n > 0 and episode_count > 0:
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
self._agent.save(str(self.checkpoints_path))
def learn(
self,
**kwargs,
):
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
for i in range(episodes):
self._current_result = self._agent.train()
self._save_checkpoint()
self._agent.stop()
super().learn()
def evaluate(
self,
**kwargs,
):
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
raise NotImplementedError
def _get_latest_checkpoint(self):
raise NotImplementedError
@classmethod
def load(cls, path: Union[str, Path]) -> RLlibAgent:
"""Load an agent from file."""
raise NotImplementedError
def save(self):
"""Save the agent."""
raise NotImplementedError
def export(self):
"""Export the agent to transportable file format."""
raise NotImplementedError

141
src/primaite/agents/sb3.py Normal file
View File

@@ -0,0 +1,141 @@
from __future__ import annotations
from pathlib import Path
from typing import Union
import numpy as np
from stable_baselines3 import A2C, PPO
from stable_baselines3.ppo import MlpPolicy as PPOMlp
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
class SB3Agent(AgentSessionABC):
"""An AgentSession class that implements a Stable Baselines3 agent."""
def __init__(self, training_config_path, lay_down_config_path):
super().__init__(training_config_path, lay_down_config_path)
if not self._training_config.agent_framework == AgentFramework.SB3:
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
raise ValueError(msg)
if self._training_config.agent_identifier == AgentIdentifier.PPO:
self._agent_class = PPO
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
self._agent_class = A2C
else:
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier}"
_LOGGER.error(msg)
raise ValueError(msg)
self._tensorboard_log_path = self.learning_path / "tensorboard_logs"
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
self._setup()
_LOGGER.debug(
f"Created {self.__class__.__name__} using: "
f"agent_framework={self._training_config.agent_framework}, "
f"agent_identifier="
f"{self._training_config.agent_identifier}"
)
self.is_eval = False
def _setup(self):
super()._setup()
self._env = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
)
self._agent = self._agent_class(
PPOMlp,
self._env,
verbose=self.sb3_output_verbose_level,
n_steps=self._training_config.num_steps,
tensorboard_log=str(self._tensorboard_log_path),
)
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._env.episode_count
if checkpoint_n > 0 and episode_count > 0:
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
self._agent.save(checkpoint_path)
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
def _get_latest_checkpoint(self):
pass
def learn(
self,
**kwargs,
):
"""
Train the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes
self.is_eval = False
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
for i in range(episodes):
self._agent.learn(total_timesteps=time_steps)
self._save_checkpoint()
self._env.reset()
self._env.close()
super().learn()
def evaluate(
self,
deterministic: bool = True,
**kwargs,
):
"""
Evaluate the agent.
:param deterministic: Whether the evaluation is deterministic.
:param kwargs: Any agent-specific key-word args to be passed.
"""
time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes
self._env.set_as_eval()
self.is_eval = True
if deterministic:
deterministic_str = "deterministic"
else:
deterministic_str = "non-deterministic"
_LOGGER.info(
f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
)
for episode in range(episodes):
obs = self._env.reset()
for step in range(time_steps):
action, _states = self._agent.predict(obs, deterministic=deterministic)
if isinstance(action, np.ndarray):
action = np.int64(action)
obs, rewards, done, info = self._env.step(action)
self._env.reset()
self._env.close()
super().evaluate()
@classmethod
def load(cls, path: Union[str, Path]) -> SB3Agent:
"""Load an agent from file."""
raise NotImplementedError
def save(self):
"""Save the agent."""
raise NotImplementedError
def export(self):
"""Export the agent to transportable file format."""
raise NotImplementedError

View File

@@ -0,0 +1,56 @@
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
class RandomAgent(HardCodedAgentSessionABC):
"""
A Random Agent.
Get a completely random action from the action space.
"""
def _calculate_action(self, obs):
return self._env.action_space.sample()
class DummyAgent(HardCodedAgentSessionABC):
"""
A Dummy Agent.
All action spaces setup so dummy action is always 0 regardless of action
type used.
"""
def _calculate_action(self, obs):
return 0
class DoNothingACLAgent(HardCodedAgentSessionABC):
"""
A do nothing ACL agent.
A valid ACL action that has no effect; does nothing.
"""
def _calculate_action(self, obs):
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
nothing_action = transform_action_acl_enum(nothing_action)
nothing_action = get_new_action(nothing_action, self._env.action_dict)
return nothing_action
class DoNothingNodeAgent(HardCodedAgentSessionABC):
"""
A do nothing Node agent.
A valid Node action that has no effect; does nothing.
"""
def _calculate_action(self, obs):
nothing_action = [1, "NONE", "ON", 0]
nothing_action = transform_action_node_enum(nothing_action)
nothing_action = get_new_action(nothing_action, self._env.action_dict)
# nothing_action should currently always be 0
return nothing_action

View File

@@ -1,4 +1,13 @@
from primaite.common.enums import NodeHardwareAction, NodePOLType, NodeSoftwareAction
import numpy as np
from primaite.common.enums import (
HardwareState,
LinkStatus,
NodeHardwareAction,
NodePOLType,
NodeSoftwareAction,
SoftwareState,
)
def transform_action_node_readable(action):
@@ -7,14 +16,15 @@ def transform_action_node_readable(action):
example:
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_node_property = NodePOLType(action[1]).name
if action_node_property == "OPERATING":
property_action = NodeHardwareAction(action[2]).name
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[
2
] <= 1:
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1:
property_action = NodeSoftwareAction(action[2]).name
else:
property_action = "NONE"
@@ -29,6 +39,9 @@ def transform_action_acl_readable(action):
example:
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"}
action_permissions = {0: "DENY", 1: "ALLOW"}
@@ -53,6 +66,9 @@ def is_valid_node_action(action):
Does NOT consider:
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
- Node already being in that state (turning an ON node ON)
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_r = transform_action_node_readable(action)
@@ -68,7 +84,10 @@ def is_valid_node_action(action):
if node_property == "OPERATING" and node_action == "PATCHING":
# Operating State cannot PATCH
return False
if node_property != "OPERATING" and node_action not in ["NONE", "PATCHING"]:
if node_property != "OPERATING" and node_action not in [
"NONE",
"PATCHING",
]:
# Software States can only do Nothing or Patch
return False
return True
@@ -83,6 +102,9 @@ def is_valid_acl_action(action):
Does NOT consider:
- Trying to create identical rules
- Trying to create a rule which is a subset of another rule (caused by "ANY")
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_r = transform_action_acl_readable(action)
@@ -93,11 +115,7 @@ def is_valid_acl_action(action):
if action_decision == "NONE":
return False
if (
action_source_id == action_destination_id
and action_source_id != "ANY"
and action_destination_id != "ANY"
):
if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY":
# ACL rule towards itself
return False
if action_permission == "DENY":
@@ -109,7 +127,12 @@ def is_valid_acl_action(action):
def is_valid_acl_action_extra(action):
"""Harsher version of valid acl actions, does not allow action."""
"""
Harsher version of valid acl actions, does not allow action.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
if is_valid_acl_action(action) is False:
return False
@@ -125,3 +148,406 @@ def is_valid_acl_action_extra(action):
return False
return True
def transform_change_obs_readable(obs):
"""
Transform list of transactions to readable list of each observation property.
example:
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']]
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
ids = [i for i in obs[:, 0]]
operating_states = [HardwareState(i).name for i in obs[:, 1]]
os_states = [SoftwareState(i).name for i in obs[:, 2]]
new_obs = [ids, operating_states, os_states]
for service in range(3, obs.shape[1]):
# Links bit/s don't have a service state
service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]]
new_obs.append(service_states)
return new_obs
def transform_obs_readable(obs):
"""
Transform observation to readable format.
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
changed_obs = transform_change_obs_readable(obs)
new_obs = list(zip(*changed_obs))
# Convert list of tuples to list of lists
new_obs = [list(i) for i in new_obs]
return new_obs
def convert_to_new_obs(obs, num_nodes=10):
"""
Convert original gym Box observation space to new multiDiscrete observation space.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
# Remove ID columns, remove links and flatten to MultiDiscrete observation space
new_obs = obs[:num_nodes, 1:].flatten()
return new_obs
def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
"""
Convert to old observation.
Links filled with 0's as no information is included in new observation space.
example:
obs = array([1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1, 1])
new_obs = array([[ 1, 1, 1, 1],
[ 2, 1, 1, 1],
[ 3, 1, 1, 1],
...
[20, 0, 0, 0]])
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
# Convert back to more readable, original format
reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2)
# Add empty links back and add node ID back
s = np.zeros(
[reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1],
dtype=np.int64,
)
s[:, 0] = range(1, num_nodes + num_links + 1) # Adding ID back
s[:num_nodes, 1:] = reshaped_nodes # put values back in
new_obs = s
# Add links back in
links = obs[-num_links:]
# Links will be added to the last protocol/service slot but they are not specific to that service
new_obs[num_nodes:, -1] = links
return new_obs
def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
"""
Return string describing change between two observations.
example:
obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]])
obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]])
output = 'ID 1: SERVICE 2 set to GOOD'
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services)
obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services)
list_of_changes = []
for n, row in enumerate(obs1 - obs2):
if row.any() != 0:
relevant_changes = np.where(row != 0, obs2[n], -1)
relevant_changes[0] = obs2[n, 0] # ID is always relevant
is_link = relevant_changes[0] > num_nodes
desc = _describe_obs_change_helper(relevant_changes, is_link)
list_of_changes.append(desc)
change_string = "\n ".join(list_of_changes)
if len(list_of_changes) > 0:
change_string = "\n " + change_string
return change_string
def _describe_obs_change_helper(obs_change, is_link):
"""
Helper funcion to describe what has changed.
example:
[ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD"
Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.'
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
# Indexes where a change has occured, not including 0th index
index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1]
# Node pol types, Indexes >= 3 are service nodes
NodePOLTypes = [NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) for i in index_changed]
# Account for hardware states, software sattes and links
states = [
LinkStatus(obs_change[i]).name
if is_link
else HardwareState(obs_change[i]).name
if i == 1
else SoftwareState(obs_change[i]).name
for i in index_changed
]
if not is_link:
desc = f"ID {obs_change[0]}:"
for node_pol_type, state in list(zip(NodePOLTypes, states)):
desc = desc + " " + node_pol_type + " changed to " + state + "."
else:
desc = f"ID {obs_change[0]}: Link traffic changed to {states[0]}."
return desc
def transform_action_node_enum(action):
"""
Convert a node action from readable string format, to enumerated format.
example:
[1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0]
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_node_id = action[0]
action_node_property = NodePOLType[action[1]].value
if action[1] == "OPERATING":
property_action = NodeHardwareAction[action[2]].value
elif action[1] == "OS" or action[1] == "SERVICE":
property_action = NodeSoftwareAction[action[2]].value
else:
property_action = 0
action_service_index = action[3]
new_action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
return new_action
def transform_action_node_readable(action):
"""
Convert a node action from enumerated format to readable format.
example:
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_node_property = NodePOLType(action[1]).name
if action_node_property == "OPERATING":
property_action = NodeHardwareAction(action[2]).name
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1:
property_action = NodeSoftwareAction(action[2]).name
else:
property_action = "NONE"
new_action = [action[0], action_node_property, property_action, action[3]]
return new_action
def node_action_description(action):
"""
Generate string describing a node-based action.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
if isinstance(action[1], (int, np.int64)):
# transform action to readable format
action = transform_action_node_readable(action)
node_id = action[0]
node_property = action[1]
property_action = action[2]
service_id = action[3]
if property_action == "NONE":
return ""
if node_property == "OPERATING" or node_property == "OS":
description = f"NODE {node_id}, {node_property}, SET TO {property_action}"
elif node_property == "SERVICE":
description = f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}"
else:
return ""
return description
def transform_action_acl_enum(action):
"""
Convert acl action from readable str format, to enumerated format.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2}
action_permissions = {"DENY": 0, "ALLOW": 1}
action_decision = action_decisions[action[0]]
action_permission = action_permissions[action[1]]
# For IPs, Ports and Protocols, ANY has value 0, otherwise its just an index
new_action = [action_decision, action_permission] + list(action[2:6])
for n, val in enumerate(list(action[2:6])):
if val == "ANY":
new_action[n + 2] = 0
new_action = np.array(new_action)
return new_action
def acl_action_description(action):
"""
Generate string describing an acl-based action.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
if isinstance(action[0], (int, np.int64)):
# transform action to readable format
action = transform_action_acl_readable(action)
if action[0] == "NONE":
description = "NO ACL RULE APPLIED"
else:
description = (
f"{action[0]} RULE: {action[1]} traffic from IP {action[2]} to IP {action[3]},"
f" for protocol/service index {action[4]} on port index {action[5]}"
)
return description
def get_node_of_ip(ip, node_dict):
"""
Get the node ID of an IP address.
node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes)
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
for node_key, node_value in node_dict.items():
node_ip = node_value.ip_address
if node_ip == ip:
return node_key
def is_valid_node_action(action):
"""Is the node action an actual valid action.
Only uses information about the action to determine if the action has an effect
Does NOT consider:
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
- Node already being in that state (turning an ON node ON)
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_r = transform_action_node_readable(action)
node_property = action_r[1]
node_action = action_r[2]
if node_property == "NONE":
return False
if node_action == "NONE":
return False
if node_property == "OPERATING" and node_action == "PATCHING":
# Operating State cannot PATCH
return False
if node_property != "OPERATING" and node_action not in [
"NONE",
"PATCHING",
]:
# Software States can only do Nothing or Patch
return False
return True
def is_valid_acl_action(action):
"""
Is the ACL action an actual valid action.
Only uses information about the action to determine if the action has an effect
Does NOT consider:
- Trying to create identical rules
- Trying to create a rule which is a subset of another rule (caused by "ANY")
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_r = transform_action_acl_readable(action)
action_decision = action_r[0]
action_permission = action_r[1]
action_source_id = action_r[2]
action_destination_id = action_r[3]
if action_decision == "NONE":
return False
if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY":
# ACL rule towards itself
return False
if action_permission == "DENY":
# DENY is unnecessary, we can create and delete allow rules instead
# No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY.
return False
return True
def is_valid_acl_action_extra(action):
"""
Harsher version of valid acl actions, does not allow action.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
if is_valid_acl_action(action) is False:
return False
action_r = transform_action_acl_readable(action)
action_protocol = action_r[4]
action_port = action_r[5]
# Don't allow protocols or ports to be ANY
# in the future we might want to do the opposite, and only have ANY option for ports and service
if action_protocol == "ANY":
return False
if action_port == "ANY":
return False
return True
def get_new_action(old_action, action_dict):
"""
Get new action (e.g. 32) from old action e.g. [1,1,1,0].
Old_action can be either node or acl action type
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
for key, val in action_dict.items():
if list(val) == list(old_action):
return key
# Not all possible actions are included in dict, only valid action are
# if action is not in the dict, its an invalid action so return 0
return 0

View File

@@ -13,6 +13,8 @@ import yaml
from platformdirs import PlatformDirs
from typing_extensions import Annotated
from primaite.data_viz import PlotlyTemplate
app = typer.Typer()
@@ -76,11 +78,12 @@ def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None):
primaite_config = yaml.safe_load(file)
if level:
primaite_config["log_level"] = level.value
primaite_config["logging"]["log_level"] = level.value
with open(user_config_path, "w") as file:
yaml.dump(primaite_config, file)
print(f"PrimAITE Log Level: {level}")
else:
level = primaite_config["log_level"]
level = primaite_config["logging"]["log_level"]
print(f"PrimAITE Log Level: {level}")
@@ -119,30 +122,18 @@ def setup(overwrite_existing: bool = True):
app_dirs = PlatformDirs(appname="primaite")
app_dirs.user_config_path.mkdir(exist_ok=True, parents=True)
user_config_path = app_dirs.user_config_path / "primaite_config.yaml"
build_config = overwrite_existing or (not user_config_path.exists())
if build_config:
pkg_config_path = Path(
pkg_resources.resource_filename(
"primaite", "setup/_package_data/primaite_config.yaml"
)
)
pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
shutil.copy2(pkg_config_path, user_config_path)
shutil.copy2(pkg_config_path, user_config_path)
from primaite import getLogger
from primaite.setup import (
old_installation_clean_up,
reset_demo_notebooks,
reset_example_configs,
setup_app_dirs,
)
from primaite.setup import old_installation_clean_up, reset_demo_notebooks, reset_example_configs, setup_app_dirs
_LOGGER = getLogger(__name__)
_LOGGER.info("Performing the PrimAITE first-time setup...")
if build_config:
_LOGGER.info("Building primaite_config.yaml...")
_LOGGER.info("Building primaite_config.yaml...")
_LOGGER.info("Building the PrimAITE app directories...")
setup_app_dirs.run()
@@ -160,13 +151,54 @@ def setup(overwrite_existing: bool = True):
@app.command()
def session(tc: str, ldc: str):
def session(tc: Optional[str] = None, ldc: Optional[str] = None):
"""
Run a PrimAITE session.
:param tc: The training config filepath.
:param ldc: The lay down config file path.
tc: The training config filepath. Optional. If no value is passed then
example default training config is used from:
~/primaite/config/example_config/training/training_config_main.yaml.
ldc: The lay down config file path. Optional. If no value is passed then
example default lay down config is used from:
~/primaite/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml.
"""
from primaite.config.lay_down_config import dos_very_basic_config_path
from primaite.config.training_config import main_training_config_path
from primaite.main import run
if not tc:
tc = main_training_config_path()
if not ldc:
ldc = dos_very_basic_config_path()
run(training_config_path=tc, lay_down_config_path=ldc)
@app.command()
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None):
"""
View or set the plotly template for Session plots.
To View, simply call: primaite plotly-template
To set, call: primaite plotly-template <desired template>
For example, to set as plotly_dark, call: primaite plotly-template PLOTLY_DARK
"""
app_dirs = PlatformDirs(appname="primaite")
app_dirs.user_config_path.mkdir(exist_ok=True, parents=True)
user_config_path = app_dirs.user_config_path / "primaite_config.yaml"
if user_config_path.exists():
with open(user_config_path, "r") as file:
primaite_config = yaml.safe_load(file)
if template:
primaite_config["session"]["outputs"]["plots"]["template"] = template.value
with open(user_config_path, "w") as file:
yaml.dump(primaite_config, file)
print(f"PrimAITE plotly template: {template.value}")
else:
template = primaite_config["session"]["outputs"]["plots"]["template"]
print(f"PrimAITE plotly template: {template}")

View File

@@ -1,7 +1,7 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""Enumerations for APE."""
from enum import Enum
from enum import Enum, IntEnum
class NodeType(Enum):
@@ -32,6 +32,7 @@ class Priority(Enum):
class HardwareState(Enum):
"""Node hardware state enumeration."""
NONE = 0
ON = 1
OFF = 2
RESETTING = 3
@@ -42,6 +43,7 @@ class HardwareState(Enum):
class SoftwareState(Enum):
"""Software or Service state enumeration."""
NONE = 0
GOOD = 1
PATCHING = 2
COMPROMISED = 3
@@ -79,6 +81,65 @@ class Protocol(Enum):
NONE = 7
class SessionType(Enum):
"""The type of PrimAITE Session to be run."""
TRAIN = 1
"Train an agent"
EVAL = 2
"Evaluate an agent"
TRAIN_EVAL = 3
"Train then evaluate an agent"
class AgentFramework(Enum):
"""The agent algorithm framework/package."""
CUSTOM = 0
"Custom Agent"
SB3 = 1
"Stable Baselines3"
RLLIB = 2
"Ray RLlib"
class DeepLearningFramework(Enum):
"""The deep learning framework."""
TF = "tf"
"Tensorflow"
TF2 = "tf2"
"Tensorflow 2.x"
TORCH = "torch"
"PyTorch"
class AgentIdentifier(Enum):
"""The Red Agent algo/class."""
A2C = 1
"Advantage Actor Critic"
PPO = 2
"Proximal Policy Optimization"
HARDCODED = 3
"The Hardcoded agents"
DO_NOTHING = 4
"The DoNothing agents"
RANDOM = 5
"The RandomAgent"
DUMMY = 6
"The DummyAgent"
class HardCodedAgentView(Enum):
"""The view the deterministic hard-coded agent has of the environment."""
BASIC = 1
"The current observation space only"
FULL = 2
"Full environment view with actions taken and reward feedback"
class ActionType(Enum):
"""Action type enumeration."""
@@ -128,3 +189,11 @@ class LinkStatus(Enum):
MEDIUM = 2
HIGH = 3
OVERLOAD = 4
class SB3OutputVerboseLevel(IntEnum):
"""The Stable Baselines3 learn/eval output verbosity level."""
NONE = 0
INFO = 1
DEBUG = 2

View File

@@ -1,95 +0,0 @@
# # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# """The config class."""
# from dataclasses import dataclass
#
# from primaite.common.enums import ActionType
#
#
# @dataclass()
# class TrainingConfig:
# """Class to hold main config values."""
#
# # Generic
# agent_identifier: str # The Red Agent algo/class to be used
# action_type: ActionType # type of action to use (NODE/ACL/ANY)
# num_episodes: int # number of episodes to train over
# num_steps: int # number of steps in an episode
# time_delay: int # delay between steps (ms) - applies to generic agents only
# # file
# session_type: str # the session type to run (TRAINING or EVALUATION)
# load_agent: str # Determine whether to load an agent from file
# agent_load_file: str # File path and file name of agent if you're loading one in
#
# # Environment
# observation_space_high_value: int # The high value for the observation space
#
# # Reward values
# # Generic
# all_ok: int
# # Node Hardware State
# off_should_be_on: int
# off_should_be_resetting: int
# on_should_be_off: int
# on_should_be_resetting: int
# resetting_should_be_on: int
# resetting_should_be_off: int
# resetting: int
# # Node Software or Service State
# good_should_be_patching: int
# good_should_be_compromised: int
# good_should_be_overwhelmed: int
# patching_should_be_good: int
# patching_should_be_compromised: int
# patching_should_be_overwhelmed: int
# patching: int
# compromised_should_be_good: int
# compromised_should_be_patching: int
# compromised_should_be_overwhelmed: int
# compromised: int
# overwhelmed_should_be_good: int
# overwhelmed_should_be_patching: int
# overwhelmed_should_be_compromised: int
# overwhelmed: int
# # Node File System State
# good_should_be_repairing: int
# good_should_be_restoring: int
# good_should_be_corrupt: int
# good_should_be_destroyed: int
# repairing_should_be_good: int
# repairing_should_be_restoring: int
# repairing_should_be_corrupt: int
# repairing_should_be_destroyed: int # Repairing does not fix destroyed state - you need to restore
#
# repairing: int
# restoring_should_be_good: int
# restoring_should_be_repairing: int
# restoring_should_be_corrupt: int # Not the optimal method (as repair will fix corruption)
#
# restoring_should_be_destroyed: int
# restoring: int
# corrupt_should_be_good: int
# corrupt_should_be_repairing: int
# corrupt_should_be_restoring: int
# corrupt_should_be_destroyed: int
# corrupt: int
# destroyed_should_be_good: int
# destroyed_should_be_repairing: int
# destroyed_should_be_restoring: int
# destroyed_should_be_corrupt: int
# destroyed: int
# scanning: int
# # IER status
# red_ier_running: int
# green_ier_blocked: int
#
# # Patching / Reset
# os_patching_duration: int # The time taken to patch the OS
# node_reset_duration: int # The time taken to reset a node (hardware)
# node_booting_duration = 0 # The Time taken to turn on the node
# node_shutdown_duration = 0 # The time taken to turn off the node
# service_patching_duration: int # The time taken to patch a service
# file_system_repairing_limit: int # The time take to repair a file
# file_system_restoring_limit: int # The time take to restore a file
# file_system_scanning_limit: int # The time taken to scan the file system
# # Patching / Reset
#

View File

@@ -1,7 +1,3 @@
- item_type: ACTIONS
type: NODE
- item_type: STEPS
steps: 128
- item_type: PORTS
ports_list:
- port: '80'

View File

@@ -1,7 +1,3 @@
- item_type: ACTIONS
type: NODE
- item_type: STEPS
steps: 128
- item_type: PORTS
ports_list:
- port: '80'

View File

@@ -1,7 +1,3 @@
- item_type: ACTIONS
type: NODE
- item_type: STEPS
steps: 256
- item_type: PORTS
ports_list:
- port: '80'

View File

@@ -1,16 +1,42 @@
# Main Config File
# Training Config File
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_identifier: STABLE_BASELINES3_A2C
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: SB3
# RED AGENT IDENTIFIER
# RANDOM or NONE
# Sets which deep learning framework will be used (by RLlib ONLY).
# Default is TF (Tensorflow).
# Options are:
# "TF" (Tensorflow)
# TF2 (Tensorflow 2.X)
# TORCH (PyTorch)
deep_learning_framework: TF2
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: PPO
# Sets whether Red Agent POL and IER is randomised.
# Options are:
# True
# False
random_red_agent: False
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are:
# "BASIC" (The current observation space only)
# "FULL" (Full environment view with actions taken and reward feedback)
hard_coded_agent_view: FULL
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
@@ -25,21 +51,34 @@ observation_space:
# - name: LINK_TRAFFIC_LEVELS
# Number of episodes to run per session
num_episodes: 10
# Number of time_steps per episode
num_steps: 256
# Time delay between steps (for generic agents)
time_delay: 10
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAINING
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10
checkpoint_every_n_episodes: 10
# Time delay (milliseconds) between steps for CUSTOM agents.
time_delay: 5
# Type of session to be run. Options are:
# "TRAIN" (Trains an agent)
# "EVAL" (Evaluates an agent)
# "TRAIN_EVAL" (Trains then evaluates an agent)
session_type: TRAIN_EVAL
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)
# "INFO" (Info Messages (such as devices and wrappers used))
# "DEBUG" (All Messages)
sb3_output_verbose_level: NONE
# Reward values
# Generic
all_ok: 0

View File

@@ -7,8 +7,10 @@
# "GENERIC"
agent_identifier: STABLE_BASELINES3_A2C
# RED AGENT IDENTIFIER
# RANDOM or NONE
# Sets whether Red Agent POL and IER is randomised.
# Options are:
# True
# False
random_red_agent: True
# Sets How the Action Space is defined:

View File

@@ -1,14 +1,57 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
from pathlib import Path
from typing import Final
from typing import Any, Dict, Final, Union
from primaite import USERS_CONFIG_DIR, getLogger
import yaml
from primaite import getLogger, USERS_CONFIG_DIR
_LOGGER = getLogger(__name__)
_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"
def convert_legacy_lay_down_config_dict(legacy_config_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert a legacy lay down config dict to the new format.
:param legacy_config_dict: A legacy lay down config dict.
"""
_LOGGER.warning("Legacy lay down config conversion not yet implemented")
return legacy_config_dict
def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict:
"""
Read in a lay down config yaml file.
:param file_path: The config file path.
:param legacy_file: True if the config file is legacy format, otherwise
False.
:return: The lay down config as a dict.
:raises ValueError: If the file_path does not exist.
"""
if not isinstance(file_path, Path):
file_path = Path(file_path)
if file_path.exists():
with open(file_path, "r") as file:
config = yaml.safe_load(file)
_LOGGER.debug(f"Loading lay down config file: {file_path}")
if legacy_file:
try:
config = convert_legacy_lay_down_config_dict(config)
except KeyError:
msg = (
f"Failed to convert lay down config file {file_path} "
f"from legacy format. Attempting to use file as is."
)
_LOGGER.error(msg)
return config
msg = f"Cannot load the lay down config as it does not exist: {file_path}"
_LOGGER.error(msg)
raise ValueError(msg)
def ddos_basic_one_config_path() -> Path:
"""
The path to the example lay_down_config_1_DDOS_basic.yaml file.

View File

@@ -1,58 +1,96 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Final, Optional, Union
import yaml
from primaite import USERS_CONFIG_DIR, getLogger
from primaite.common.enums import ActionType
from primaite import getLogger, USERS_CONFIG_DIR
from primaite.common.enums import (
ActionType,
AgentFramework,
AgentIdentifier,
DeepLearningFramework,
HardCodedAgentView,
SB3OutputVerboseLevel,
SessionType,
)
_LOGGER = getLogger(__name__)
_EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training"
def main_training_config_path() -> Path:
"""
The path to the example training_config_main.yaml file.
:return: The file path.
"""
path = _EXAMPLE_TRAINING / "training_config_main.yaml"
if not path.exists():
msg = "Example config not found. Please run 'primaite setup'"
_LOGGER.critical(msg)
raise FileNotFoundError(msg)
return path
@dataclass()
class TrainingConfig:
"""The Training Config class."""
# Generic
agent_identifier: str = "STABLE_BASELINES3_A2C"
"The Red Agent algo/class to be used."
agent_framework: AgentFramework = AgentFramework.SB3
"The AgentFramework"
deep_learning_framework: DeepLearningFramework = DeepLearningFramework.TF
"The DeepLearningFramework"
agent_identifier: AgentIdentifier = AgentIdentifier.PPO
"The AgentIdentifier"
hard_coded_agent_view: HardCodedAgentView = HardCodedAgentView.FULL
"The view the deterministic hard-coded agent has of the environment"
random_red_agent: bool = False
"Creates Random Red Agent Attacks"
action_type: ActionType = ActionType.ANY
"The ActionType to use."
"The ActionType to use"
num_episodes: int = 10
"The number of episodes to train over."
"The number of episodes to train over"
num_steps: int = 256
"The number of steps in an episode."
observation_space: dict = field(
default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]}
)
"The observation space config dict."
"The number of steps in an episode"
checkpoint_every_n_episodes: int = 5
"The agent will save a checkpoint every n episodes"
observation_space: dict = field(default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]})
"The observation space config dict"
time_delay: int = 10
"The delay between steps (ms). Applies to generic agents only."
"The delay between steps (ms). Applies to generic agents only"
# file
session_type: str = "TRAINING"
"the session type to run (TRAINING or EVALUATION)"
session_type: SessionType = SessionType.TRAIN
"The type of PrimAITE session to run"
load_agent: str = False
"Determine whether to load an agent from file."
"Determine whether to load an agent from file"
agent_load_file: Optional[str] = None
"File path and file name of agent if you're loading one in."
"File path and file name of agent if you're loading one in"
# Environment
observation_space_high_value: int = 1000000000
"The high value for the observation space."
"The high value for the observation space"
sb3_output_verbose_level: SB3OutputVerboseLevel = SB3OutputVerboseLevel.NONE
"Stable Baselines3 learn/eval output verbosity level"
# Reward values
# Generic
@@ -117,28 +155,51 @@ class TrainingConfig:
# Patching / Reset durations
os_patching_duration: int = 5
"The time taken to patch the OS."
"The time taken to patch the OS"
node_reset_duration: int = 5
"The time taken to reset a node (hardware)."
"The time taken to reset a node (hardware)"
node_booting_duration: int = 3
"The Time taken to turn on the node."
"The Time taken to turn on the node"
node_shutdown_duration: int = 2
"The time taken to turn off the node."
"The time taken to turn off the node"
service_patching_duration: int = 5
"The time taken to patch a service."
"The time taken to patch a service"
file_system_repairing_limit: int = 5
"The time take to repair the file system."
"The time take to repair the file system"
file_system_restoring_limit: int = 5
"The time take to restore the file system."
"The time take to restore the file system"
file_system_scanning_limit: int = 5
"The time taken to scan the file system."
"The time taken to scan the file system"
@classmethod
def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig:
"""
Create an instance of TrainingConfig from a dict.
:param config_dict: The training config dict.
:return: The instance of TrainingConfig.
"""
field_enum_map = {
"agent_framework": AgentFramework,
"deep_learning_framework": DeepLearningFramework,
"agent_identifier": AgentIdentifier,
"action_type": ActionType,
"session_type": SessionType,
"sb3_output_verbose_level": SB3OutputVerboseLevel,
"hard_coded_agent_view": HardCodedAgentView,
}
for key, value in field_enum_map.items():
if key in config_dict:
config_dict[key] = value[config_dict[key]]
return TrainingConfig(**config_dict)
def to_dict(self, json_serializable: bool = True):
"""
@@ -150,24 +211,28 @@ class TrainingConfig:
"""
data = self.__dict__
if json_serializable:
data["action_type"] = self.action_type.value
data["agent_framework"] = self.agent_framework.name
data["deep_learning_framework"] = self.deep_learning_framework.name
data["agent_identifier"] = self.agent_identifier.name
data["action_type"] = self.action_type.name
data["sb3_output_verbose_level"] = self.sb3_output_verbose_level.name
data["session_type"] = self.session_type.name
data["hard_coded_agent_view"] = self.hard_coded_agent_view.name
return data
def main_training_config_path() -> Path:
"""
The path to the example training_config_main.yaml file.
:return: The file path.
"""
path = _EXAMPLE_TRAINING / "training_config_main.yaml"
if not path.exists():
msg = "Example config not found. Please run 'primaite setup'"
_LOGGER.critical(msg)
raise FileNotFoundError(msg)
return path
def __str__(self) -> str:
tc = f"{self.agent_framework}, "
if self.agent_framework is AgentFramework.RLLIB:
tc += f"{self.deep_learning_framework}, "
tc += f"{self.agent_identifier}, "
if self.agent_identifier is AgentIdentifier.HARDCODED:
tc += f"{self.hard_coded_agent_view}, "
tc += f"{self.action_type}, "
tc += f"observation_space={self.observation_space}, "
tc += f"{self.num_episodes} episodes @ "
tc += f"{self.num_steps} steps"
return tc
def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig:
@@ -198,15 +263,10 @@ def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConf
f"from legacy format. Attempting to use file as is."
)
_LOGGER.error(msg)
# Convert values to Enums
config["action_type"] = ActionType[config["action_type"]]
try:
return TrainingConfig(**config)
return TrainingConfig.from_dict(config)
except TypeError as e:
msg = (
f"Error when creating an instance of {TrainingConfig} "
f"from the training config file {file_path}"
)
msg = f"Error when creating an instance of {TrainingConfig} " f"from the training config file {file_path}"
_LOGGER.critical(msg, exc_info=True)
raise e
msg = f"Cannot load the training config as it does not exist: {file_path}"
@@ -215,19 +275,35 @@ def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConf
def convert_legacy_training_config_dict(
legacy_config_dict: Dict[str, Any], num_steps: int = 256, action_type: str = "ANY"
legacy_config_dict: Dict[str, Any],
agent_framework: AgentFramework = AgentFramework.SB3,
agent_identifier: AgentIdentifier = AgentIdentifier.PPO,
action_type: ActionType = ActionType.ANY,
num_steps: int = 256,
) -> Dict[str, Any]:
"""
Convert a legacy training config dict to the new format.
:param legacy_config_dict: A legacy training config dict.
:param num_steps: The number of steps to set as legacy training configs
don't have num_steps values.
:param agent_framework: The agent framework to use as legacy training
configs don't have agent_framework values.
:param agent_identifier: The red agent identifier to use as legacy
training configs don't have agent_identifier values.
:param action_type: The action space type to set as legacy training configs
don't have action_type values.
:param num_steps: The number of steps to set as legacy training configs
don't have num_steps values.
:return: The converted training config dict.
"""
config_dict = {"num_steps": num_steps, "action_type": action_type}
config_dict = {
"agent_framework": agent_framework.name,
"agent_identifier": agent_identifier.name,
"action_type": action_type.name,
"num_steps": num_steps,
"sb3_output_verbose_level": SB3OutputVerboseLevel.INFO.name,
}
session_type_map = {"TRAINING": "TRAIN", "EVALUATION": "EVAL"}
legacy_config_dict["sessionType"] = session_type_map[legacy_config_dict["sessionType"]]
for legacy_key, value in legacy_config_dict.items():
new_key = _get_new_key_from_legacy(legacy_key)
if new_key:
@@ -243,7 +319,7 @@ def _get_new_key_from_legacy(legacy_key: str) -> str:
:return: The mapped key.
"""
key_mapping = {
"agentIdentifier": "agent_identifier",
"agentIdentifier": None,
"numEpisodes": "num_episodes",
"timeDelay": "time_delay",
"configFilename": None,

View File

@@ -0,0 +1,13 @@
from enum import Enum
class PlotlyTemplate(Enum):
"""The built-in plotly templates."""
PLOTLY = "plotly"
PLOTLY_WHITE = "plotly_white"
PLOTLY_DARK = "plotly_dark"
GGPLOT2 = "ggplot2"
SEABORN = "seaborn"
SIMPLE_WHITE = "simple_white"
NONE = "none"

View File

@@ -0,0 +1,73 @@
from pathlib import Path
from typing import Dict, Optional, Union
import plotly.graph_objects as go
import polars as pl
import yaml
from plotly.graph_objs import Figure
from primaite import _PLATFORM_DIRS
def _get_plotly_config() -> Dict:
"""Get the plotly config from primaite_config.yaml."""
user_config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml"
with open(user_config_path, "r") as file:
primaite_config = yaml.safe_load(file)
return primaite_config["session"]["outputs"]["plots"]
def plot_av_reward_per_episode(
av_reward_per_episode_csv: Union[str, Path],
title: Optional[str] = None,
subtitle: Optional[str] = None,
) -> Figure:
"""
Plot the average reward per episode from a csv session output.
:param av_reward_per_episode_csv: The average reward per episode csv
file path.
:param title: The plot title. This is optional.
:param subtitle: The plot subtitle. This is optional.
:return: The plot as an instance of ``plotly.graph_objs._figure.Figure``.
"""
df = pl.read_csv(av_reward_per_episode_csv)
if title:
if subtitle:
title = f"{title} <br>{subtitle}</sup>"
else:
if subtitle:
title = subtitle
config = _get_plotly_config()
layout = go.Layout(
autosize=config["size"]["auto_size"],
width=config["size"]["width"],
height=config["size"]["height"],
)
# Create the line graph with a colored line
fig = go.Figure(layout=layout)
fig.update_layout(template=config["template"])
fig.add_trace(
go.Scatter(
x=df["Episode"],
y=df["Average Reward"],
mode="lines",
name="Mean Reward per Episode",
)
)
# Set the layout of the graph
fig.update_layout(
xaxis={
"title": "Episode",
"type": "linear",
"rangeslider": {"visible": config["range_slider"]},
},
yaxis={"title": "Average Reward"},
title=title,
showlegend=False,
)
return fig

View File

@@ -1,7 +1,7 @@
"""Module for handling configurable observation spaces in PrimAITE."""
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Dict, Final, List, Tuple, Union
from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union
import numpy as np
from gym import spaces
@@ -63,7 +63,7 @@ class NodeLinkTable(AbstractObservationComponent):
"""
_FIXED_PARAMETERS: int = 4
_MAX_VAL: int = 1_000_000
_MAX_VAL: int = 1_000_000_000
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite"):
@@ -101,9 +101,7 @@ class NodeLinkTable(AbstractObservationComponent):
self.current_observation[item_index][1] = node.hardware_state.value
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
self.current_observation[item_index][2] = node.software_state.value
self.current_observation[item_index][
3
] = node.file_system_state_observed.value
self.current_observation[item_index][3] = node.file_system_state_observed.value
else:
self.current_observation[item_index][2] = 0
self.current_observation[item_index][3] = 0
@@ -111,9 +109,7 @@ class NodeLinkTable(AbstractObservationComponent):
if isinstance(node, ServiceNode):
for service in self.env.services_list:
if node.has_service(service):
self.current_observation[item_index][
service_index
] = node.get_service_state(service).value
self.current_observation[item_index][service_index] = node.get_service_state(service).value
else:
self.current_observation[item_index][service_index] = 0
service_index += 1
@@ -133,9 +129,7 @@ class NodeLinkTable(AbstractObservationComponent):
protocol_list = link.get_protocol_list()
protocol_index = 0
for protocol in protocol_list:
self.current_observation[item_index][
protocol_index + 4
] = protocol.get_load()
self.current_observation[item_index][protocol_index + 4] = protocol.get_load()
protocol_index += 1
item_index += 1
@@ -244,7 +238,12 @@ class NodeStatuses(AbstractObservationComponent):
if node.has_service(service):
service_states[i] = node.get_service_state(service).value
obs.extend(
[hardware_state, software_state, file_system_state, *service_states]
[
hardware_state,
software_state,
file_system_state,
*service_states,
]
)
self.current_observation[:] = obs
@@ -267,9 +266,7 @@ class NodeStatuses(AbstractObservationComponent):
for service in services:
structure.append(f"node_{node_id}_service_{service}_state_NONE")
for state in SoftwareState:
structure.append(
f"node_{node_id}_service_{service}_state_{state.name}"
)
structure.append(f"node_{node_id}_service_{service}_state_{state.name}")
return structure
@@ -325,9 +322,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
self._entries_per_link = self.env.num_services
# 1. Define the shape of your observation space component
shape = (
[self._quantisation_levels] * self.env.num_links * self._entries_per_link
)
shape = [self._quantisation_levels] * self.env.num_links * self._entries_per_link
# 2. Create Observation space
self.space = spaces.MultiDiscrete(shape)
@@ -356,9 +351,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
elif load >= bandwidth:
traffic_level = self._quantisation_levels - 1
else:
traffic_level = (load / bandwidth) // (
1 / (self._quantisation_levels - 2)
) + 1
traffic_level = (load / bandwidth) // (1 / (self._quantisation_levels - 2)) + 1
obs.append(int(traffic_level))

View File

@@ -1,13 +1,11 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""Main environment module containing the PRIMmary AI Training Evironment (Primaite) class."""
import copy
import csv
import logging
import uuid as uuid
from datetime import datetime
from pathlib import Path
from random import choice, randint, sample, uniform
from typing import Dict, Tuple, Union
from typing import Dict, Final, Tuple, Union
import networkx as nx
import numpy as np
@@ -15,11 +13,13 @@ import yaml
from gym import Env, spaces
from matplotlib import pyplot as plt
from primaite import getLogger
from primaite.acl.access_control_list import AccessControlList
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
ActionType,
AgentFramework,
FileSystemState,
HardwareState,
NodePOLInitiator,
@@ -27,6 +27,7 @@ from primaite.common.enums import (
NodeType,
ObservationType,
Priority,
SessionType,
SoftwareState,
)
from primaite.common.service import Service
@@ -45,9 +46,9 @@ from primaite.pol.green_pol import apply_iers, apply_node_pol
from primaite.pol.ier import IER
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
from primaite.transactions.transaction import Transaction
from primaite.utils.session_output_writer import SessionOutputWriter
_LOGGER = logging.getLogger(__name__)
_LOGGER.setLevel(logging.INFO)
_LOGGER = getLogger(__name__)
class Primaite(Env):
@@ -63,7 +64,6 @@ class Primaite(Env):
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
transaction_list,
session_path: Path,
timestamp_str: str,
):
@@ -72,26 +72,23 @@ class Primaite(Env):
:param training_config_path: The training config filepath.
:param lay_down_config_path: The lay down config filepath.
:param transaction_list: The list of transactions to populate.
:param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
self.session_path: Final[Path] = session_path
self.timestamp_str: Final[str] = timestamp_str
self._training_config_path = training_config_path
self._lay_down_config_path = lay_down_config_path
self.training_config: TrainingConfig = training_config.load(
training_config_path
)
self.training_config: TrainingConfig = training_config.load(training_config_path)
_LOGGER.info(f"Using: {str(self.training_config)}")
# Number of steps in an episode
self.episode_steps = self.training_config.num_steps
super(Primaite, self).__init__()
# Transaction list
self.transaction_list = transaction_list
# The agent in use
self.agent_identifier = self.training_config.agent_identifier
@@ -178,6 +175,9 @@ class Primaite(Env):
# It will be initialised later.
self.obs_handler: ObservationsHandler
self._obs_space_description = None
"The env observation space description for transactions writing"
# Open the config file and build the environment laydown
with open(self._lay_down_config_path, "r") as file:
# Open the config file and build the environment laydown
@@ -204,16 +204,14 @@ class Primaite(Env):
plt.savefig(file_path, format="PNG")
plt.clf()
except Exception:
_LOGGER.error("Could not save network diagram")
_LOGGER.error("Exception occured", exc_info=True)
print("Could not save network diagram")
_LOGGER.error("Could not save network diagram", exc_info=True)
# Initiate observation space
self.observation_space, self.env_obs = self.init_observations()
# Define Action Space - depends on action space type (Node or ACL)
if self.training_config.action_type == ActionType.NODE:
_LOGGER.info("Action space type NODE selected")
_LOGGER.debug("Action space type NODE selected")
# Terms (for node action space):
# [0, num nodes] - node ID (0 = nothing, node ID)
# [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa
@@ -222,7 +220,7 @@ class Primaite(Env):
self.action_dict = self.create_node_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
elif self.training_config.action_type == ActionType.ACL:
_LOGGER.info("Action space type ACL selected")
_LOGGER.debug("Action space type ACL selected")
# Terms (for ACL action space):
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
# [0, 1] - Permission (0 = DENY, 1 = ALLOW)
@@ -233,27 +231,29 @@ class Primaite(Env):
self.action_dict = self.create_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
elif self.training_config.action_type == ActionType.ANY:
_LOGGER.info("Action space type ANY selected - Node + ACL")
_LOGGER.debug("Action space type ANY selected - Node + ACL")
self.action_dict = self.create_node_and_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
else:
_LOGGER.info(
f"Invalid action type selected: {self.training_config.action_type}"
)
# Set up a csv to store the results of the training
try:
header = ["Episode", "Average Reward"]
_LOGGER.error(f"Invalid action type selected: {self.training_config.action_type}")
file_name = f"average_reward_per_episode_{timestamp_str}.csv"
file_path = session_path / file_name
self.csv_file = open(file_path, "w", encoding="UTF8", newline="")
self.csv_writer = csv.writer(self.csv_file)
self.csv_writer.writerow(header)
except Exception:
_LOGGER.error(
"Could not create csv file to hold average reward per episode"
)
_LOGGER.error("Exception occured", exc_info=True)
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=True)
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=True)
@property
def actual_episode_count(self) -> int:
"""Shifts the episode_count by -1 for RLlib."""
if self.training_config.agent_framework is AgentFramework.RLLIB:
return self.episode_count - 1
return self.episode_count
def set_as_eval(self):
"""Set the writers to write to eval directories."""
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=False)
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=False)
self.episode_count = 0
self.step_count = 0
self.total_step_count = 0
def reset(self):
"""
@@ -262,12 +262,14 @@ class Primaite(Env):
Returns:
Environment observation space (reset)
"""
csv_data = self.episode_count, self.average_reward
self.csv_writer.writerow(csv_data)
if self.actual_episode_count > 0:
csv_data = self.actual_episode_count, self.average_reward
self.episode_av_reward_writer.write(csv_data)
self.episode_count += 1
# Don't need to reset links, as they are cleared and recalculated every step
# Don't need to reset links, as they are cleared and recalculated every
# step
# Clear the ACL
self.init_acl()
@@ -287,6 +289,7 @@ class Primaite(Env):
# Update observations space and return
self.update_environent_obs()
return self.env_obs
def step(self, action):
@@ -302,15 +305,10 @@ class Primaite(Env):
done: Indicates episode is complete if True
step_info: Additional information relating to this step
"""
if self.step_count == 0:
print(f"Episode: {str(self.episode_count)}")
# TEMP
done = False
self.step_count += 1
self.total_step_count += 1
# print("Episode step: " + str(self.step_count))
# Need to clear traffic on all links first
for link_key, link_value in self.links.items():
@@ -320,14 +318,15 @@ class Primaite(Env):
link.clear_traffic()
# Create a Transaction (metric) object for this step
transaction = Transaction(
datetime.now(), self.agent_identifier, self.episode_count, self.step_count
)
transaction = Transaction(self.agent_identifier, self.actual_episode_count, self.step_count)
# Load the initial observation space into the transaction
transaction.set_obs_space(self.obs_handler._flat_observation)
transaction.obs_space = self.obs_handler._flat_observation
# Set the transaction obs space description
transaction.obs_space_description = self._obs_space_description
# Load the action space into the transaction
transaction.set_action_space(copy.deepcopy(action))
transaction.action_space = copy.deepcopy(action)
# 1. Implement Blue Action
self.interpret_action_and_apply(action)
@@ -371,9 +370,7 @@ class Primaite(Env):
self.acl,
self.step_count,
)
apply_red_agent_node_pol(
self.nodes, self.red_iers, self.red_node_pol, self.step_count
)
apply_red_agent_node_pol(self.nodes, self.red_iers, self.red_node_pol, self.step_count)
# Take snapshots of nodes and links
self.nodes_post_red = copy.deepcopy(self.nodes)
self.links_post_red = copy.deepcopy(self.links)
@@ -389,17 +386,17 @@ class Primaite(Env):
self.step_count,
self.training_config,
)
# print(f" Step {self.step_count} Reward: {str(reward)}")
_LOGGER.debug(f"Episode: {self.actual_episode_count}, " f"Step {self.step_count}, " f"Reward: {reward}")
self.total_reward += reward
if self.step_count == self.episode_steps:
self.average_reward = self.total_reward / self.step_count
if self.training_config.session_type == "EVALUATION":
if self.training_config.session_type is SessionType.EVAL:
# For evaluation, need to trigger the done value = True when
# step count is reached in order to prevent neverending episode
done = True
print(f" Average Reward: {str(self.average_reward)}")
_LOGGER.info(f"Episode: {self.actual_episode_count}, " f"Average Reward: {self.average_reward}")
# Load the reward into the transaction
transaction.set_reward(reward)
transaction.reward = reward
# 6. Output Verbose
# self.output_link_status()
@@ -407,15 +404,21 @@ class Primaite(Env):
# 7. Update env_obs
self.update_environent_obs()
# 8. Add the transaction to the list of transactions
self.transaction_list.append(copy.deepcopy(transaction))
# Write transaction to file
if self.actual_episode_count > 0:
self.transaction_writer.write(transaction)
# Return
return self.env_obs, reward, done, self.step_info
def __close__(self):
"""Override close function."""
self.csv_file.close()
def close(self):
"""Override parent close and close writers."""
# Close files if last episode/step
# if self.can_finish:
super().close()
self.transaction_writer.close()
self.episode_av_reward_writer.close()
def init_acl(self):
"""Initialise the Access Control List."""
@@ -424,14 +427,9 @@ class Primaite(Env):
def output_link_status(self):
"""Output the link status of all links to the console."""
for link_key, link_value in self.links.items():
print("Link ID: " + link_value.get_id())
_LOGGER.debug("Link ID: " + link_value.get_id())
for protocol in link_value.protocol_list:
print(
" Protocol: "
+ protocol.get_name().name
+ ", Load: "
+ str(protocol.get_load())
)
print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load()))
def interpret_action_and_apply(self, _action):
"""
@@ -446,13 +444,9 @@ class Primaite(Env):
self.apply_actions_to_nodes(_action)
elif self.training_config.action_type == ActionType.ACL:
self.apply_actions_to_acl(_action)
elif (
len(self.action_dict[_action]) == 6
): # ACL actions in multidiscrete form have len 6
elif len(self.action_dict[_action]) == 6: # ACL actions in multidiscrete form have len 6
self.apply_actions_to_acl(_action)
elif (
len(self.action_dict[_action]) == 4
): # Node actions in multdiscrete (array) from have len 4
elif len(self.action_dict[_action]) == 4: # Node actions in multdiscrete (array) from have len 4
self.apply_actions_to_nodes(_action)
else:
logging.error("Invalid action type found")
@@ -518,9 +512,7 @@ class Primaite(Env):
return
elif property_action == 1:
# Patch (valid action if it's good or compromised)
node.set_service_state(
self.services_list[service_index], SoftwareState.PATCHING
)
node.set_service_state(self.services_list[service_index], SoftwareState.PATCHING)
else:
# Node is not of Service Type
return
@@ -675,6 +667,9 @@ class Primaite(Env):
"""
self.obs_handler = ObservationsHandler.from_config(self, self.obs_config)
if not self._obs_space_description:
self._obs_space_description = self.obs_handler.describe_structure()
return self.obs_handler.space, self.obs_handler.current_observation
def update_environent_obs(self):
@@ -1225,11 +1220,7 @@ class Primaite(Env):
# Change node keys to not overlap with acl keys
# Only 1 nothing action (key 0) is required, remove the other
new_node_action_dict = {
k + len(acl_action_dict) - 1: v
for k, v in node_action_dict.items()
if k != 0
}
new_node_action_dict = {k + len(acl_action_dict) - 1: v for k, v in node_action_dict.items() if k != 0}
# Combine the Node dict and ACL dict
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
@@ -1244,9 +1235,7 @@ class Primaite(Env):
# Decide how many nodes become compromised
node_list = list(self.nodes.values())
computers = [node for node in node_list if node.node_type == NodeType.COMPUTER]
max_num_nodes_compromised = len(
computers
) # only computers can become compromised
max_num_nodes_compromised = len(computers) # only computers can become compromised
# random select between 1 and max_num_nodes_compromised
num_nodes_to_compromise = randint(1, max_num_nodes_compromised)
@@ -1257,9 +1246,7 @@ class Primaite(Env):
source_node = choice(nodes_to_be_compromised)
# For each of the nodes to be compromised decide which step they become compromised
max_step_compromised = (
self.episode_steps // 2
) # always compromise in first half of episode
max_step_compromised = self.episode_steps // 2 # always compromise in first half of episode
# Bandwidth for all links
bandwidths = [i.get_bandwidth() for i in list(self.links.values())]
@@ -1308,9 +1295,7 @@ class Primaite(Env):
ier_protocol = pol_service_name # Same protocol as compromised node
ier_service = node.services[pol_service_name]
ier_port = ier_service.port
ier_mission_criticality = (
0 # Red IER will never be important to green agent success
)
ier_mission_criticality = 0 # Red IER will never be important to green agent success
# We choose a node to attack based on the first that applies:
# a. Green IERs, select dest node of the red ier based on dest node of green IER
# b. Attack a random server that doesn't have a DENY acl rule in default config

View File

@@ -41,27 +41,19 @@ def calculate_reward_function(
reference_node = reference_nodes[node_key]
# Hardware State
reward_value += score_node_operating_state(
final_node, initial_node, reference_node, config_values
)
reward_value += score_node_operating_state(final_node, initial_node, reference_node, config_values)
# Software State
if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode):
reward_value += score_node_os_state(
final_node, initial_node, reference_node, config_values
)
reward_value += score_node_os_state(final_node, initial_node, reference_node, config_values)
# Service State
if isinstance(final_node, ServiceNode):
reward_value += score_node_service_state(
final_node, initial_node, reference_node, config_values
)
reward_value += score_node_service_state(final_node, initial_node, reference_node, config_values)
# File System State
if isinstance(final_node, ActiveNode):
reward_value += score_node_file_system(
final_node, initial_node, reference_node, config_values
)
reward_value += score_node_file_system(final_node, initial_node, reference_node, config_values)
# Go through each red IER - penalise if it is running
for ier_key, ier_value in red_iers.items():
@@ -80,14 +72,9 @@ def calculate_reward_function(
if step_count >= start_step and step_count <= stop_step:
reference_blocked = not reference_ier.get_is_running()
live_blocked = not ier_value.get_is_running()
ier_reward = (
config_values.green_ier_blocked * ier_value.get_mission_criticality()
)
ier_reward = config_values.green_ier_blocked * ier_value.get_mission_criticality()
if live_blocked and not reference_blocked:
_LOGGER.debug(
f"Applying reward of {ier_reward} because IER {ier_key} is blocked"
)
reward_value += ier_reward
elif live_blocked and reference_blocked:
_LOGGER.debug(

View File

@@ -1,338 +1,29 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""
The main PrimAITE session runner module.
TODO: This will eventually be refactored out into a proper Session class.
TODO: The passing about of session_dir and timestamp_str is temporary and
will be cleaned up once we move to a proper Session class.
"""
"""The main PrimAITE session runner module."""
import argparse
import json
import time
from datetime import datetime
from pathlib import Path
from typing import Final, Union
from uuid import uuid4
from typing import Union
from stable_baselines3 import A2C, PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.ppo import MlpPolicy as PPOMlp
from primaite import SESSIONS_DIR, getLogger
from primaite.config.training_config import TrainingConfig
from primaite.environment.primaite_env import Primaite
from primaite.transactions.transactions_to_file import write_transaction_to_file
from primaite import getLogger
from primaite.primaite_session import PrimaiteSession
_LOGGER = getLogger(__name__)
def run_generic(env: Primaite, config_values: TrainingConfig):
"""
Run against a generic agent.
:param env: An instance of
:class:`~primaite.environment.primaite_env.Primaite`.
:param config_values: An instance of
:class:`~primaite.config.training_config.TrainingConfig`.
"""
for episode in range(0, config_values.num_episodes):
env.reset()
for step in range(0, config_values.num_steps):
# Send the observation space to the agent to get an action
# TEMP - random action for now
# action = env.blue_agent_action(obs)
action = env.action_space.sample()
# Run the simulation step on the live environment
obs, reward, done, info = env.step(action)
# Break if done is True
if done:
break
# Introduce a delay between steps
time.sleep(config_values.time_delay / 1000)
# Reset the environment at the end of the episode
env.close()
def run_stable_baselines3_ppo(
env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str
def run(
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
):
"""
Run against a stable_baselines3 PPO agent.
:param env: An instance of
:class:`~primaite.environment.primaite_env.Primaite`.
:param config_values: An instance of
:class:`~primaite.config.training_config.TrainingConfig`.
:param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
if config_values.load_agent:
try:
agent = PPO.load(
config_values.agent_load_file,
env,
verbose=0,
n_steps=config_values.num_steps,
)
except Exception:
print(
"ERROR: Could not load agent at location: "
+ config_values.agent_load_file
)
_LOGGER.error("Could not load agent")
_LOGGER.error("Exception occured", exc_info=True)
else:
agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps)
if config_values.session_type == "TRAINING":
# We're in a training session
print("Starting training session...")
_LOGGER.debug("Starting training session...")
for episode in range(config_values.num_episodes):
agent.learn(total_timesteps=config_values.num_steps)
_save_agent(agent, session_path, timestamp_str)
else:
# Default to being in an evaluation session
print("Starting evaluation session...")
_LOGGER.debug("Starting evaluation session...")
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
env.close()
def run_stable_baselines3_a2c(
env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str
):
"""
Run against a stable_baselines3 A2C agent.
:param env: An instance of
:class:`~primaite.environment.primaite_env.Primaite`.
:param config_values: An instance of
:class:`~primaite.config.training_config.TrainingConfig`.
param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
if config_values.load_agent:
try:
agent = A2C.load(
config_values.agent_load_file,
env,
verbose=0,
n_steps=config_values.num_steps,
)
except Exception:
print(
"ERROR: Could not load agent at location: "
+ config_values.agent_load_file
)
_LOGGER.error("Could not load agent")
_LOGGER.error("Exception occured", exc_info=True)
else:
agent = A2C("MlpPolicy", env, verbose=0, n_steps=config_values.num_steps)
if config_values.session_type == "TRAINING":
# We're in a training session
print("Starting training session...")
_LOGGER.debug("Starting training session...")
for episode in range(config_values.num_episodes):
agent.learn(total_timesteps=config_values.num_steps)
_save_agent(agent, session_path, timestamp_str)
else:
# Default to being in an evaluation session
print("Starting evaluation session...")
_LOGGER.debug("Starting evaluation session...")
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
env.close()
def _write_session_metadata_file(
session_dir: Path, uuid: str, session_timestamp: datetime, env: Primaite
):
"""
Write the ``session_metadata.json`` file.
Creates a ``session_metadata.json`` in the ``session_dir`` directory
and adds the following key/value pairs:
- uuid: The UUID assigned to the session upon instantiation.
- start_datetime: The date & time the session started in iso format.
- end_datetime: NULL.
- total_episodes: NULL.
- total_time_steps: NULL.
- env:
- training_config:
- All training config items
- lay_down_config:
- All lay down config items
"""
metadata_dict = {
"uuid": uuid,
"start_datetime": session_timestamp.isoformat(),
"end_datetime": None,
"total_episodes": None,
"total_time_steps": None,
"env": {
"training_config": env.training_config.to_dict(json_serializable=True),
"lay_down_config": env.lay_down_config,
},
}
filepath = session_dir / "session_metadata.json"
_LOGGER.debug(f"Writing Session Metadata file: {filepath}")
with open(filepath, "w") as file:
json.dump(metadata_dict, file)
def _update_session_metadata_file(session_dir: Path, env: Primaite):
"""
Update the ``session_metadata.json`` file.
Updates the `session_metadata.json`` in the ``session_dir`` directory
with the following key/value pairs:
- end_datetime: NULL.
- total_episodes: NULL.
- total_time_steps: NULL.
"""
with open(session_dir / "session_metadata.json", "r") as file:
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
metadata_dict["total_episodes"] = env.episode_count
metadata_dict["total_time_steps"] = env.total_step_count
filepath = session_dir / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
with open(filepath, "w") as file:
json.dump(metadata_dict, file)
def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str):
"""
Persist an agent.
Only works for stable baselines3 agents at present.
:param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
if not isinstance(agent, OnPolicyAlgorithm):
msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}."
_LOGGER.error(msg)
else:
filepath = session_path / f"agent_saved_{timestamp_str}"
agent.save(filepath)
_LOGGER.debug(f"Trained agent saved as: {filepath}")
def _get_session_path(session_timestamp: datetime) -> Path:
"""
Get the directory path the session will output to.
This is set in the format of:
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
:param session_timestamp: This is the datetime that the session started.
:return: The session directory path.
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = SESSIONS_DIR / date_dir / session_dir
session_path.mkdir(exist_ok=True, parents=True)
return session_path
def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]):
"""Run the PrimAITE Session.
:param training_config_path: The training config filepath.
:param lay_down_config_path: The lay down config filepath.
"""
# Welcome message
print("Welcome to the Primary-level AI Training Environment (PrimAITE)")
uuid = str(uuid4())
session_timestamp: Final[datetime] = datetime.now()
session_dir = _get_session_path(session_timestamp)
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session = PrimaiteSession(training_config_path, lay_down_config_path)
print(f"The output directory for this session is: {session_dir}")
# Create a list of transactions
# A transaction is an object holding the:
# - episode #
# - step #
# - initial observation space
# - action
# - reward
# - new observation space
transaction_list = []
# Create the Primaite environment
env = Primaite(
training_config_path=training_config_path,
lay_down_config_path=lay_down_config_path,
transaction_list=transaction_list,
session_path=session_dir,
timestamp_str=timestamp_str,
)
print("Writing Session Metadata file...")
_write_session_metadata_file(
session_dir=session_dir, uuid=uuid, session_timestamp=session_timestamp, env=env
)
config_values = env.training_config
# Get the number of steps (which is stored in the child config file)
config_values.num_steps = env.episode_steps
# Run environment against an agent
if config_values.agent_identifier == "GENERIC":
run_generic(env=env, config_values=config_values)
elif config_values.agent_identifier == "STABLE_BASELINES3_PPO":
run_stable_baselines3_ppo(
env=env,
config_values=config_values,
session_path=session_dir,
timestamp_str=timestamp_str,
)
elif config_values.agent_identifier == "STABLE_BASELINES3_A2C":
run_stable_baselines3_a2c(
env=env,
config_values=config_values,
session_path=session_dir,
timestamp_str=timestamp_str,
)
print("Session finished")
_LOGGER.debug("Session finished")
print("Saving transaction logs...")
write_transaction_to_file(
transaction_list=transaction_list,
session_path=session_dir,
timestamp_str=timestamp_str,
obs_space_description=env.obs_handler.describe_structure(),
)
print("Updating Session Metadata file...")
_update_session_metadata_file(session_dir=session_dir, env=env)
print("Finished")
_LOGGER.debug("Finished")
session.setup()
session.learn()
session.evaluate()
if __name__ == "__main__":
@@ -341,11 +32,7 @@ if __name__ == "__main__":
parser.add_argument("--ldc")
args = parser.parse_args()
if not args.tc:
_LOGGER.error(
"Please provide a training config file using the --tc " "argument"
)
_LOGGER.error("Please provide a training config file using the --tc " "argument")
if not args.ldc:
_LOGGER.error(
"Please provide a lay down config file using the --ldc " "argument"
)
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
run(training_config_path=args.tc, lay_down_config_path=args.ldc)

View File

@@ -3,13 +3,7 @@
import logging
from typing import Final
from primaite.common.enums import (
FileSystemState,
HardwareState,
NodeType,
Priority,
SoftwareState,
)
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
from primaite.config.training_config import TrainingConfig
from primaite.nodes.node import Node
@@ -44,9 +38,7 @@ class ActiveNode(Node):
:param file_system_state: The node file system state
:param config_values: The config values
"""
super().__init__(
node_id, name, node_type, priority, hardware_state, config_values
)
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)
self.ip_address: str = ip_address
# Related to Software
self._software_state: SoftwareState = software_state
@@ -125,14 +117,10 @@ class ActiveNode(Node):
self.file_system_state_actual = file_system_state
if file_system_state == FileSystemState.REPAIRING:
self.file_system_action_count = (
self.config_values.file_system_repairing_limit
)
self.file_system_action_count = self.config_values.file_system_repairing_limit
self.file_system_state_observed = FileSystemState.REPAIRING
elif file_system_state == FileSystemState.RESTORING:
self.file_system_action_count = (
self.config_values.file_system_restoring_limit
)
self.file_system_action_count = self.config_values.file_system_restoring_limit
self.file_system_state_observed = FileSystemState.RESTORING
elif file_system_state == FileSystemState.GOOD:
self.file_system_state_observed = FileSystemState.GOOD
@@ -145,9 +133,7 @@ class ActiveNode(Node):
f"Node.file_system_state.actual:{self.file_system_state_actual}"
)
def set_file_system_state_if_not_compromised(
self, file_system_state: FileSystemState
):
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState):
"""
Sets the file system state (actual and observed) if not in a compromised state.
@@ -164,14 +150,10 @@ class ActiveNode(Node):
self.file_system_state_actual = file_system_state
if file_system_state == FileSystemState.REPAIRING:
self.file_system_action_count = (
self.config_values.file_system_repairing_limit
)
self.file_system_action_count = self.config_values.file_system_repairing_limit
self.file_system_state_observed = FileSystemState.REPAIRING
elif file_system_state == FileSystemState.RESTORING:
self.file_system_action_count = (
self.config_values.file_system_restoring_limit
)
self.file_system_action_count = self.config_values.file_system_restoring_limit
self.file_system_state_observed = FileSystemState.RESTORING
elif file_system_state == FileSystemState.GOOD:
self.file_system_state_observed = FileSystemState.GOOD

View File

@@ -28,9 +28,7 @@ class PassiveNode(Node):
:param config_values: Config values.
"""
# Pass through to Super for now
super().__init__(
node_id, name, node_type, priority, hardware_state, config_values
)
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)
@property
def ip_address(self) -> str:

View File

@@ -3,13 +3,7 @@
import logging
from typing import Dict, Final
from primaite.common.enums import (
FileSystemState,
HardwareState,
NodeType,
Priority,
SoftwareState,
)
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
from primaite.common.service import Service
from primaite.config.training_config import TrainingConfig
from primaite.nodes.active_node import ActiveNode
@@ -128,9 +122,7 @@ class ServiceNode(ActiveNode):
) or software_state != SoftwareState.COMPROMISED:
service_value.software_state = software_state
if software_state == SoftwareState.PATCHING:
service_value.patching_count = (
self.config_values.service_patching_duration
)
service_value.patching_count = self.config_values.service_patching_duration
else:
_LOGGER.info(
f"The Nodes hardware state is OFF so the state of a service "
@@ -141,9 +133,7 @@ class ServiceNode(ActiveNode):
f"Node.services[<key>].software_state:{software_state}"
)
def set_service_state_if_not_compromised(
self, protocol_name: str, software_state: SoftwareState
):
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState):
"""
Sets the software_state of a service (protocol) on the node.
@@ -159,9 +149,7 @@ class ServiceNode(ActiveNode):
if service_value.software_state != SoftwareState.COMPROMISED:
service_value.software_state = software_state
if software_state == SoftwareState.PATCHING:
service_value.patching_count = (
self.config_values.service_patching_duration
)
service_value.patching_count = self.config_values.service_patching_duration
else:
_LOGGER.info(
f"The Nodes hardware state is OFF so the state of a service "

View File

@@ -4,7 +4,7 @@ import os
import subprocess
import sys
from primaite import NOTEBOOKS_DIR, getLogger
from primaite import getLogger, NOTEBOOKS_DIR
_LOGGER = getLogger(__name__)

View File

@@ -86,9 +86,7 @@ def apply_iers(
and source_node.software_state != SoftwareState.PATCHING
):
if source_node.has_service(protocol):
if source_node.service_running(
protocol
) and not source_node.service_is_overwhelmed(protocol):
if source_node.service_running(protocol) and not source_node.service_is_overwhelmed(protocol):
source_valid = True
else:
source_valid = False
@@ -103,10 +101,7 @@ def apply_iers(
# 2. Check the dest node situation
if dest_node.node_type == NodeType.SWITCH:
# It's a switch
if (
dest_node.hardware_state == HardwareState.ON
and dest_node.software_state != SoftwareState.PATCHING
):
if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING:
dest_valid = True
else:
# IER no longer valid
@@ -116,14 +111,9 @@ def apply_iers(
pass
else:
# It's not a switch or an actuator (so active node)
if (
dest_node.hardware_state == HardwareState.ON
and dest_node.software_state != SoftwareState.PATCHING
):
if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING:
if dest_node.has_service(protocol):
if dest_node.service_running(
protocol
) and not dest_node.service_is_overwhelmed(protocol):
if dest_node.service_running(protocol) and not dest_node.service_is_overwhelmed(protocol):
dest_valid = True
else:
dest_valid = False
@@ -136,9 +126,7 @@ def apply_iers(
dest_valid = False
# 3. Check that the ACL doesn't block it
acl_block = acl.is_blocked(
source_node.ip_address, dest_node.ip_address, protocol, port
)
acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port)
if acl_block:
if _VERBOSE:
print(
@@ -169,10 +157,7 @@ def apply_iers(
# We might have a switch in the path, so check all nodes are operational
for node in path_node_list:
if (
node.hardware_state != HardwareState.ON
or node.software_state == SoftwareState.PATCHING
):
if node.hardware_state != HardwareState.ON or node.software_state == SoftwareState.PATCHING:
path_valid = False
if path_valid:
@@ -184,9 +169,7 @@ def apply_iers(
# Check that the link capacity is not exceeded by the new load
while count < path_node_list_length - 1:
# Get the link between the next two nodes
edge_dict = network.get_edge_data(
path_node_list[count], path_node_list[count + 1]
)
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1])
link_id = edge_dict[0].get("id")
link = links[link_id]
# Check whether the new load exceeds the bandwidth
@@ -204,7 +187,8 @@ def apply_iers(
while count < path_node_list_length - 1:
# Get the link between the next two nodes
edge_dict = network.get_edge_data(
path_node_list[count], path_node_list[count + 1]
path_node_list[count],
path_node_list[count + 1],
)
link_id = edge_dict[0].get("id")
link = links[link_id]

View File

@@ -6,13 +6,7 @@ from networkx import MultiGraph, shortest_path
from primaite.acl.access_control_list import AccessControlList
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
HardwareState,
NodePOLInitiator,
NodePOLType,
NodeType,
SoftwareState,
)
from primaite.common.enums import HardwareState, NodePOLInitiator, NodePOLType, NodeType, SoftwareState
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
@@ -83,10 +77,7 @@ def apply_red_agent_iers(
if source_node.hardware_state == HardwareState.ON:
if source_node.has_service(protocol):
# Red agents IERs can only be valid if the source service is in a compromised state
if (
source_node.get_service_state(protocol)
== SoftwareState.COMPROMISED
):
if source_node.get_service_state(protocol) == SoftwareState.COMPROMISED:
source_valid = True
else:
source_valid = False
@@ -124,9 +115,7 @@ def apply_red_agent_iers(
dest_valid = False
# 3. Check that the ACL doesn't block it
acl_block = acl.is_blocked(
source_node.ip_address, dest_node.ip_address, protocol, port
)
acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port)
if acl_block:
if _VERBOSE:
print(
@@ -170,9 +159,7 @@ def apply_red_agent_iers(
# Check that the link capacity is not exceeded by the new load
while count < path_node_list_length - 1:
# Get the link between the next two nodes
edge_dict = network.get_edge_data(
path_node_list[count], path_node_list[count + 1]
)
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1])
link_id = edge_dict[0].get("id")
link = links[link_id]
# Check whether the new load exceeds the bandwidth
@@ -190,7 +177,8 @@ def apply_red_agent_iers(
while count < path_node_list_length - 1:
# Get the link between the next two nodes
edge_dict = network.get_edge_data(
path_node_list[count], path_node_list[count + 1]
path_node_list[count],
path_node_list[count + 1],
)
link_id = edge_dict[0].get("id")
link = links[link_id]
@@ -248,9 +236,7 @@ def apply_red_agent_node_pol(
state = node_instruction.get_state()
source_node_id = node_instruction.get_source_node_id()
source_node_service_name = node_instruction.get_source_node_service()
source_node_service_state_value = (
node_instruction.get_source_node_service_state()
)
source_node_service_state_value = node_instruction.get_source_node_service_state()
passed_checks = False
@@ -292,9 +278,7 @@ def apply_red_agent_node_pol(
target_node.hardware_state = state
elif pol_type == NodePOLType.OS:
# Change OS state
if isinstance(target_node, ActiveNode) or isinstance(
target_node, ServiceNode
):
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
target_node.software_state = state
elif pol_type == NodePOLType.SERVICE:
# Change a service state
@@ -302,9 +286,7 @@ def apply_red_agent_node_pol(
target_node.set_service_state(service_name, state)
else:
# Change the file system status
if isinstance(target_node, ActiveNode) or isinstance(
target_node, ServiceNode
):
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
target_node.set_file_system_state(state)
else:
if _VERBOSE:

View File

@@ -0,0 +1,150 @@
from __future__ import annotations
from pathlib import Path
from typing import Dict, Final, Union
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.agents.hardcoded_acl import HardCodedACLAgent
from primaite.agents.hardcoded_node import HardCodedNodeAgent
from primaite.agents.rllib import RLlibAgent
from primaite.agents.sb3 import SB3Agent
from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyAgent, RandomAgent
from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType
from primaite.config import lay_down_config, training_config
from primaite.config.training_config import TrainingConfig
_LOGGER = getLogger(__name__)
class PrimaiteSession:
"""
The PrimaiteSession class.
Provides a single learning and evaluation entry point for all training
and lay down configurations.
"""
def __init__(
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
):
"""
The PrimaiteSession constructor.
:param training_config_path: The training config path.
:param lay_down_config_path: The lay down config path.
"""
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path]] = training_config_path
self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path)
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
self._agent_session: AgentSessionABC = None # noqa
self.session_path: Path = None # noqa
self.timestamp_str: str = None # noqa
self.learning_path: Path = None # noqa
self.evaluation_path: Path = None # noqa
def setup(self):
"""Performs the session setup."""
if self._training_config.agent_framework == AgentFramework.CUSTOM:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}")
if self._training_config.agent_identifier == AgentIdentifier.HARDCODED:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}")
if self._training_config.action_type == ActionType.NODE:
# Deterministic Hardcoded Agent with Node Action Space
self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path)
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
self._agent_session = HardCodedACLAgent(self._training_config_path, self._lay_down_config_path)
elif self._training_config.action_type == ActionType.ANY:
# Deterministic Hardcoded Agent with ANY Action Space
raise NotImplementedError
else:
# Invalid AgentIdentifier ActionType combo
raise ValueError
elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHING}")
if self._training_config.action_type == ActionType.NODE:
self._agent_session = DoNothingNodeAgent(self._training_config_path, self._lay_down_config_path)
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
self._agent_session = DoNothingACLAgent(self._training_config_path, self._lay_down_config_path)
elif self._training_config.action_type == ActionType.ANY:
# Deterministic Hardcoded Agent with ANY Action Space
raise NotImplementedError
else:
# Invalid AgentIdentifier ActionType combo
raise ValueError
elif self._training_config.agent_identifier == AgentIdentifier.RANDOM:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}")
self._agent_session = RandomAgent(self._training_config_path, self._lay_down_config_path)
elif self._training_config.agent_identifier == AgentIdentifier.DUMMY:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}")
self._agent_session = DummyAgent(self._training_config_path, self._lay_down_config_path)
else:
# Invalid AgentFramework AgentIdentifier combo
raise ValueError
elif self._training_config.agent_framework == AgentFramework.SB3:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}")
# Stable Baselines3 Agent
self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path)
elif self._training_config.agent_framework == AgentFramework.RLLIB:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}")
# Ray RLlib Agent
self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path)
else:
# Invalid AgentFramework
raise ValueError
self.session_path: Path = self._agent_session.session_path
self.timestamp_str: str = self._agent_session.timestamp_str
self.learning_path: Path = self._agent_session.learning_path
self.evaluation_path: Path = self._agent_session.evaluation_path
def learn(
self,
**kwargs,
):
"""
Train the agent.
:param kwargs: Any agent-framework specific key word args.
"""
if not self._training_config.session_type == SessionType.EVAL:
self._agent_session.learn(**kwargs)
def evaluate(
self,
**kwargs,
):
"""
Evaluate the agent.
:param kwargs: Any agent-framework specific key word args.
"""
if not self._training_config.session_type == SessionType.TRAIN:
self._agent_session.evaluate(**kwargs)
def close(self):
"""Closes the agent."""
self._agent_session.close()

View File

@@ -1,5 +1,22 @@
# The main PrimAITE application config file
# Logging
log_level: INFO
logger_format: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
logging:
log_level: INFO
logger_format:
DEBUG: '%(asctime)s: %(message)s'
INFO: '%(asctime)s: %(message)s'
WARNING: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
ERROR: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
CRITICAL: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
# Session
session:
outputs:
plots:
size:
auto_size: false
width: 1500
height: 900
template: plotly_white
range_slider: false

View File

@@ -6,7 +6,7 @@ from pathlib import Path
import pkg_resources
from primaite import NOTEBOOKS_DIR, getLogger
from primaite import getLogger, NOTEBOOKS_DIR
_LOGGER = getLogger(__name__)
@@ -18,9 +18,7 @@ def run(overwrite_existing: bool = True):
:param overwrite_existing: A bool to toggle replacing existing edited
notebooks on or off.
"""
notebooks_package_data_root = pkg_resources.resource_filename(
"primaite", "notebooks/_package_data"
)
notebooks_package_data_root = pkg_resources.resource_filename("primaite", "notebooks/_package_data")
for subdir, dirs, files in os.walk(notebooks_package_data_root):
for file in files:
fp = os.path.join(subdir, file)
@@ -30,9 +28,7 @@ def run(overwrite_existing: bool = True):
copy_file = not target_fp.is_file()
if overwrite_existing and not copy_file:
copy_file = (not filecmp.cmp(fp, target_fp)) and (
".ipynb_checkpoints" not in str(target_fp)
)
copy_file = (not filecmp.cmp(fp, target_fp)) and (".ipynb_checkpoints" not in str(target_fp))
if copy_file:
shutil.copy2(fp, target_fp)

View File

@@ -5,7 +5,7 @@ from pathlib import Path
import pkg_resources
from primaite import USERS_CONFIG_DIR, getLogger
from primaite import getLogger, USERS_CONFIG_DIR
_LOGGER = getLogger(__name__)
@@ -17,9 +17,7 @@ def run(overwrite_existing=True):
:param overwrite_existing: A bool to toggle replacing existing edited
config on or off.
"""
configs_package_data_root = pkg_resources.resource_filename(
"primaite", "config/_package_data"
)
configs_package_data_root = pkg_resources.resource_filename("primaite", "config/_package_data")
for subdir, dirs, files in os.walk(configs_package_data_root):
for file in files:

View File

@@ -1,5 +1,5 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
from primaite import _USER_DIRS, LOG_DIR, NOTEBOOKS_DIR, getLogger
from primaite import _USER_DIRS, getLogger, LOG_DIR, NOTEBOOKS_DIR
_LOGGER = getLogger(__name__)

View File

@@ -1,48 +1,99 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""The Transaction class."""
from datetime import datetime
from typing import List, Tuple
from primaite.common.enums import AgentIdentifier
class Transaction(object):
"""Transaction class."""
def __init__(self, _timestamp, _agent_identifier, _episode_number, _step_number):
def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int):
"""
Init.
Transaction constructor.
Args:
_timestamp: The time this object was created
_agent_identifier: An identifier for the agent in use
_episode_number: The episode number
_step_number: The step number
:param agent_identifier: An identifier for the agent in use
:param episode_number: The episode number
:param step_number: The step number
"""
self.timestamp = _timestamp
self.agent_identifier = _agent_identifier
self.episode_number = _episode_number
self.step_number = _step_number
self.timestamp = datetime.now()
"The datetime of the transaction"
self.agent_identifier: AgentIdentifier = agent_identifier
"The agent identifier"
self.episode_number: int = episode_number
"The episode number"
self.step_number: int = step_number
"The step number"
self.obs_space = None
"The observation space (pre)"
self.obs_space_pre = None
"The observation space before any actions are taken"
self.obs_space_post = None
"The observation space after any actions are taken"
self.reward = None
"The reward value"
self.action_space = None
"The action space invoked by the agent"
self.obs_space_description = None
"The env observation space description"
def set_obs_space(self, _obs_space):
def as_csv_data(self) -> Tuple[List, List]:
"""
Sets the observation space (pre).
Converts the Transaction to a csv data row and provides a header.
Args:
_obs_space_pre: The observation space before any actions are taken
:return: A tuple consisting of (header, data).
"""
self.obs_space = _obs_space
if isinstance(self.action_space, int):
action_length = self.action_space
else:
action_length = self.action_space.size
def set_reward(self, _reward):
"""
Sets the reward.
# Create the action space headers array
action_header = []
for x in range(action_length):
action_header.append("AS_" + str(x))
Args:
_reward: The reward value
"""
self.reward = _reward
# Open up a csv file
header = ["Timestamp", "Episode", "Step", "Reward"]
header = header + action_header + self.obs_space_description
def set_action_space(self, _action_space):
"""
Sets the action space.
row = [
str(self.timestamp),
str(self.episode_number),
str(self.step_number),
str(self.reward),
]
row = row + _turn_action_space_to_array(self.action_space) + self.obs_space.tolist()
return header, row
Args:
_action_space: The action space invoked by the agent
"""
self.action_space = _action_space
def _turn_action_space_to_array(action_space) -> List[str]:
"""
Turns action space into a string array so it can be saved to csv.
:param action_space: The action space
:return: The action space as an array of strings
"""
if isinstance(action_space, list):
return [str(i) for i in action_space]
else:
return [str(action_space)]
def _turn_obs_space_to_array(obs_space, obs_assets, obs_features) -> List[str]:
"""
Turns observation space into a string array so it can be saved to csv.
:param obs_space: The observation space
:param obs_assets: The number of assets (i.e. nodes or links) in the
observation space
:param obs_features: The number of features associated with the asset
:return: The observation space as an array of strings
"""
return_array = []
for x in range(obs_assets):
for y in range(obs_features):
return_array.append(str(obs_space[x][y]))
return return_array

View File

@@ -1,91 +0,0 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""Writes the Transaction log list out to file for evaluation to utilse."""
import csv
from pathlib import Path
from primaite import getLogger
_LOGGER = getLogger(__name__)
def turn_action_space_to_array(_action_space):
"""
Turns action space into a string array so it can be saved to csv.
Args:
_action_space: The action space.
"""
if isinstance(_action_space, list):
return [str(i) for i in _action_space]
else:
return [str(_action_space)]
def write_transaction_to_file(
transaction_list,
session_path: Path,
timestamp_str: str,
obs_space_description: list,
):
"""
Writes transaction logs to file to support training evaluation.
:param transaction_list: The list of transactions from all steps and all
episodes.
:param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
# Get the first transaction and use it to determine the makeup of the
# observation space and action space
# Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action
# space as "AS_1"
# This will be tied into the PrimAITE Use Case so that they make sense
template_transation = transaction_list[0]
action_length = template_transation.action_space.size
# obs_shape = template_transation.obs_space_post.shape
# obs_assets = template_transation.obs_space_post.shape[0]
# if len(obs_shape) == 1:
# bit of a workaround but I think the way transactions are written will change soon
# obs_features = 1
# else:
# obs_features = template_transation.obs_space_post.shape[1]
# Create the action space headers array
action_header = []
for x in range(action_length):
action_header.append("AS_" + str(x))
# Create the observation space headers array
# obs_header_initial = [f"pre_{o}" for o in obs_space_description]
# obs_header_new = [f"post_{o}" for o in obs_space_description]
# Open up a csv file
header = ["Timestamp", "Episode", "Step", "Reward"]
header = header + action_header + obs_space_description
try:
filename = session_path / f"all_transactions_{timestamp_str}.csv"
_LOGGER.debug(f"Saving transaction logs: {filename}")
csv_file = open(filename, "w", encoding="UTF8", newline="")
csv_writer = csv.writer(csv_file)
csv_writer.writerow(header)
for transaction in transaction_list:
csv_data = [
str(transaction.timestamp),
str(transaction.episode_number),
str(transaction.step_number),
str(transaction.reward),
]
csv_data = (
csv_data
+ turn_action_space_to_array(transaction.action_space)
+ transaction.obs_space.tolist()
)
csv_writer.writerow(csv_data)
csv_file.close()
except Exception:
_LOGGER.error("Could not save the transaction file", exc_info=True)

View File

@@ -0,0 +1,20 @@
from pathlib import Path
from typing import Dict, Union
# Using polars as it's faster than Pandas; it will speed things up when
# files get big!
import polars as pl
def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
"""
Read an average rewards per episode csv file and return as a dict.
The dictionary keys are the episode number, and the values are the mean
reward that episode.
:param av_rewards_csv_file: The average rewards per episode csv file path.
:return: The average rewards per episode cdv as a dict.
"""
d = pl.read_csv(av_rewards_csv_file).to_dict()
return {v: d["Average Reward"][i] for i, v in enumerate(d["Episode"])}

View File

@@ -0,0 +1,83 @@
import csv
from logging import Logger
from typing import Final, List, Tuple, TYPE_CHECKING, Union
from primaite import getLogger
from primaite.transactions.transaction import Transaction
if TYPE_CHECKING:
from primaite.environment.primaite_env import Primaite
_LOGGER: Logger = getLogger(__name__)
class SessionOutputWriter:
"""
A session output writer class.
Is used to write session outputs to csv file.
"""
_AV_REWARD_PER_EPISODE_HEADER: Final[List[str]] = [
"Episode",
"Average Reward",
]
def __init__(
self,
env: "Primaite",
transaction_writer: bool = False,
learning_session: bool = True,
):
self._env = env
self.transaction_writer = transaction_writer
self.learning_session = learning_session
if self.transaction_writer:
fn = f"all_transactions_{self._env.timestamp_str}.csv"
else:
fn = f"average_reward_per_episode_{self._env.timestamp_str}.csv"
if self.learning_session:
self._csv_file_path = self._env.session_path / "learning" / fn
else:
self._csv_file_path = self._env.session_path / "evaluation" / fn
self._csv_file_path.parent.mkdir(exist_ok=True, parents=True)
self._csv_file = None
self._csv_writer = None
self._first_write: bool = True
def _init_csv_writer(self):
self._csv_file = open(self._csv_file_path, "w", encoding="UTF8", newline="")
self._csv_writer = csv.writer(self._csv_file)
def __del__(self):
self.close()
def close(self):
"""Close the cvs file."""
if self._csv_file:
self._csv_file.close()
_LOGGER.debug(f"Finished writing file: {self._csv_file_path}")
def write(self, data: Union[Tuple, Transaction]):
"""
Write a row of session data.
:param data: The row of data to write. Can be a Tuple or an instance
of Transaction.
"""
if isinstance(data, Transaction):
header, data = data.as_csv_data()
else:
header = self._AV_REWARD_PER_EPISODE_HEADER
if self._first_write:
self._init_csv_writer()
self._csv_writer.writerow(header)
self._first_write = False
self._csv_writer.writerow(data)

View File

@@ -1,11 +1,20 @@
# Main Config File
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_identifier: STABLE_BASELINES3_A2C
# Sets which agent algorithm framework will be used:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray[RLlib])
# "NONE" (Custom Agent)
agent_framework: SB3
# Sets which Red Agent algo/class will be used:
# "PPO" (Proximal Policy Optimization)
# "A2C" (Advantage Actor Critic)
# "HARDCODED" (Custom Agent)
# "RANDOM" (Random Action)
agent_identifier: PPO
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
@@ -18,7 +27,7 @@ num_steps: 256
# Time delay between steps (for generic agents)
time_delay: 10
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAINING
session_type: TRAIN
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in

View File

@@ -1,11 +1,22 @@
# Main Config File
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: SB3
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: A2C
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_identifier: STABLE_BASELINES3_A2C
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
@@ -28,7 +39,7 @@ observation_space:
time_delay: 1
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAINING
session_type: TRAIN
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in

View File

@@ -1,11 +1,22 @@
# Main Config File
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: CUSTOM
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: RANDOM
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_identifier: NONE
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
@@ -24,7 +35,7 @@ observation_space:
time_delay: 1
# Filename of the scenario / laydown
session_type: TRAINING
session_type: TRAIN
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in

View File

@@ -1,11 +1,22 @@
# Main Config File
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: CUSTOM
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: RANDOM
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_identifier: NONE
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
@@ -25,7 +36,7 @@ observation_space:
time_delay: 1
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAINING
session_type: TRAIN
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in

View File

@@ -1,11 +1,22 @@
# Main Config File
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: CUSTOM
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: RANDOM
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_identifier: NONE
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
@@ -18,7 +29,7 @@ num_steps: 5
# Time delay between steps (for generic agents)
time_delay: 1
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAINING
session_type: TRAIN
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in

View File

@@ -1,10 +1,22 @@
# Main Config File
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: CUSTOM
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: DUMMY
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
agent_identifier: GENERIC
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
@@ -18,7 +30,7 @@ num_steps: 15
time_delay: 1
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAINING
session_type: EVAL
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in

View File

@@ -1,11 +1,22 @@
# Main Config File
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: CUSTOM
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: RANDOM
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_identifier: GENERIC
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
@@ -18,7 +29,7 @@ num_steps: 15
# Time delay between steps (for generic agents)
time_delay: 1
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAINING
session_type: EVAL
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in

View File

@@ -1,11 +1,22 @@
# Main Config File
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: CUSTOM
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: RANDOM
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_identifier: GENERIC
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
@@ -18,7 +29,7 @@ num_steps: 5
# Time delay between steps (for generic agents)
time_delay: 1
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAINING
session_type: EVAL
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in

View File

@@ -0,0 +1,112 @@
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: CUSTOM
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: DUMMY
# Sets whether Red Agent POL and IER is randomised.
# Options are:
# True
# False
random_red_agent: True
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# Number of episodes to run per session
num_episodes: 2
# Number of time_steps per episode
num_steps: 15
# Time delay between steps (for generic agents)
time_delay: 1
# Type of session to be run (TRAINING or EVALUATION)
session_type: EVAL
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
# Reward values
# Generic
all_ok: 0
# Node Hardware State
off_should_be_on: -10
off_should_be_resetting: -5
on_should_be_off: -2
on_should_be_resetting: -5
resetting_should_be_on: -5
resetting_should_be_off: -2
resetting: -3
# Node Software or Service State
good_should_be_patching: 2
good_should_be_compromised: 5
good_should_be_overwhelmed: 5
patching_should_be_good: -5
patching_should_be_compromised: 2
patching_should_be_overwhelmed: 2
patching: -3
compromised_should_be_good: -20
compromised_should_be_patching: -20
compromised_should_be_overwhelmed: -20
compromised: -20
overwhelmed_should_be_good: -20
overwhelmed_should_be_patching: -20
overwhelmed_should_be_compromised: -20
overwhelmed: -20
# Node File System State
good_should_be_repairing: 2
good_should_be_restoring: 2
good_should_be_corrupt: 5
good_should_be_destroyed: 10
repairing_should_be_good: -5
repairing_should_be_restoring: 2
repairing_should_be_corrupt: 2
repairing_should_be_destroyed: 0
repairing: -3
restoring_should_be_good: -10
restoring_should_be_repairing: -2
restoring_should_be_corrupt: 1
restoring_should_be_destroyed: 2
restoring: -6
corrupt_should_be_good: -10
corrupt_should_be_repairing: -10
corrupt_should_be_restoring: -10
corrupt_should_be_destroyed: 2
corrupt: -10
destroyed_should_be_good: -20
destroyed_should_be_repairing: -20
destroyed_should_be_restoring: -20
destroyed_should_be_corrupt: -20
destroyed: -20
scanning: -2
# IER status
red_ier_running: -5
green_ier_blocked: -10
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS
node_reset_duration: 5 # The time taken to reset a node (hardware)
service_patching_duration: 5 # The time taken to patch a service
file_system_repairing_limit: 5 # The time take to repair the file system
file_system_restoring_limit: 5 # The time take to restore the file system
file_system_scanning_limit: 5 # The time taken to scan the file system

View File

@@ -1,43 +1,150 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
import datetime
import shutil
import tempfile
import time
from datetime import datetime
from pathlib import Path
from typing import Union
from typing import Dict, Union
from unittest.mock import patch
import pytest
from primaite import getLogger
from primaite.common.enums import AgentIdentifier
from primaite.environment.primaite_env import Primaite
from primaite.primaite_session import PrimaiteSession
from primaite.utils.session_output_reader import av_rewards_dict
from tests.mock_and_patch.get_session_path_mock import get_temp_session_path
ACTION_SPACE_NODE_VALUES = 1
ACTION_SPACE_NODE_ACTION_VALUES = 1
_LOGGER = getLogger(__name__)
def _get_temp_session_path(session_timestamp: datetime) -> Path:
class TempPrimaiteSession(PrimaiteSession):
"""
A temporary PrimaiteSession class.
Uses context manager for deletion of files upon exit.
"""
def __init__(
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
):
super().__init__(training_config_path, lay_down_config_path)
self.setup()
def learn_av_reward_per_episode(self) -> Dict[int, float]:
"""Get the learn av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.learning_path / csv_file)
def eval_av_reward_per_episode_csv(self) -> Dict[int, float]:
"""Get the eval av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.evaluation_path / csv_file)
@property
def env(self) -> Primaite:
"""Direct access to the env for ease of testing."""
return self._agent_session._env # noqa
def __enter__(self):
return self
def __exit__(self, type, value, tb):
shutil.rmtree(self.session_path)
shutil.rmtree(self.session_path.parent)
_LOGGER.debug(f"Deleted temp session directory: {self.session_path}")
@pytest.fixture
def temp_primaite_session(request):
"""
Provides a temporary PrimaiteSession instance.
It's temporary as it uses a temporary directory as the session path.
To use this fixture you need to:
- parametrize your test function with:
- "temp_primaite_session"
- [[path to training config, path to lay down config]]
- Include the temp_primaite_session fixture as a param in your test
function.
- use the temp_primaite_session as a context manager assigning is the
name 'session'.
.. code:: python
from primaite.config.lay_down_config import dos_very_basic_config_path
from primaite.config.training_config import main_training_config_path
@pytest.mark.parametrize(
"temp_primaite_session",
[
[main_training_config_path(), dos_very_basic_config_path()]
],
indirect=True
)
def test_primaite_session(temp_primaite_session):
with temp_primaite_session as session:
# Learning outputs are saved in session.learning_path
session.learn()
# Evaluation outputs are saved in session.evaluation_path
session.evaluate()
# To ensure that all files are written, you must call .close()
session.close()
# If you need to inspect any session outputs, it must be done
# inside the context manager
# Now that we've exited the context manager, the
# session.session_path directory and its contents are deleted
"""
training_config_path = request.param[0]
lay_down_config_path = request.param[1]
with patch("primaite.agents.agent.get_session_path", get_temp_session_path) as mck:
mck.session_timestamp = datetime.now()
return TempPrimaiteSession(training_config_path, lay_down_config_path)
@pytest.fixture
def temp_session_path() -> Path:
"""
Get a temp directory session path the test session will output to.
:param session_timestamp: This is the datetime that the session started.
:return: The session directory path.
"""
session_timestamp = datetime.now()
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_dir
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
session_path.mkdir(exist_ok=True, parents=True)
return session_path
def _get_primaite_env_from_config(
training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
temp_session_path,
):
"""Takes a config path and returns the created instance of Primaite."""
session_timestamp: datetime = datetime.now()
session_path = _get_temp_session_path(session_timestamp)
session_path = temp_session_path(session_timestamp)
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
env = Primaite(
training_config_path=training_config_path,
lay_down_config_path=lay_down_config_path,
transaction_list=[],
session_path=session_path,
timestamp_str=timestamp_str,
)
@@ -46,7 +153,7 @@ def _get_primaite_env_from_config(
# TOOD: This needs t be refactored to happen outside. Should be part of
# a main Session class.
if env.training_config.agent_identifier == "GENERIC":
if env.training_config.agent_identifier is AgentIdentifier.RANDOM:
run_generic(env, config_values)
return env

View File

@@ -1,8 +0,0 @@
from primaite.config.lay_down_config import data_manipulation_config_path
from primaite.config.training_config import main_training_config_path
from primaite.main import run
def test_primaite_main_e2e():
"""Tests the primaite.main.run function end-to-end."""
run(main_training_config_path(), data_manipulation_config_path())

View File

@@ -0,0 +1,22 @@
import tempfile
from datetime import datetime
from pathlib import Path
from primaite import getLogger
_LOGGER = getLogger(__name__)
def get_temp_session_path(session_timestamp: datetime) -> Path:
"""
Get a temp directory session path the test session will output to.
:param session_timestamp: This is the datetime that the session started.
:return: The session directory path.
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
session_path.mkdir(exist_ok=True, parents=True)
_LOGGER.debug(f"Created temp session directory: {session_path}")
return session_path

View File

@@ -95,8 +95,6 @@ def test_rule_hash():
rule = ACLRule("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
hash_value_local = hash(rule)
hash_value_remote = acl.get_dictionary_hash(
"DENY", "192.168.1.1", "192.168.1.2", "TCP", "80"
)
hash_value_remote = acl.get_dictionary_hash("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
assert hash_value_local == hash_value_remote

View File

@@ -2,84 +2,79 @@
import numpy as np
import pytest
from primaite.environment.observations import (
NodeLinkTable,
NodeStatuses,
ObservationsHandler,
)
from primaite.environment.primaite_env import Primaite
from primaite.environment.observations import NodeLinkTable, NodeStatuses, ObservationsHandler
from tests import TEST_CONFIG_ROOT
from tests.conftest import _get_primaite_env_from_config
@pytest.fixture
def env(request):
"""Build Primaite environment for integration tests of observation space."""
marker = request.node.get_closest_marker("env_config_paths")
training_config_path = marker.args[0]["training_config_path"]
lay_down_config_path = marker.args[0]["lay_down_config_path"]
env = _get_primaite_env_from_config(
training_config_path=training_config_path,
lay_down_config_path=lay_down_config_path,
)
yield env
@pytest.mark.env_config_paths(
dict(
training_config_path=TEST_CONFIG_ROOT
/ "obs_tests/main_config_without_obs.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
@pytest.mark.parametrize(
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
]
],
indirect=True,
)
def test_default_obs_space(env: Primaite):
def test_default_obs_space(temp_primaite_session):
"""Create environment with no obs space defined in config and check that the default obs space was created."""
env.update_environent_obs()
with temp_primaite_session as session:
session.env.update_environent_obs()
components = env.obs_handler.registered_obs_components
components = session.env.obs_handler.registered_obs_components
assert len(components) == 1
assert isinstance(components[0], NodeLinkTable)
assert len(components) == 1
assert isinstance(components[0], NodeLinkTable)
@pytest.mark.env_config_paths(
dict(
training_config_path=TEST_CONFIG_ROOT
/ "obs_tests/main_config_without_obs.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
@pytest.mark.parametrize(
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
]
],
indirect=True,
)
def test_registering_components(env: Primaite):
def test_registering_components(temp_primaite_session):
"""Test regitering and deregistering a component."""
handler = ObservationsHandler()
component = NodeStatuses(env)
handler.register(component)
assert component in handler.registered_obs_components
handler.deregister(component)
assert component not in handler.registered_obs_components
with temp_primaite_session as session:
env = session.env
handler = ObservationsHandler()
component = NodeStatuses(env)
handler.register(component)
assert component in handler.registered_obs_components
handler.deregister(component)
assert component not in handler.registered_obs_components
@pytest.mark.env_config_paths(
dict(
training_config_path=TEST_CONFIG_ROOT
/ "obs_tests/main_config_NODE_LINK_TABLE.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
@pytest.mark.parametrize(
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_LINK_TABLE.yaml",
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
]
],
indirect=True,
)
class TestNodeLinkTable:
"""Test the NodeLinkTable observation component (in isolation)."""
def test_obs_shape(self, env: Primaite):
def test_obs_shape(self, temp_primaite_session):
"""Try creating env with box observation space."""
env.update_environent_obs()
with temp_primaite_session as session:
env = session.env
env.update_environent_obs()
# we have three nodes and two links, with two service
# therefore the box observation space will have:
# * 5 rows (3 nodes + 2 links)
# * 6 columns (four fixed and two for the services)
assert env.env_obs.shape == (5, 6)
# we have three nodes and two links, with two service
# therefore the box observation space will have:
# * 5 rows (3 nodes + 2 links)
# * 6 columns (four fixed and two for the services)
assert env.env_obs.shape == (5, 6)
def test_value(self, env: Primaite):
def test_value(self, temp_primaite_session):
"""Test that the observation is generated correctly.
The laydown has:
@@ -125,36 +120,43 @@ class TestNodeLinkTable:
* 999 (999 traffic service1)
* 0 (no traffic for service2)
"""
# act = np.asarray([0,])
obs, reward, done, info = env.step(0) # apply the 'do nothing' action
with temp_primaite_session as session:
env = session.env
# act = np.asarray([0,])
obs, reward, done, info = env.step(0) # apply the 'do nothing' action
assert np.array_equal(
obs,
[
[1, 1, 3, 1, 1, 1],
[2, 1, 1, 1, 1, 4],
[3, 1, 1, 1, 0, 0],
[4, 0, 0, 0, 999, 0],
[5, 0, 0, 0, 999, 0],
],
)
assert np.array_equal(
obs,
[
[1, 1, 3, 1, 1, 1],
[2, 1, 1, 1, 1, 4],
[3, 1, 1, 1, 0, 0],
[4, 0, 0, 0, 999, 0],
[5, 0, 0, 0, 999, 0],
],
)
@pytest.mark.env_config_paths(
dict(
training_config_path=TEST_CONFIG_ROOT
/ "obs_tests/main_config_NODE_STATUSES.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
@pytest.mark.parametrize(
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_STATUSES.yaml",
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
]
],
indirect=True,
)
class TestNodeStatuses:
"""Test the NodeStatuses observation component (in isolation)."""
def test_obs_shape(self, env: Primaite):
def test_obs_shape(self, temp_primaite_session):
"""Try creating env with NodeStatuses as the only component."""
assert env.env_obs.shape == (15,)
with temp_primaite_session as session:
env = session.env
assert env.env_obs.shape == (15,)
def test_values(self, env: Primaite):
def test_values(self, temp_primaite_session):
"""Test that the hardware and software states are encoded correctly.
The laydown has:
@@ -181,28 +183,36 @@ class TestNodeStatuses:
* service 1 = n/a (0)
* service 2 = n/a (0)
"""
obs, _, _, _ = env.step(0) # apply the 'do nothing' action
assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0])
with temp_primaite_session as session:
env = session.env
obs, _, _, _ = env.step(0) # apply the 'do nothing' action
print(obs)
assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0])
@pytest.mark.env_config_paths(
dict(
training_config_path=TEST_CONFIG_ROOT
/ "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
@pytest.mark.parametrize(
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT / "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml",
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
]
],
indirect=True,
)
class TestLinkTrafficLevels:
"""Test the LinkTrafficLevels observation component (in isolation)."""
def test_obs_shape(self, env: Primaite):
def test_obs_shape(self, temp_primaite_session):
"""Try creating env with MultiDiscrete observation space."""
env.update_environent_obs()
with temp_primaite_session as session:
env = session.env
env.update_environent_obs()
# we have two links and two services, so the shape should be 2 * 2
assert env.env_obs.shape == (2 * 2,)
# we have two links and two services, so the shape should be 2 * 2
assert env.env_obs.shape == (2 * 2,)
def test_values(self, env: Primaite):
def test_values(self, temp_primaite_session):
"""Test that traffic values are encoded correctly.
The laydown has:
@@ -212,12 +222,14 @@ class TestLinkTrafficLevels:
* an IER trying to send 999 bits of data over both links the whole time (via the first service)
* link bandwidth of 1000, therefore the utilisation is 99.9%
"""
obs, reward, done, info = env.step(0)
obs, reward, done, info = env.step(0)
with temp_primaite_session as session:
env = session.env
obs, reward, done, info = env.step(0)
obs, reward, done, info = env.step(0)
# the observation space has combine_service_traffic set to False, so the space has this format:
# [link1_service1, link1_service2, link2_service1, link2_service2]
# we send 999 bits of data via link1 and link2 on service 1.
# therefore the first and third elements should be 6 and all others 0
# (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%)
assert np.array_equal(obs, [6, 0, 6, 0])
# the observation space has combine_service_traffic set to False, so the space has this format:
# [link1_service1, link1_service2, link2_service1, link2_service2]
# we send 999 bits of data via link1 and link2 on service 1.
# therefore the first and third elements should be 6 and all others 0
# (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%)
assert np.array_equal(obs, [6, 0, 6, 0])

View File

@@ -0,0 +1,55 @@
import os
import pytest
from primaite import getLogger
from primaite.config.lay_down_config import dos_very_basic_config_path
from primaite.config.training_config import main_training_config_path
_LOGGER = getLogger(__name__)
@pytest.mark.parametrize(
"temp_primaite_session",
[[main_training_config_path(), dos_very_basic_config_path()]],
indirect=True,
)
def test_primaite_session(temp_primaite_session):
"""Tests the PrimaiteSession class and its outputs."""
with temp_primaite_session as session:
session_path = session.session_path
assert session_path.exists()
session.learn()
# Learning outputs are saved in session.learning_path
session.evaluate()
# Evaluation outputs are saved in session.evaluation_path
# If you need to inspect any session outputs, it must be done inside
# the context manager
# Check that the metadata json file exists
assert (session_path / "session_metadata.json").exists()
# Check that the network png file exists
assert (session_path / f"network_{session.timestamp_str}.png").exists()
# Check that both the transactions and av reward csv files exist
for file in session.learning_path.iterdir():
if file.suffix == ".csv":
assert "all_transactions" in file.name or "average_reward_per_episode" in file.name
# Check that both the transactions and av reward csv files exist
for file in session.evaluation_path.iterdir():
if file.suffix == ".csv":
assert "all_transactions" in file.name or "average_reward_per_episode" in file.name
_LOGGER.debug("Inspecting files in temp session path...")
for dir_path, dir_names, file_names in os.walk(session_path):
for file in file_names:
path = os.path.join(dir_path, file)
file_str = path.split(str(session_path))[-1]
_LOGGER.debug(f" {file_str}")
# Now that we've exited the context manager, the session.session_path
# directory and its contents are deleted
assert not session_path.exists()

View File

@@ -1,68 +1,30 @@
from datetime import datetime
import pytest
from primaite.config.lay_down_config import data_manipulation_config_path
from primaite.environment.primaite_env import Primaite
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from tests import TEST_CONFIG_ROOT
from tests.conftest import _get_temp_session_path
def run_generic(env, config_values):
"""Run against a generic agent."""
# Reset the environment at the start of the episode
env.reset()
for episode in range(0, config_values.num_episodes):
for step in range(0, config_values.num_steps):
# Send the observation space to the agent to get an action
# TEMP - random action for now
# action = env.blue_agent_action(obs)
# action = env.action_space.sample()
action = 0
# Run the simulation step on the live environment
obs, reward, done, info = env.step(action)
# Break if done is True
if done:
break
# Reset the environment at the end of the episode
env.reset()
env.close()
def test_random_red_agent_behaviour():
"""
Test that hardware state is penalised at each step.
When the initial state is OFF compared to reference state which is ON.
"""
@pytest.mark.parametrize(
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT / "test_random_red_main_config.yaml",
data_manipulation_config_path(),
]
],
indirect=True,
)
def test_random_red_agent_behaviour(temp_primaite_session):
"""Test that red agent POL is randomised each episode."""
list_of_node_instructions = []
# RUN TWICE so we can make sure that red agent is randomised
for i in range(2):
"""Takes a config path and returns the created instance of Primaite."""
session_timestamp: datetime = datetime.now()
session_path = _get_temp_session_path(session_timestamp)
with temp_primaite_session as session:
session.evaluate()
list_of_node_instructions.append(session.env.red_node_pol)
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
env = Primaite(
training_config_path=TEST_CONFIG_ROOT
/ "one_node_states_on_off_main_config.yaml",
lay_down_config_path=data_manipulation_config_path(),
transaction_list=[],
session_path=session_path,
timestamp_str=timestamp_str,
)
# set red_agent_
env.training_config.random_red_agent = True
training_config = env.training_config
training_config.num_steps = env.episode_steps
run_generic(env, training_config)
# add red pol instructions to list
list_of_node_instructions.append(env.red_node_pol)
session.evaluate()
list_of_node_instructions.append(session.env.red_node_pol)
# compare instructions to make sure that red instructions are truly random
for index, instruction in enumerate(list_of_node_instructions):
@@ -73,5 +35,4 @@ def test_random_red_agent_behaviour():
print(f"{key} end step: {instruction.get_end_step()}")
print(f"{key} target node id: {instruction.get_target_node_id()}")
print("")
assert list_of_node_instructions[0].__ne__(list_of_node_instructions[1])

View File

@@ -1,13 +1,7 @@
"""Used to test Active Node functions."""
import pytest
from primaite.common.enums import (
FileSystemState,
HardwareState,
NodeType,
Priority,
SoftwareState,
)
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
from primaite.common.service import Service
from primaite.config.training_config import TrainingConfig
from primaite.nodes.active_node import ActiveNode
@@ -57,9 +51,7 @@ def test_node_boots_correctly(operating_state, expected_operating_state):
file_system_state="GOOD",
config_values=1,
)
service_attributes = Service(
name="node", port="80", software_state=SoftwareState.COMPROMISED
)
service_attributes = Service(name="node", port="80", software_state=SoftwareState.COMPROMISED)
service_node.add_service(service_attributes)
for x in range(5):

View File

@@ -1,21 +1,26 @@
import pytest
from tests import TEST_CONFIG_ROOT
from tests.conftest import _get_primaite_env_from_config
def test_rewards_are_being_penalised_at_each_step_function():
@pytest.mark.parametrize(
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml",
TEST_CONFIG_ROOT / "one_node_states_on_off_lay_down_config.yaml",
]
],
indirect=True,
)
def test_rewards_are_being_penalised_at_each_step_function(
temp_primaite_session,
):
"""
Test that hardware state is penalised at each step.
When the initial state is OFF compared to reference state which is ON.
"""
env = _get_primaite_env_from_config(
training_config_path=TEST_CONFIG_ROOT
/ "one_node_states_on_off_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT
/ "one_node_states_on_off_lay_down_config.yaml",
)
"""
The config 'one_node_states_on_off_lay_down_config.yaml' has 15 steps:
On different steps, the laydown config has Pattern of Life (PoLs) which change a state of the node's attribute.
For example, turning the nodes' file system state to CORRUPT from its original state GOOD.
@@ -38,4 +43,8 @@ def test_rewards_are_being_penalised_at_each_step_function():
For the 4 steps where this occurs the average reward is:
Average Reward: -8 (-120 / 15)
"""
assert env.average_reward == -8.0
with temp_primaite_session as session:
session.evaluate()
session.close()
ev_rewards = session.eval_av_reward_per_episode_csv()
assert ev_rewards[1] == -8.0

View File

@@ -1,9 +1,10 @@
import time
import pytest
from primaite.common.enums import HardwareState
from primaite.environment.primaite_env import Primaite
from tests import TEST_CONFIG_ROOT
from tests.conftest import _get_primaite_env_from_config
def run_generic_set_actions(env: Primaite):
@@ -17,7 +18,6 @@ def run_generic_set_actions(env: Primaite):
# TEMP - random action for now
# action = env.blue_agent_action(obs)
action = 0
print("Episode:", episode, "\nStep:", step)
if step == 5:
# [1, 1, 2, 1, 1, 1]
# Creates an ACL rule
@@ -44,59 +44,71 @@ def run_generic_set_actions(env: Primaite):
# env.close()
def test_single_action_space_is_valid():
"""Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations."""
env = _get_primaite_env_from_config(
training_config_path=TEST_CONFIG_ROOT / "single_action_space_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT
/ "single_action_space_lay_down_config.yaml",
)
@pytest.mark.parametrize(
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT / "single_action_space_main_config.yaml",
TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
]
],
indirect=True,
)
def test_single_action_space_is_valid(temp_primaite_session):
"""Test single action space is valid."""
with temp_primaite_session as session:
env = session.env
run_generic_set_actions(env)
# Retrieve the action space dictionary values from environment
env_action_space_dict = env.action_dict.values()
# Flags to check the conditions of the action space
contains_acl_actions = False
contains_node_actions = False
both_action_spaces = False
# Loop through each element of the list (which is every value from the dictionary)
for dict_item in env_action_space_dict:
# Node action detected
if len(dict_item) == 4:
contains_node_actions = True
# Link action detected
elif len(dict_item) == 6:
contains_acl_actions = True
# If both are there then the ANY action type is working
if contains_node_actions and contains_acl_actions:
both_action_spaces = True
# Check condition should be True
assert both_action_spaces
run_generic_set_actions(env)
# Retrieve the action space dictionary values from environment
env_action_space_dict = env.action_dict.values()
# Flags to check the conditions of the action space
contains_acl_actions = False
contains_node_actions = False
both_action_spaces = False
# Loop through each element of the list (which is every value from the dictionary)
for dict_item in env_action_space_dict:
# Node action detected
if len(dict_item) == 4:
contains_node_actions = True
# Link action detected
elif len(dict_item) == 6:
contains_acl_actions = True
# If both are there then the ANY action type is working
if contains_node_actions and contains_acl_actions:
both_action_spaces = True
# Check condition should be True
assert both_action_spaces
def test_agent_is_executing_actions_from_both_spaces():
@pytest.mark.parametrize(
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT / "single_action_space_fixed_blue_actions_main_config.yaml",
TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
]
],
indirect=True,
)
def test_agent_is_executing_actions_from_both_spaces(temp_primaite_session):
"""Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL)."""
env = _get_primaite_env_from_config(
training_config_path=TEST_CONFIG_ROOT
/ "single_action_space_fixed_blue_actions_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT
/ "single_action_space_lay_down_config.yaml",
)
# Run environment with specified fixed blue agent actions only
run_generic_set_actions(env)
# Retrieve hardware state of computer_1 node in laydown config
# Agent turned this off in Step 5
computer_node_hardware_state = env.nodes["1"].hardware_state
# Retrieve the Access Control List object stored by the environment at the end of the episode
access_control_list = env.acl
# Use the Access Control List object acl object attribute to get dictionary
# Use dictionary.values() to get total list of all items in the dictionary
acl_rules_list = access_control_list.acl.values()
# Length of this list tells you how many items are in the dictionary
# This number is the frequency of Access Control Rules in the environment
# In the scenario, we specified that the agent should create only 1 acl rule
num_of_rules = len(acl_rules_list)
# Therefore these statements below MUST be true
assert computer_node_hardware_state == HardwareState.OFF
assert num_of_rules == 1
with temp_primaite_session as session:
env = session.env
# Run environment with specified fixed blue agent actions only
run_generic_set_actions(env)
# Retrieve hardware state of computer_1 node in laydown config
# Agent turned this off in Step 5
computer_node_hardware_state = env.nodes["1"].hardware_state
# Retrieve the Access Control List object stored by the environment at the end of the episode
access_control_list = env.acl
# Use the Access Control List object acl object attribute to get dictionary
# Use dictionary.values() to get total list of all items in the dictionary
acl_rules_list = access_control_list.acl.values()
# Length of this list tells you how many items are in the dictionary
# This number is the frequency of Access Control Rules in the environment
# In the scenario, we specified that the agent should create only 1 acl rule
num_of_rules = len(acl_rules_list)
# Therefore these statements below MUST be true
assert computer_node_hardware_state == HardwareState.OFF
assert num_of_rules == 1

View File

@@ -7,8 +7,8 @@ from tests import TEST_CONFIG_ROOT
def test_legacy_lay_down_config_yaml_conversion():
"""Tests the conversion of legacy lay down config files."""
legacy_path = TEST_CONFIG_ROOT / "legacy" / "legacy_training_config.yaml"
new_path = TEST_CONFIG_ROOT / "legacy" / "new_training_config.yaml"
legacy_path = TEST_CONFIG_ROOT / "legacy_conversion" / "legacy_training_config.yaml"
new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "new_training_config.yaml"
with open(legacy_path, "r") as file:
legacy_dict = yaml.safe_load(file)
@@ -24,13 +24,13 @@ def test_legacy_lay_down_config_yaml_conversion():
def test_create_config_values_main_from_file():
"""Tests creating an instance of TrainingConfig from file."""
new_path = TEST_CONFIG_ROOT / "legacy" / "new_training_config.yaml"
new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "new_training_config.yaml"
training_config.load(new_path)
def test_create_config_values_main_from_legacy_file():
"""Tests creating an instance of TrainingConfig from legacy file."""
new_path = TEST_CONFIG_ROOT / "legacy" / "legacy_training_config.yaml"
new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "legacy_training_config.yaml"
training_config.load(new_path, legacy_file=True)