Merged PR 101: Integrate ADSP RLlib and use PrimaiteSession for running between agent frameworks
## Summary * Brought over the RLlib, hardcoded agents, and simple agents from ADSP 1.1.0. This opened a can of worms... ADSP got their stuff working in notebooks (***_stares at data scientists!_** 😂) but hadn't integrated it into the PrimAITE package or made the other PrimAITE functionality work with it. * RLlib agents have been fully integrated with the wider PrimAITE package. This was done by: * The creation of an `AgentSessionABC` and `HardCodedAgentSessionABC` classes. * `SB3Agent` and `RLlibAgent` classes then inherited from `AgentSessionABC`. * The ADSP hardcoded agents were integrated into subclasses of `HardCodedAgentSessionABC`. * The random and dummy agents were also integrated into subclasses of `HardCodedagentSessionABC`. * A set of session output directories were created and managed by the agent session to enable consistent storage of session outputs in a common format regardless of the agent type. * The main config was rafactored so that it had * **agent_framework** - To identify whether SB3, RLlib, or Custom. * **agent_identifier** - To identify whether PPO, A2C, hardcoded, random, or dummy. * **deep_learning_framework** - To identify which framework to use for RLlib. * Transactions have been overhauled to simplify the process. It also means that they're written in real time so they're not lost if the agent crashes. * Tests completely overhauled to use `PrimaiteSession`, or at least a test subclass, `TempPrimaiteSession`. It's temp because it uses temp directory rather than main primaite session directory, and it cleans up after itself. * All the crap removed from `main.py` and made it so that it just runs `PrimaiteSession`. Now this is where I went off on a tangent... * CLI added to just make my life and everyone else's life easier. * Primaite app config added to hold things like logging format, levels etc. * A `primaite.data_viz.session_plots` module added so that the average reward per episode for each session is plotted and saves for each session (this helped while we were testing and bug fixing). ## Test process * All tests use `TempPrimaiteSession`, which uses `PrimaiteSession`. * I still need to write a tests that runs the RLlib, hardcoded, and random/dummy agents. I'll do that now while this is being reviewed. ## Still to do * Update docs. I'm getting this PR up now so we can get it in to make use of the features. I'll get the docs updated today either on this branch or another branch (depending on how long this review takes). ## Checklist - [X] This PR is linked to a **work item** - [X] I have performed **self-review** of the code - [X] I have written **tests** for any new functionality added with this PR - [ ] I have updated the **documentation** if this PR changes or adds functionality - [X] I have run **pre-commit** checks for code style Related work items: #917, #1563
This commit is contained in:
@@ -1,38 +1,12 @@
|
||||
trigger:
|
||||
- main
|
||||
|
||||
pool:
|
||||
vmImage: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
Ubuntu2004Python38:
|
||||
python.version: '3.8'
|
||||
imageName: 'ubuntu-20.04'
|
||||
Ubuntu2004Python39:
|
||||
python.version: '3.9'
|
||||
imageName: 'ubuntu-20.04'
|
||||
Ubuntu2004Python310:
|
||||
Python310:
|
||||
python.version: '3.10'
|
||||
imageName: 'ubuntu-20.04'
|
||||
WindowsPython38:
|
||||
python.version: '3.8'
|
||||
imageName: 'windows-latest'
|
||||
WindowsPython39:
|
||||
python.version: '3.9'
|
||||
imageName: 'windows-latest'
|
||||
WindowsPython310:
|
||||
python.version: '3.10'
|
||||
imageName: 'windows-latest'
|
||||
MacPython38:
|
||||
python.version: '3.8'
|
||||
imageName: 'macOS-latest'
|
||||
MacPython39:
|
||||
python.version: '3.9'
|
||||
imageName: 'macOS-latest'
|
||||
MacPython310:
|
||||
python.version: '3.10'
|
||||
imageName: 'macOS-latest'
|
||||
|
||||
pool:
|
||||
vmImage: $(imageName)
|
||||
|
||||
steps:
|
||||
- task: UsePythonVersion@0
|
||||
|
||||
@@ -6,18 +6,29 @@ trigger:
|
||||
- bugfix/*
|
||||
- release/*
|
||||
|
||||
pool:
|
||||
vmImage: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
Python38:
|
||||
UbuntuPython38:
|
||||
python.version: '3.8'
|
||||
Python39:
|
||||
python.version: '3.9'
|
||||
Python310:
|
||||
imageName: 'ubuntu-latest'
|
||||
UbuntuPython310:
|
||||
python.version: '3.10'
|
||||
Python311:
|
||||
python.version: '3.11'
|
||||
imageName: 'ubuntu-latest'
|
||||
WindowsPython38:
|
||||
python.version: '3.8'
|
||||
imageName: 'windows-latest'
|
||||
WindowsPython310:
|
||||
python.version: '3.10'
|
||||
imageName: 'windows-latest'
|
||||
MacOSPython38:
|
||||
python.version: '3.8'
|
||||
imageName: 'macOS-latest'
|
||||
MacOSPython310:
|
||||
python.version: '3.10'
|
||||
imageName: 'macOS-latest'
|
||||
|
||||
pool:
|
||||
vmImage: $(imageName)
|
||||
|
||||
steps:
|
||||
- task: UsePythonVersion@0
|
||||
@@ -47,16 +58,17 @@ steps:
|
||||
PRIMAITE_WHEEL=$(ls ./dist/primaite*.whl)
|
||||
python -m pip install $PRIMAITE_WHEEL[dev]
|
||||
displayName: 'Install PrimAITE'
|
||||
condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' ))
|
||||
|
||||
- script: |
|
||||
forfiles /p dist\ /m *.whl /c "cmd /c python -m pip install @file[dev]"
|
||||
displayName: 'Install PrimAITE'
|
||||
condition: eq( variables['Agent.OS'], 'Windows_NT' )
|
||||
|
||||
- script: |
|
||||
primaite setup
|
||||
displayName: 'Perform PrimAITE Setup'
|
||||
|
||||
#- script: |
|
||||
# flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
|
||||
# displayName: 'Lint with flake8'
|
||||
|
||||
- script: |
|
||||
pytest tests/
|
||||
displayName: 'Run unmarked tests'
|
||||
displayName: 'Run tests'
|
||||
|
||||
@@ -13,6 +13,9 @@ repos:
|
||||
rev: 23.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
args: [ "--line-length=120" ]
|
||||
additional_dependencies:
|
||||
- jupyter
|
||||
- repo: http://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
@@ -22,4 +25,5 @@ repos:
|
||||
rev: 6.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
additional_dependencies: [ flake8-docstrings ]
|
||||
additional_dependencies:
|
||||
- flake8-docstrings
|
||||
|
||||
@@ -20,13 +20,24 @@ The environment config file consists of the following attributes:
|
||||
|
||||
**Generic Config Values**
|
||||
|
||||
* **agent_identifier** [enum]
|
||||
|
||||
This identifies the agent to use for the session. Select from one of the following:
|
||||
* **agent_framework** [enum]
|
||||
|
||||
This identifies the agent framework to be used to instantiate the agent algorithm. Select from one of the following:
|
||||
|
||||
* NONE - Where a user developed agent is to be used
|
||||
* SB3 - Stable Baselines3
|
||||
* RLLIB - Ray RLlib.
|
||||
|
||||
* **agent_identifier**
|
||||
|
||||
This identifies the agent to use for the session. Select from one of the following:
|
||||
|
||||
* A2C - Advantage Actor Critic
|
||||
* PPO - Proximal Policy Optimization
|
||||
* HARDCODED - A custom built deterministic agent
|
||||
* RANDOM - A Stochastic random agent
|
||||
|
||||
* GENERIC - Where a user developed agent is to be used
|
||||
* STABLE_BASELINES3_PPO - Use a SB3 PPO agent
|
||||
* STABLE_BASELINES3_A2C - use a SB3 A2C agent
|
||||
|
||||
* **random_red_agent** [bool]
|
||||
|
||||
@@ -34,38 +45,38 @@ The environment config file consists of the following attributes:
|
||||
|
||||
* **action_type** [enum]
|
||||
|
||||
Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session
|
||||
Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session
|
||||
|
||||
|
||||
* **num_episodes** [int]
|
||||
|
||||
This defines the number of episodes that the agent will train or be evaluated over.
|
||||
This defines the number of episodes that the agent will train or be evaluated over.
|
||||
|
||||
* **num_steps** [int]
|
||||
|
||||
Determines the number of steps to run in each episode of the session
|
||||
Determines the number of steps to run in each episode of the session
|
||||
|
||||
|
||||
* **time_delay** [int]
|
||||
|
||||
The time delay (in milliseconds) to take between each step when running a GENERIC agent session
|
||||
The time delay (in milliseconds) to take between each step when running a GENERIC agent session
|
||||
|
||||
|
||||
* **session_type** [text]
|
||||
|
||||
Type of session to be run (TRAINING or EVALUATION)
|
||||
Type of session to be run (TRAINING, EVALUATION, or BOTH)
|
||||
|
||||
* **load_agent** [bool]
|
||||
|
||||
Determine whether to load an agent from file
|
||||
Determine whether to load an agent from file
|
||||
|
||||
* **agent_load_file** [text]
|
||||
|
||||
File path and file name of agent if you're loading one in
|
||||
File path and file name of agent if you're loading one in
|
||||
|
||||
* **observation_space_high_value** [int]
|
||||
|
||||
The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases
|
||||
The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases
|
||||
|
||||
**Reward-Based Config Values**
|
||||
|
||||
@@ -73,95 +84,95 @@ Rewards are calculated based on the difference between the current state and ref
|
||||
|
||||
* **Generic [all_ok]** [int]
|
||||
|
||||
The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken)
|
||||
The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken)
|
||||
|
||||
* **Node Hardware State [off_should_be_on]** [int]
|
||||
|
||||
The score to give when the node should be on, but is off
|
||||
The score to give when the node should be on, but is off
|
||||
|
||||
* **Node Hardware State [off_should_be_resetting]** [int]
|
||||
|
||||
The score to give when the node should be resetting, but is off
|
||||
The score to give when the node should be resetting, but is off
|
||||
|
||||
* **Node Hardware State [on_should_be_off]** [int]
|
||||
|
||||
The score to give when the node should be off, but is on
|
||||
The score to give when the node should be off, but is on
|
||||
|
||||
* **Node Hardware State [on_should_be_resetting]** [int]
|
||||
|
||||
The score to give when the node should be resetting, but is on
|
||||
The score to give when the node should be resetting, but is on
|
||||
|
||||
* **Node Hardware State [resetting_should_be_on]** [int]
|
||||
|
||||
The score to give when the node should be on, but is resetting
|
||||
The score to give when the node should be on, but is resetting
|
||||
|
||||
* **Node Hardware State [resetting_should_be_off]** [int]
|
||||
|
||||
The score to give when the node should be off, but is resetting
|
||||
The score to give when the node should be off, but is resetting
|
||||
|
||||
* **Node Hardware State [resetting]** [int]
|
||||
|
||||
The score to give when the node is resetting
|
||||
The score to give when the node is resetting
|
||||
|
||||
* **Node Operating System or Service State [good_should_be_patching]** [int]
|
||||
|
||||
The score to give when the state should be patching, but is good
|
||||
The score to give when the state should be patching, but is good
|
||||
|
||||
* **Node Operating System or Service State [good_should_be_compromised]** [int]
|
||||
|
||||
The score to give when the state should be compromised, but is good
|
||||
The score to give when the state should be compromised, but is good
|
||||
|
||||
* **Node Operating System or Service State [good_should_be_overwhelmed]** [int]
|
||||
|
||||
The score to give when the state should be overwhelmed, but is good
|
||||
The score to give when the state should be overwhelmed, but is good
|
||||
|
||||
* **Node Operating System or Service State [patching_should_be_good]** [int]
|
||||
|
||||
The score to give when the state should be good, but is patching
|
||||
The score to give when the state should be good, but is patching
|
||||
|
||||
* **Node Operating System or Service State [patching_should_be_compromised]** [int]
|
||||
|
||||
The score to give when the state should be compromised, but is patching
|
||||
The score to give when the state should be compromised, but is patching
|
||||
|
||||
* **Node Operating System or Service State [patching_should_be_overwhelmed]** [int]
|
||||
|
||||
The score to give when the state should be overwhelmed, but is patching
|
||||
The score to give when the state should be overwhelmed, but is patching
|
||||
|
||||
* **Node Operating System or Service State [patching]** [int]
|
||||
|
||||
The score to give when the state is patching
|
||||
The score to give when the state is patching
|
||||
|
||||
* **Node Operating System or Service State [compromised_should_be_good]** [int]
|
||||
|
||||
The score to give when the state should be good, but is compromised
|
||||
The score to give when the state should be good, but is compromised
|
||||
|
||||
* **Node Operating System or Service State [compromised_should_be_patching]** [int]
|
||||
|
||||
The score to give when the state should be patching, but is compromised
|
||||
The score to give when the state should be patching, but is compromised
|
||||
|
||||
* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [int]
|
||||
|
||||
The score to give when the state should be overwhelmed, but is compromised
|
||||
The score to give when the state should be overwhelmed, but is compromised
|
||||
|
||||
* **Node Operating System or Service State [compromised]** [int]
|
||||
|
||||
The score to give when the state is compromised
|
||||
The score to give when the state is compromised
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed_should_be_good]** [int]
|
||||
|
||||
The score to give when the state should be good, but is overwhelmed
|
||||
The score to give when the state should be good, but is overwhelmed
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed_should_be_patching]** [int]
|
||||
|
||||
The score to give when the state should be patching, but is overwhelmed
|
||||
The score to give when the state should be patching, but is overwhelmed
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [int]
|
||||
|
||||
The score to give when the state should be compromised, but is overwhelmed
|
||||
The score to give when the state should be compromised, but is overwhelmed
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed]** [int]
|
||||
|
||||
The score to give when the state is overwhelmed
|
||||
The score to give when the state is overwhelmed
|
||||
|
||||
* **Node File System State [good_should_be_repairing]** [int]
|
||||
|
||||
@@ -265,37 +276,37 @@ Rewards are calculated based on the difference between the current state and ref
|
||||
|
||||
* **IER Status [red_ier_running]** [int]
|
||||
|
||||
The score to give when a red agent IER is permitted to run
|
||||
The score to give when a red agent IER is permitted to run
|
||||
|
||||
* **IER Status [green_ier_blocked]** [int]
|
||||
|
||||
The score to give when a green agent IER is prevented from running
|
||||
The score to give when a green agent IER is prevented from running
|
||||
|
||||
**Patching / Reset Durations**
|
||||
|
||||
* **os_patching_duration** [int]
|
||||
|
||||
The number of steps to take when patching an Operating System
|
||||
The number of steps to take when patching an Operating System
|
||||
|
||||
* **node_reset_duration** [int]
|
||||
|
||||
The number of steps to take when resetting a node's hardware state
|
||||
The number of steps to take when resetting a node's hardware state
|
||||
|
||||
* **service_patching_duration** [int]
|
||||
|
||||
The number of steps to take when patching a service
|
||||
The number of steps to take when patching a service
|
||||
|
||||
* **file_system_repairing_limit** [int]:
|
||||
|
||||
The number of steps to take when repairing the file system
|
||||
The number of steps to take when repairing the file system
|
||||
|
||||
* **file_system_restoring_limit** [int]
|
||||
|
||||
The number of steps to take when restoring the file system
|
||||
The number of steps to take when restoring the file system
|
||||
|
||||
* **file_system_scanning_limit** [int]
|
||||
|
||||
The number of steps to take when scanning the file system
|
||||
The number of steps to take when scanning the file system
|
||||
|
||||
The Lay Down Config
|
||||
*******************
|
||||
@@ -304,22 +315,22 @@ The lay down config file consists of the following attributes:
|
||||
|
||||
* **itemType: ACTIONS** [enum]
|
||||
|
||||
Determines whether a NODE or ACL action space format is adopted for the session
|
||||
Determines whether a NODE or ACL action space format is adopted for the session
|
||||
|
||||
* **itemType: OBSERVATION_SPACE** [dict]
|
||||
|
||||
Allows for user to configure observation space by combining one or more observation components. List of available
|
||||
components is is :py:mod:'primaite.environment.observations'.
|
||||
Allows for user to configure observation space by combining one or more observation components. List of available
|
||||
components is is :py:mod:'primaite.environment.observations'.
|
||||
|
||||
The observation space config item should have a ``components`` key which is a list of components. Each component
|
||||
config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the
|
||||
component while it is being initialised.
|
||||
The observation space config item should have a ``components`` key which is a list of components. Each component
|
||||
config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the
|
||||
component while it is being initialised.
|
||||
|
||||
This example illustrates the correct format for the observation space config item
|
||||
This example illustrates the correct format for the observation space config item
|
||||
|
||||
.. code-block::yaml
|
||||
|
||||
- itemType: OBSERVATION_SPACE
|
||||
- item_type: OBSERVATION_SPACE
|
||||
components:
|
||||
- name: LINK_TRAFFIC_LEVELS
|
||||
options:
|
||||
@@ -332,15 +343,15 @@ The lay down config file consists of the following attributes:
|
||||
|
||||
* **item_type: PORTS** [int]
|
||||
|
||||
Provides a list of ports modelled in this session
|
||||
Provides a list of ports modelled in this session
|
||||
|
||||
* **item_type: SERVICES** [freetext]
|
||||
|
||||
Provides a list of services modelled in this session
|
||||
Provides a list of services modelled in this session
|
||||
|
||||
* **item_type: NODE**
|
||||
|
||||
Defines a node included in the system laydown being simulated. It should consist of the following attributes:
|
||||
Defines a node included in the system laydown being simulated. It should consist of the following attributes:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **name** [freetext]: Human-readable name of the component
|
||||
@@ -359,7 +370,7 @@ The lay down config file consists of the following attributes:
|
||||
|
||||
* **item_type: LINK**
|
||||
|
||||
Defines a link included in the system laydown being simulated. It should consist of the following attributes:
|
||||
Defines a link included in the system laydown being simulated. It should consist of the following attributes:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **name** [freetext]: Human-readable name of the component
|
||||
@@ -369,7 +380,7 @@ The lay down config file consists of the following attributes:
|
||||
|
||||
* **item_type: GREEN_IER**
|
||||
|
||||
Defines a green agent Information Exchange Requirement (IER). It should consist of:
|
||||
Defines a green agent Information Exchange Requirement (IER). It should consist of:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **start_step** [int]: The start step (in the episode) for this IER to begin
|
||||
@@ -383,7 +394,7 @@ The lay down config file consists of the following attributes:
|
||||
|
||||
* **item_type: RED_IER**
|
||||
|
||||
Defines a red agent Information Exchange Requirement (IER). It should consist of:
|
||||
Defines a red agent Information Exchange Requirement (IER). It should consist of:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **start_step** [int]: The start step (in the episode) for this IER to begin
|
||||
|
||||
@@ -47,6 +47,8 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| asttokens | 2.2.1 | Apache 2.0 | https://github.com/gristlabs/asttokens |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| astunparse | 1.6.3 | BSD License | https://github.com/simonpercivall/astunparse |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| attrs | 23.1.0 | MIT License | https://www.attrs.org/en/stable/changelog.html |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| backcall | 0.2.0 | BSD License | https://github.com/takluyver/backcall |
|
||||
@@ -103,6 +105,8 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| flake8 | 6.0.0 | MIT License | https://github.com/pycqa/flake8 |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| flatbuffers | 23.5.26 | Apache Software License | https://google.github.io/flatbuffers/ |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| fonttools | 4.39.4 | MIT License | http://github.com/fonttools/fonttools |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| fqdn | 1.5.1 | Mozilla Public License 2.0 (MPL 2.0) | https://github.com/ypcrts/fqdn |
|
||||
@@ -111,9 +115,13 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| furo | 2023.3.27 | MIT License | UNKNOWN |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| gast | 0.4.0 | BSD License | https://github.com/serge-sans-paille/gast/ |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| google-auth | 2.19.0 | Apache Software License | https://github.com/googleapis/google-auth-library-python |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| google-auth-oauthlib | 1.0.0 | Apache Software License | https://github.com/GoogleCloudPlatform/google-auth-library-python-oauthlib |
|
||||
| google-auth-oauthlib | 0.4.6 | Apache Software License | https://github.com/GoogleCloudPlatform/google-auth-library-python-oauthlib |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| google-pasta | 0.2.0 | Apache Software License | https://github.com/google/pasta |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| grpcio | 1.51.3 | Apache Software License | https://grpc.io |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
@@ -121,6 +129,8 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| gymnasium-notices | 0.0.1 | MIT License | https://github.com/Farama-Foundation/gym-notices |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| h5py | 3.9.0 | BSD License | https://www.h5py.org/ |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| identify | 2.5.24 | MIT License | https://github.com/pre-commit/identify |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| idna | 3.4 | BSD License | https://github.com/kjd/idna |
|
||||
@@ -141,6 +151,8 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| isoduration | 20.11.0 | ISC License (ISCL) | https://github.com/bolsote/isoduration |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| jax | 0.4.12 | Apache-2.0 | https://github.com/google/jax |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| jedi | 0.18.2 | MIT License | https://github.com/davidhalter/jedi |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| json5 | 0.9.14 | Apache Software License | https://github.com/dpranke/pyjson5 |
|
||||
@@ -151,14 +163,14 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| jupyter-events | 0.6.3 | BSD License | http://jupyter.org |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| jupyter-server | 1.24.0 | BSD License | https://jupyter-server.readthedocs.io |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| jupyter-ydoc | 0.2.4 | BSD 3-Clause License | https://jupyter.org |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| jupyter_client | 8.2.0 | BSD License | https://jupyter.org |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| jupyter_core | 5.3.0 | BSD License | https://jupyter.org |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| jupyter_server | 2.6.0 | BSD License | https://jupyter-server.readthedocs.io |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| jupyter_server_fileid | 0.9.0 | BSD License | UNKNOWN |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| jupyter_server_terminals | 0.4.4 | BSD License | https://jupyter.org |
|
||||
@@ -171,10 +183,14 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| jupyterlab_server | 2.22.1 | BSD License | https://jupyterlab-server.readthedocs.io |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| keras | 2.12.0 | Apache Software License | https://keras.io/ |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| kiwisolver | 1.4.4 | BSD License | UNKNOWN |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| lazy_loader | 0.2 | BSD License | https://github.com/scientific-python/lazy_loader |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| libclang | 16.0.0 | Apache Software License | https://github.com/sighingnow/libclang |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| lz4 | 4.3.2 | BSD License | https://github.com/python-lz4/python-lz4 |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| markdown-it-py | 2.2.0 | MIT License | https://github.com/executablebooks/markdown-it-py |
|
||||
@@ -183,19 +199,23 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| matplotlib-inline | 0.1.6 | BSD 3-Clause | https://github.com/ipython/matplotlib-inline |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| mavizstyle | 1.0.0 | UNKNOWN | UNKNOWN |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| mccabe | 0.7.0 | MIT License | https://github.com/pycqa/mccabe |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| mdurl | 0.1.2 | MIT License | https://github.com/executablebooks/mdurl |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| mistune | 2.0.5 | BSD License | https://github.com/lepture/mistune |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| ml-dtypes | 0.2.0 | Apache Software License | UNKNOWN |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| mock | 5.0.2 | BSD License | http://mock.readthedocs.org/en/latest/ |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| mpmath | 1.3.0 | BSD License | http://mpmath.org/ |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| msgpack | 1.0.5 | Apache Software License | https://msgpack.org/ |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| nbclassic | 1.0.0 | BSD License | https://github.com/jupyter/nbclassic |
|
||||
| nbclassic | 0.5.6 | BSD License | https://github.com/jupyter/nbclassic |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| nbclient | 0.8.0 | BSD License | https://jupyter.org |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
@@ -217,6 +237,8 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| oauthlib | 3.2.2 | BSD License | https://github.com/oauthlib/oauthlib |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| opt-einsum | 3.3.0 | MIT | https://github.com/dgasmith/opt_einsum |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| overrides | 7.3.1 | Apache License, Version 2.0 | https://github.com/mkorpela/overrides |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| packaging | 23.1 | Apache Software License; BSD License | https://github.com/pypa/packaging |
|
||||
@@ -231,11 +253,17 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| platformdirs | 3.5.1 | MIT License | https://github.com/platformdirs/platformdirs |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| plotly | 5.15.0 | MIT License | https://plotly.com/python/ |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| pluggy | 1.0.0 | MIT License | https://github.com/pytest-dev/pluggy |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| pre-commit | 2.20.0 | MIT License | https://github.com/pre-commit/pre-commit |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| primaite | 1.2.1 | MIT License | UNKNOWN |
|
||||
| primaite | 2.0.0rc1 | MIT License | UNKNOWN |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| primaite | 2.0.0rc1 | MIT License | UNKNOWN |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| primaite | 2.0.0rc1 | MIT License | UNKNOWN |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| prometheus-client | 0.17.0 | Apache Software License | https://github.com/prometheus/client_python |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
@@ -295,6 +323,8 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| rsa | 4.9 | Apache Software License | https://stuvel.eu/rsa |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| ruff | 0.0.272 | MIT License | https://github.com/charliermarsh/ruff |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| scikit-image | 0.20.0 | BSD License | https://scikit-image.org |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| scipy | 1.10.1 | BSD License | https://scipy.org/ |
|
||||
@@ -335,14 +365,26 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| tabulate | 0.9.0 | MIT License | https://github.com/astanin/python-tabulate |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| tensorboard | 2.12.3 | Apache Software License | https://github.com/tensorflow/tensorboard |
|
||||
| tenacity | 8.2.2 | Apache Software License | https://github.com/jd/tenacity |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| tensorboard-data-server | 0.7.0 | Apache Software License | https://github.com/tensorflow/tensorboard/tree/master/tensorboard/data/server |
|
||||
| tensorboard | 2.11.2 | Apache Software License | https://github.com/tensorflow/tensorboard |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| tensorboard-data-server | 0.6.1 | Apache Software License | https://github.com/tensorflow/tensorboard/tree/master/tensorboard/data/server |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| tensorboard-plugin-wit | 1.8.1 | Apache 2.0 | https://whatif-tool.dev |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| tensorboardX | 2.6 | MIT License | https://github.com/lanpa/tensorboardX |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| tensorflow | 2.12.0 | Apache Software License | https://www.tensorflow.org/ |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| tensorflow-estimator | 2.12.0 | Apache Software License | https://www.tensorflow.org/ |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| tensorflow-intel | 2.12.0 | Apache Software License | https://www.tensorflow.org/ |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| tensorflow-io-gcs-filesystem | 0.31.0 | Apache Software License | https://github.com/tensorflow/io |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| termcolor | 2.3.0 | MIT License | https://github.com/termcolor/termcolor |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| terminado | 0.17.1 | BSD License | https://github.com/jupyter/terminado |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| tifffile | 2023.4.12 | BSD License | https://www.cgohlke.com |
|
||||
@@ -377,6 +419,8 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| websocket-client | 1.5.2 | Apache Software License | https://github.com/websocket-client/websocket-client.git |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| wrapt | 1.14.1 | BSD License | https://github.com/GrahamDumpleton/wrapt |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| y-py | 0.5.9 | MIT License | https://github.com/y-crdt/ypy |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| ypy-websocket | 0.8.2 | UNKNOWN | https://github.com/y-crdt/ypy-websocket |
|
||||
|
||||
@@ -7,7 +7,7 @@ name = "primaite"
|
||||
description = "PrimAITE (Primary-level AI Training Environment) is a simulation environment for training AI under the ARCD programme."
|
||||
authors = [{name="QinetiQ Training and Simulation Ltd"}]
|
||||
license = {text = "MIT License"}
|
||||
requires-python = ">=3.8"
|
||||
requires-python = ">=3.8, <3.11"
|
||||
dynamic = ["version", "readme"]
|
||||
classifiers = [
|
||||
"License :: MIT License",
|
||||
@@ -20,19 +20,23 @@ classifiers = [
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3 :: Only",
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
"gym==0.21.0",
|
||||
"jupyterlab==3.6.1",
|
||||
"kaleido==0.2.1",
|
||||
"matplotlib==3.7.1",
|
||||
"networkx==3.1",
|
||||
"numpy==1.23.5",
|
||||
"platformdirs==3.5.1",
|
||||
"plotly==5.15.0",
|
||||
"polars==0.18.4",
|
||||
"PyYAML==6.0",
|
||||
"ray[rllib]==2.2.0",
|
||||
"stable-baselines3==1.6.2",
|
||||
"tensorflow==2.12.0",
|
||||
"typer[all]==0.9.0"
|
||||
]
|
||||
|
||||
@@ -62,9 +66,15 @@ dev = [
|
||||
"sphinx-copybutton==0.5.2",
|
||||
"wheel==0.38.4"
|
||||
]
|
||||
tensorflow = [
|
||||
"tensorflow==2.12.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
primaite = "primaite.cli:app"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 120
|
||||
force_sort_within_sections = "False"
|
||||
order_by_type = "False"
|
||||
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
|
||||
@@ -1 +1 @@
|
||||
2.0.0dev0
|
||||
2.0.0rc1
|
||||
|
||||
@@ -2,10 +2,11 @@
|
||||
import logging
|
||||
import logging.config
|
||||
import sys
|
||||
from logging import Logger, StreamHandler
|
||||
from bisect import bisect
|
||||
from logging import Formatter, Logger, LogRecord, StreamHandler
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
from typing import Dict, Final
|
||||
|
||||
import pkg_resources
|
||||
import yaml
|
||||
@@ -18,11 +19,7 @@ _PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite")
|
||||
def _get_primaite_config():
|
||||
config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml"
|
||||
if not config_path.exists():
|
||||
config_path = Path(
|
||||
pkg_resources.resource_filename(
|
||||
"primaite", "setup/_package_data/primaite_config.yaml"
|
||||
)
|
||||
)
|
||||
config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
|
||||
with open(config_path, "r") as file:
|
||||
primaite_config = yaml.safe_load(file)
|
||||
log_level_map = {
|
||||
@@ -33,7 +30,7 @@ def _get_primaite_config():
|
||||
"ERROR": logging.ERROR,
|
||||
"CRITICAL": logging.CRITICAL,
|
||||
}
|
||||
primaite_config["log_level"] = log_level_map[primaite_config["log_level"]]
|
||||
primaite_config["log_level"] = log_level_map[primaite_config["logging"]["log_level"]]
|
||||
return primaite_config
|
||||
|
||||
|
||||
@@ -68,6 +65,28 @@ Users PrimAITE Sessions are stored at: ``~/primaite/sessions``.
|
||||
|
||||
|
||||
# region Setup Logging
|
||||
class _LevelFormatter(Formatter):
|
||||
"""
|
||||
A custom level-specific formatter.
|
||||
|
||||
Credit to: https://stackoverflow.com/a/68154386
|
||||
"""
|
||||
|
||||
def __init__(self, formats: Dict[int, str], **kwargs):
|
||||
super().__init__()
|
||||
|
||||
if "fmt" in kwargs:
|
||||
raise ValueError("Format string must be passed to level-surrogate formatters, " "not this one")
|
||||
|
||||
self.formats = sorted((level, Formatter(fmt, **kwargs)) for level, fmt in formats.items())
|
||||
|
||||
def format(self, record: LogRecord) -> str:
|
||||
"""Overrides ``Formatter.format``."""
|
||||
idx = bisect(self.formats, (record.levelno,), hi=len(self.formats) - 1)
|
||||
level, formatter = self.formats[idx]
|
||||
return formatter.format(record)
|
||||
|
||||
|
||||
def _log_dir() -> Path:
|
||||
if sys.platform == "win32":
|
||||
dir_path = _PLATFORM_DIRS.user_data_path / "logs"
|
||||
@@ -76,6 +95,16 @@ def _log_dir() -> Path:
|
||||
return dir_path
|
||||
|
||||
|
||||
_LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter(
|
||||
{
|
||||
logging.DEBUG: _PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"],
|
||||
logging.INFO: _PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"],
|
||||
logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"],
|
||||
logging.ERROR: _PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"],
|
||||
logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"],
|
||||
}
|
||||
)
|
||||
|
||||
LOG_DIR: Final[Path] = _log_dir()
|
||||
"""The path to the app log directory as an instance of `Path` or `PosixPath`, depending on the OS."""
|
||||
|
||||
@@ -85,18 +114,19 @@ LOG_PATH: Final[Path] = LOG_DIR / "primaite.log"
|
||||
"""The primaite.log file path as an instance of `Path` or `PosixPath`, depending on the OS."""
|
||||
|
||||
_STREAM_HANDLER: Final[StreamHandler] = StreamHandler()
|
||||
|
||||
_FILE_HANDLER: Final[RotatingFileHandler] = RotatingFileHandler(
|
||||
filename=LOG_PATH,
|
||||
maxBytes=10485760, # 10MB
|
||||
backupCount=9, # Max 100MB of logs
|
||||
encoding="utf8",
|
||||
)
|
||||
_STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["log_level"])
|
||||
_FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["log_level"])
|
||||
_STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"])
|
||||
_FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"])
|
||||
|
||||
_LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logger_format"]
|
||||
_STREAM_HANDLER.setFormatter(logging.Formatter(_LOG_FORMAT_STR))
|
||||
_FILE_HANDLER.setFormatter(logging.Formatter(_LOG_FORMAT_STR))
|
||||
_LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logging"]["logger_format"]
|
||||
_STREAM_HANDLER.setFormatter(_LEVEL_FORMATTER)
|
||||
_FILE_HANDLER.setFormatter(_LEVEL_FORMATTER)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -104,7 +134,7 @@ _LOGGER.addHandler(_STREAM_HANDLER)
|
||||
_LOGGER.addHandler(_FILE_HANDLER)
|
||||
|
||||
|
||||
def getLogger(name: str) -> Logger:
|
||||
def getLogger(name: str) -> Logger: # noqa
|
||||
"""
|
||||
Get a PrimAITE logger.
|
||||
|
||||
|
||||
@@ -25,18 +25,9 @@ class AccessControlList:
|
||||
True if match; False otherwise.
|
||||
"""
|
||||
if (
|
||||
(
|
||||
_rule.get_source_ip() == _source_ip_address
|
||||
and _rule.get_dest_ip() == _dest_ip_address
|
||||
)
|
||||
or (
|
||||
_rule.get_source_ip() == "ANY"
|
||||
and _rule.get_dest_ip() == _dest_ip_address
|
||||
)
|
||||
or (
|
||||
_rule.get_source_ip() == _source_ip_address
|
||||
and _rule.get_dest_ip() == "ANY"
|
||||
)
|
||||
(_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address)
|
||||
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == _dest_ip_address)
|
||||
or (_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == "ANY")
|
||||
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY")
|
||||
):
|
||||
return True
|
||||
@@ -57,15 +48,9 @@ class AccessControlList:
|
||||
Indicates block if all conditions are satisfied.
|
||||
"""
|
||||
for rule_key, rule_value in self.acl.items():
|
||||
if self.check_address_match(
|
||||
rule_value, _source_ip_address, _dest_ip_address
|
||||
):
|
||||
if (
|
||||
rule_value.get_protocol() == _protocol
|
||||
or rule_value.get_protocol() == "ANY"
|
||||
) and (
|
||||
str(rule_value.get_port()) == str(_port)
|
||||
or rule_value.get_port() == "ANY"
|
||||
if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address):
|
||||
if (rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY") and (
|
||||
str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY"
|
||||
):
|
||||
# There's a matching rule. Get the permission
|
||||
if rule_value.get_permission() == "DENY":
|
||||
|
||||
@@ -30,7 +30,13 @@ class ACLRule:
|
||||
Returns hash of core parameters.
|
||||
"""
|
||||
return hash(
|
||||
(self.permission, self.source_ip, self.dest_ip, self.protocol, self.port)
|
||||
(
|
||||
self.permission,
|
||||
self.source_ip,
|
||||
self.dest_ip,
|
||||
self.protocol,
|
||||
self.port,
|
||||
)
|
||||
)
|
||||
|
||||
def get_permission(self):
|
||||
|
||||
383
src/primaite/agents/agent.py
Normal file
383
src/primaite/agents/agent.py
Normal file
@@ -0,0 +1,383 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Final, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
|
||||
import primaite
|
||||
from primaite import getLogger, SESSIONS_DIR
|
||||
from primaite.config import lay_down_config, training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.data_viz.session_plots import plot_av_reward_per_episode
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def get_session_path(session_timestamp: datetime) -> Path:
|
||||
"""
|
||||
Get the directory path the session will output to.
|
||||
|
||||
This is set in the format of:
|
||||
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
|
||||
:param session_timestamp: This is the datetime that the session started.
|
||||
:return: The session directory path.
|
||||
"""
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = SESSIONS_DIR / date_dir / session_path
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
return session_path
|
||||
|
||||
|
||||
class AgentSessionABC(ABC):
|
||||
"""
|
||||
An ABC that manages training and/or evaluation of agents in PrimAITE.
|
||||
|
||||
This class cannot be directly instantiated and must be inherited from
|
||||
with all implemented abstract methods implemented.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Final[Union[Path, str]] = training_config_path
|
||||
self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path)
|
||||
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
|
||||
|
||||
self._env: Primaite
|
||||
self._agent = None
|
||||
self._can_learn: bool = False
|
||||
self._can_evaluate: bool = False
|
||||
self.is_eval = False
|
||||
|
||||
self._uuid = str(uuid4())
|
||||
self.session_timestamp: datetime = datetime.now()
|
||||
"The session timestamp"
|
||||
self.session_path = get_session_path(self.session_timestamp)
|
||||
"The Session path"
|
||||
|
||||
@property
|
||||
def timestamp_str(self) -> str:
|
||||
"""The session timestamp as a string."""
|
||||
return self.session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
@property
|
||||
def learning_path(self) -> Path:
|
||||
"""The learning outputs path."""
|
||||
path = self.session_path / "learning"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def evaluation_path(self) -> Path:
|
||||
"""The evaluation outputs path."""
|
||||
path = self.session_path / "evaluation"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def checkpoints_path(self) -> Path:
|
||||
"""The Session checkpoints path."""
|
||||
path = self.learning_path / "checkpoints"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def uuid(self):
|
||||
"""The Agent Session UUID."""
|
||||
return self._uuid
|
||||
|
||||
def _write_session_metadata_file(self):
|
||||
"""
|
||||
Write the ``session_metadata.json`` file.
|
||||
|
||||
Creates a ``session_metadata.json`` in the ``session_path`` directory
|
||||
and adds the following key/value pairs:
|
||||
|
||||
- uuid: The UUID assigned to the session upon instantiation.
|
||||
- start_datetime: The date & time the session started in iso format.
|
||||
- end_datetime: NULL.
|
||||
- total_episodes: NULL.
|
||||
- total_time_steps: NULL.
|
||||
- env:
|
||||
- training_config:
|
||||
- All training config items
|
||||
- lay_down_config:
|
||||
- All lay down config items
|
||||
|
||||
"""
|
||||
metadata_dict = {
|
||||
"uuid": self.uuid,
|
||||
"start_datetime": self.session_timestamp.isoformat(),
|
||||
"end_datetime": None,
|
||||
"learning": {"total_episodes": None, "total_time_steps": None},
|
||||
"evaluation": {"total_episodes": None, "total_time_steps": None},
|
||||
"env": {
|
||||
"training_config": self._training_config.to_dict(json_serializable=True),
|
||||
"lay_down_config": self._lay_down_config,
|
||||
},
|
||||
}
|
||||
filepath = self.session_path / "session_metadata.json"
|
||||
_LOGGER.debug(f"Writing Session Metadata file: {filepath}")
|
||||
with open(filepath, "w") as file:
|
||||
json.dump(metadata_dict, file)
|
||||
_LOGGER.debug("Finished writing session metadata file")
|
||||
|
||||
def _update_session_metadata_file(self):
|
||||
"""
|
||||
Update the ``session_metadata.json`` file.
|
||||
|
||||
Updates the `session_metadata.json`` in the ``session_path`` directory
|
||||
with the following key/value pairs:
|
||||
|
||||
- end_datetime: The date & time the session ended in iso format.
|
||||
- total_episodes: The total number of training episodes completed.
|
||||
- total_time_steps: The total number of training time steps completed.
|
||||
"""
|
||||
with open(self.session_path / "session_metadata.json", "r") as file:
|
||||
metadata_dict = json.load(file)
|
||||
|
||||
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
|
||||
if not self.is_eval:
|
||||
metadata_dict["learning"]["total_episodes"] = self._env.episode_count # noqa
|
||||
metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa
|
||||
else:
|
||||
metadata_dict["evaluation"]["total_episodes"] = self._env.episode_count # noqa
|
||||
metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa
|
||||
|
||||
filepath = self.session_path / "session_metadata.json"
|
||||
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||
with open(filepath, "w") as file:
|
||||
json.dump(metadata_dict, file)
|
||||
_LOGGER.debug("Finished updating session metadata file")
|
||||
|
||||
@abstractmethod
|
||||
def _setup(self):
|
||||
_LOGGER.info(
|
||||
"Welcome to the Primary-level AI Training Environment " f"(PrimAITE) (version: {primaite.__version__})"
|
||||
)
|
||||
_LOGGER.info(f"The output directory for this session is: {self.session_path}")
|
||||
self._write_session_metadata_file()
|
||||
self._can_learn = True
|
||||
self._can_evaluate = False
|
||||
|
||||
@abstractmethod
|
||||
def _save_checkpoint(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
if self._can_learn:
|
||||
_LOGGER.info("Finished learning")
|
||||
_LOGGER.debug("Writing transactions")
|
||||
self._update_session_metadata_file()
|
||||
self._can_evaluate = True
|
||||
self.is_eval = False
|
||||
self._plot_av_reward_per_episode(learning_session=True)
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
self._env.set_as_eval() # noqa
|
||||
self.is_eval = True
|
||||
self._plot_av_reward_per_episode(learning_session=False)
|
||||
_LOGGER.info("Finished evaluation")
|
||||
|
||||
@abstractmethod
|
||||
def _get_latest_checkpoint(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load(cls, path: Union[str, Path]) -> AgentSessionABC:
|
||||
"""Load an agent from file."""
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
|
||||
if path.exists():
|
||||
# Unpack the session_metadata.json file
|
||||
md_file = path / "session_metadata.json"
|
||||
with open(md_file, "r") as file:
|
||||
md_dict = json.load(file)
|
||||
|
||||
# Create a temp directory and dump the training and lay down
|
||||
# configs into it
|
||||
temp_dir = path / ".temp"
|
||||
temp_dir.mkdir(exist_ok=True)
|
||||
|
||||
temp_tc = temp_dir / "tc.yaml"
|
||||
with open(temp_tc, "w") as file:
|
||||
yaml.dump(md_dict["env"]["training_config"], file)
|
||||
|
||||
temp_ldc = temp_dir / "ldc.yaml"
|
||||
with open(temp_ldc, "w") as file:
|
||||
yaml.dump(md_dict["env"]["lay_down_config"], file)
|
||||
|
||||
agent = cls(temp_tc, temp_ldc)
|
||||
|
||||
agent.session_path = path
|
||||
|
||||
return agent
|
||||
|
||||
else:
|
||||
# Session path does not exist
|
||||
msg = f"Failed to load PrimAITE Session, path does not exist: {path}"
|
||||
_LOGGER.error(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
"""Save the agent."""
|
||||
self._agent.save(self.session_path)
|
||||
|
||||
@abstractmethod
|
||||
def export(self):
|
||||
"""Export the agent to transportable file format."""
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
"""Closes the agent."""
|
||||
self._env.episode_av_reward_writer.close() # noqa
|
||||
self._env.transaction_writer.close() # noqa
|
||||
|
||||
def _plot_av_reward_per_episode(self, learning_session: bool = True):
|
||||
# self.close()
|
||||
title = f"PrimAITE Session {self.timestamp_str} "
|
||||
subtitle = str(self._training_config)
|
||||
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
|
||||
image_file = f"average_reward_per_episode_{self.timestamp_str}.png"
|
||||
if learning_session:
|
||||
title += "(Learning)"
|
||||
path = self.learning_path / csv_file
|
||||
image_path = self.learning_path / image_file
|
||||
else:
|
||||
title += "(Evaluation)"
|
||||
path = self.evaluation_path / csv_file
|
||||
image_path = self.evaluation_path / image_file
|
||||
|
||||
fig = plot_av_reward_per_episode(path, title, subtitle)
|
||||
fig.write_image(image_path)
|
||||
_LOGGER.debug(f"Saved average rewards per episode plot to: {path}")
|
||||
|
||||
|
||||
class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
"""
|
||||
An Agent Session ABC for evaluation deterministic agents.
|
||||
|
||||
This class cannot be directly instantiated and must be inherited from
|
||||
with all implemented abstract methods implemented.
|
||||
"""
|
||||
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
self._setup()
|
||||
|
||||
def _setup(self):
|
||||
self._env: Primaite = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str,
|
||||
)
|
||||
super()._setup()
|
||||
self._can_learn = False
|
||||
self._can_evaluate = True
|
||||
|
||||
def _save_checkpoint(self):
|
||||
pass
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
_LOGGER.warning("Deterministic agents cannot learn")
|
||||
|
||||
@abstractmethod
|
||||
def _calculate_action(self, obs):
|
||||
pass
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
self._env.set_as_eval() # noqa
|
||||
self.is_eval = True
|
||||
|
||||
time_steps = self._training_config.num_steps
|
||||
episodes = self._training_config.num_episodes
|
||||
|
||||
obs = self._env.reset()
|
||||
for episode in range(episodes):
|
||||
# Reset env and collect initial observation
|
||||
for step in range(time_steps):
|
||||
# Calculate action
|
||||
action = self._calculate_action(obs)
|
||||
|
||||
# Perform the step
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
# Introduce a delay between steps
|
||||
time.sleep(self._training_config.time_delay / 1000)
|
||||
obs = self._env.reset()
|
||||
self._env.close()
|
||||
|
||||
@classmethod
|
||||
def load(cls):
|
||||
"""Load an agent from file."""
|
||||
_LOGGER.warning("Deterministic agents cannot be loaded")
|
||||
|
||||
def save(self):
|
||||
"""Save the agent."""
|
||||
_LOGGER.warning("Deterministic agents cannot be saved")
|
||||
|
||||
def export(self):
|
||||
"""Export the agent to transportable file format."""
|
||||
_LOGGER.warning("Deterministic agents cannot be exported")
|
||||
431
src/primaite/agents/hardcoded_acl.py
Normal file
431
src/primaite/agents/hardcoded_acl.py
Normal file
@@ -0,0 +1,431 @@
|
||||
import numpy as np
|
||||
|
||||
from primaite.agents.agent import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import (
|
||||
get_new_action,
|
||||
get_node_of_ip,
|
||||
transform_action_acl_enum,
|
||||
transform_change_obs_readable,
|
||||
)
|
||||
from primaite.common.enums import HardCodedAgentView
|
||||
|
||||
|
||||
class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
"""An Agent Session class that implements a deterministic ACL agent."""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC:
|
||||
# Basic view action using only the current observation
|
||||
return self._calculate_action_basic_view(obs)
|
||||
else:
|
||||
# full view action using observation space, action
|
||||
# history and reward feedback
|
||||
return self._calculate_action_full_view(obs)
|
||||
|
||||
def get_blocked_green_iers(self, green_iers, acl, nodes):
|
||||
"""
|
||||
Get blocked green IERs.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
blocked_green_iers = {}
|
||||
|
||||
for green_ier_id, green_ier in green_iers.items():
|
||||
source_node_id = green_ier.get_source_node_id()
|
||||
source_node_address = nodes[source_node_id].ip_address
|
||||
dest_node_id = green_ier.get_dest_node_id()
|
||||
dest_node_address = nodes[dest_node_id].ip_address
|
||||
protocol = green_ier.get_protocol() # e.g. 'TCP'
|
||||
port = green_ier.get_port()
|
||||
|
||||
# Can be blocked by an ACL or by default (no allow rule exists)
|
||||
if acl.is_blocked(source_node_address, dest_node_address, protocol, port):
|
||||
blocked_green_iers[green_ier_id] = green_ier
|
||||
|
||||
return blocked_green_iers
|
||||
|
||||
def get_matching_acl_rules_for_ier(self, ier, acl, nodes):
|
||||
"""
|
||||
Get matching ACL rules for an IER.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
source_node_id = ier.get_source_node_id()
|
||||
source_node_address = nodes[source_node_id].ip_address
|
||||
dest_node_id = ier.get_dest_node_id()
|
||||
dest_node_address = nodes[dest_node_id].ip_address
|
||||
protocol = ier.get_protocol() # e.g. 'TCP'
|
||||
port = ier.get_port()
|
||||
|
||||
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
|
||||
return matching_rules
|
||||
|
||||
def get_blocking_acl_rules_for_ier(self, ier, acl, nodes):
|
||||
"""
|
||||
Get blocking ACL rules for an IER.
|
||||
|
||||
.. warning::
|
||||
Can return empty dict but IER can still be blocked by default
|
||||
(No ALLOW rule, therefore blocked).
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
|
||||
|
||||
blocked_rules = {}
|
||||
for rule_key, rule_value in matching_rules.items():
|
||||
if rule_value.get_permission() == "DENY":
|
||||
blocked_rules[rule_key] = rule_value
|
||||
|
||||
return blocked_rules
|
||||
|
||||
def get_allow_acl_rules_for_ier(self, ier, acl, nodes):
|
||||
"""
|
||||
Get all allowing ACL rules for an IER.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
|
||||
|
||||
allowed_rules = {}
|
||||
for rule_key, rule_value in matching_rules.items():
|
||||
if rule_value.get_permission() == "ALLOW":
|
||||
allowed_rules[rule_key] = rule_value
|
||||
|
||||
return allowed_rules
|
||||
|
||||
def get_matching_acl_rules(
|
||||
self,
|
||||
source_node_id,
|
||||
dest_node_id,
|
||||
protocol,
|
||||
port,
|
||||
acl,
|
||||
nodes,
|
||||
services_list,
|
||||
):
|
||||
"""
|
||||
Get matching ACL rules.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
if source_node_id != "ANY":
|
||||
source_node_address = nodes[str(source_node_id)].ip_address
|
||||
else:
|
||||
source_node_address = source_node_id
|
||||
|
||||
if dest_node_id != "ANY":
|
||||
dest_node_address = nodes[str(dest_node_id)].ip_address
|
||||
else:
|
||||
dest_node_address = dest_node_id
|
||||
|
||||
if protocol != "ANY":
|
||||
protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services
|
||||
|
||||
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
|
||||
return matching_rules
|
||||
|
||||
def get_allow_acl_rules(
|
||||
self,
|
||||
source_node_id,
|
||||
dest_node_id,
|
||||
protocol,
|
||||
port,
|
||||
acl,
|
||||
nodes,
|
||||
services_list,
|
||||
):
|
||||
"""
|
||||
Get the ALLOW ACL rules.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules(
|
||||
source_node_id,
|
||||
dest_node_id,
|
||||
protocol,
|
||||
port,
|
||||
acl,
|
||||
nodes,
|
||||
services_list,
|
||||
)
|
||||
|
||||
allowed_rules = {}
|
||||
for rule_key, rule_value in matching_rules.items():
|
||||
if rule_value.get_permission() == "ALLOW":
|
||||
allowed_rules[rule_key] = rule_value
|
||||
|
||||
return allowed_rules
|
||||
|
||||
def get_deny_acl_rules(
|
||||
self,
|
||||
source_node_id,
|
||||
dest_node_id,
|
||||
protocol,
|
||||
port,
|
||||
acl,
|
||||
nodes,
|
||||
services_list,
|
||||
):
|
||||
"""
|
||||
Get the DENY ACL rules.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules(
|
||||
source_node_id,
|
||||
dest_node_id,
|
||||
protocol,
|
||||
port,
|
||||
acl,
|
||||
nodes,
|
||||
services_list,
|
||||
)
|
||||
|
||||
allowed_rules = {}
|
||||
for rule_key, rule_value in matching_rules.items():
|
||||
if rule_value.get_permission() == "DENY":
|
||||
allowed_rules[rule_key] = rule_value
|
||||
|
||||
return allowed_rules
|
||||
|
||||
def _calculate_action_full_view(self, obs):
|
||||
"""
|
||||
Calculate a good acl-based action for the blue agent to take.
|
||||
|
||||
Knowledge of just the observation space is insufficient for a perfect solution, as we need to know:
|
||||
|
||||
- Which ACL rules already exist, - otherwise:
|
||||
- The agent would perminently get stuck in a loop of performing the same action over and over.
|
||||
(best action is to block something, but its already blocked but doesn't know this)
|
||||
- The agent would be unable to interact with existing rules (e.g. how would it know to delete a rule,
|
||||
if it doesnt know what rules exist)
|
||||
- The Green IERs (optional) - It often needs to know which traffic it should be allowing. For example
|
||||
in the default config one of the green IERs is blocked by default, but it has no way of knowing this
|
||||
based on the observation space. Additionally, potentially in the future, once a node state
|
||||
has been fixed (no longer compromised), it needs a way to know it should reallow traffic.
|
||||
A RL agent can learn what the green IERs are on its own - but the rule based agent cannot easily do this.
|
||||
|
||||
There doesn't seem like there's much that can be done if an Operating or OS State is compromised
|
||||
|
||||
If a service node becomes compromised there's a decision to make - do we block that service?
|
||||
Pros: It cannot launch an attack on another node, so the node will not be able to be OVERWHELMED
|
||||
Cons: Will block a green IER, decreasing the reward
|
||||
We decide to block the service.
|
||||
|
||||
Potentially a better solution (for the reward) would be to block the incomming traffic from compromised
|
||||
nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing
|
||||
an overwhelmed state, so we don't do this.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
# obs = convert_to_old_obs(obs)
|
||||
r_obs = transform_change_obs_readable(obs)
|
||||
_, _, _, *s = r_obs
|
||||
|
||||
if len(r_obs) == 4: # only 1 service
|
||||
s = [*s]
|
||||
|
||||
# 1. Check if node is compromised. If so we want to block its outwards services
|
||||
# a. If it is comprimised check if there's an allow rule we should delete.
|
||||
# cons: might delete a multi-rule from any source node (ANY -> x)
|
||||
# b. OPTIONAL (Deny rules not needed): Check if there already exists an existing Deny Rule so not to duplicate
|
||||
# c. OPTIONAL (no allow rule = blocked): Add a DENY rule
|
||||
found_action = False
|
||||
for service_num, service_states in enumerate(s):
|
||||
for x, service_state in enumerate(service_states):
|
||||
if service_state == "COMPROMISED":
|
||||
action_source_id = x + 1 # +1 as 0 is any
|
||||
action_destination_id = "ANY"
|
||||
action_protocol = service_num + 1 # +1 as 0 is any
|
||||
action_port = "ANY"
|
||||
|
||||
allow_rules = self.get_allow_acl_rules(
|
||||
action_source_id,
|
||||
action_destination_id,
|
||||
action_protocol,
|
||||
action_port,
|
||||
self._env.acl,
|
||||
self._env.nodes,
|
||||
self._env.services_list,
|
||||
)
|
||||
deny_rules = self.get_deny_acl_rules(
|
||||
action_source_id,
|
||||
action_destination_id,
|
||||
action_protocol,
|
||||
action_port,
|
||||
self._env.acl,
|
||||
self._env.nodes,
|
||||
self._env.services_list,
|
||||
)
|
||||
if len(allow_rules) > 0:
|
||||
# Check if there's an allow rule we should delete
|
||||
rule = list(allow_rules.values())[0]
|
||||
action_decision = "DELETE"
|
||||
action_permission = "ALLOW"
|
||||
action_source_ip = rule.get_source_ip()
|
||||
action_source_id = int(get_node_of_ip(action_source_ip, self._env.nodes))
|
||||
action_destination_ip = rule.get_dest_ip()
|
||||
action_destination_id = int(get_node_of_ip(action_destination_ip, self._env.nodes))
|
||||
action_protocol_name = rule.get_protocol()
|
||||
action_protocol = (
|
||||
self._env.services_list.index(action_protocol_name) + 1
|
||||
) # convert name e.g. 'TCP' to index
|
||||
action_port_name = rule.get_port()
|
||||
action_port = (
|
||||
self._env.ports_list.index(action_port_name) + 1
|
||||
) # convert port name e.g. '80' to index
|
||||
|
||||
found_action = True
|
||||
break
|
||||
elif len(deny_rules) > 0:
|
||||
# TODO OPTIONAL
|
||||
# If there's already a DENY RULE, that blocks EVERYTHING from the source ip we don't need
|
||||
# to create another
|
||||
# Check to see if the DENY rule really blocks everything (ANY) or just a specific rule
|
||||
continue
|
||||
else:
|
||||
# TODO OPTIONAL: Add a DENY rule, optional as by default no allow rule == blocked
|
||||
action_decision = "CREATE"
|
||||
action_permission = "DENY"
|
||||
break
|
||||
if found_action:
|
||||
break
|
||||
|
||||
# 2. If NO Node is Comprimised, or the node has already been blocked, check the green IERs and
|
||||
# add an Allow rule if the green IER is being blocked.
|
||||
# a. OPTIONAL - NOT IMPLEMENTED (optional as a deny rule does not overwrite an allow rule):
|
||||
# If there's a DENY rule delete it if:
|
||||
# - There isn't already a deny rule
|
||||
# - It doesnt allows a comprimised node to become operational.
|
||||
# b. Add an ALLOW rule if:
|
||||
# - There isn't already an allow rule
|
||||
# - It doesnt allows a comprimised node to become operational
|
||||
|
||||
if not found_action:
|
||||
# Which Green IERS are blocked
|
||||
blocked_green_iers = self.get_blocked_green_iers(self._env.green_iers, self._env.acl, self._env.nodes)
|
||||
for ier_key, ier in blocked_green_iers.items():
|
||||
# Which ALLOW rules are allowing this IER (none)
|
||||
allowing_rules = self.get_allow_acl_rules_for_ier(ier, self._env.acl, self._env.nodes)
|
||||
|
||||
# If there are no blocking rules, it may be being blocked by default
|
||||
# If there is already an allow rule
|
||||
node_id_to_check = int(ier.get_source_node_id())
|
||||
service_name_to_check = ier.get_protocol()
|
||||
service_id_to_check = self._env.services_list.index(service_name_to_check)
|
||||
|
||||
# Service state of the the source node in the ier
|
||||
service_state = s[service_id_to_check][node_id_to_check - 1]
|
||||
|
||||
if len(allowing_rules) == 0 and service_state != "COMPROMISED":
|
||||
action_decision = "CREATE"
|
||||
action_permission = "ALLOW"
|
||||
action_source_id = int(ier.get_source_node_id())
|
||||
action_destination_id = int(ier.get_dest_node_id())
|
||||
action_protocol_name = ier.get_protocol()
|
||||
action_protocol = (
|
||||
self._env.services_list.index(action_protocol_name) + 1
|
||||
) # convert name e.g. 'TCP' to index
|
||||
action_port_name = ier.get_port()
|
||||
action_port = (
|
||||
self._env.ports_list.index(action_port_name) + 1
|
||||
) # convert port name e.g. '80' to index
|
||||
|
||||
found_action = True
|
||||
break
|
||||
|
||||
if found_action:
|
||||
action = [
|
||||
action_decision,
|
||||
action_permission,
|
||||
action_source_id,
|
||||
action_destination_id,
|
||||
action_protocol,
|
||||
action_port,
|
||||
]
|
||||
action = transform_action_acl_enum(action)
|
||||
action = get_new_action(action, self._env.action_dict)
|
||||
else:
|
||||
# If no good/useful action has been found, just perform a nothing action
|
||||
action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
|
||||
action = transform_action_acl_enum(action)
|
||||
action = get_new_action(action, self._env.action_dict)
|
||||
return action
|
||||
|
||||
def _calculate_action_basic_view(self, obs):
|
||||
"""Calculate a good acl-based action for the blue agent to take.
|
||||
|
||||
Uses ONLY information from the current observation with NO knowledge
|
||||
of previous actions taken and NO reward feedback.
|
||||
|
||||
We rely on randomness to select the precise action, as we want to
|
||||
block all traffic originating from a compromised node, without being
|
||||
able to tell:
|
||||
1. Which ACL rules already exist
|
||||
2. Which actions the agent has already tried.
|
||||
|
||||
There is a high probability that the correct rule will not be deleted
|
||||
before the state becomes overwhelmed.
|
||||
|
||||
Currently, a deny rule does not overwrite an allow rule. The allow
|
||||
rules must be deleted.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
action_dict = self._env.action_dict
|
||||
r_obs = transform_change_obs_readable(obs)
|
||||
_, o, _, *s = r_obs
|
||||
|
||||
if len(r_obs) == 4: # only 1 service
|
||||
s = [*s]
|
||||
|
||||
number_of_nodes = len([i for i in o if i != "NONE"]) # number of nodes (not links)
|
||||
for service_num, service_states in enumerate(s):
|
||||
comprimised_states = [n for n, i in enumerate(service_states) if i == "COMPROMISED"]
|
||||
if len(comprimised_states) == 0:
|
||||
# No states are COMPROMISED, try the next service
|
||||
continue
|
||||
|
||||
compromised_node = np.random.choice(comprimised_states) + 1 # +1 as 0 would be any
|
||||
action_decision = "DELETE"
|
||||
action_permission = "ALLOW"
|
||||
action_source_ip = compromised_node
|
||||
# Randomly select a destination ID to block
|
||||
action_destination_ip = np.random.choice(list(range(1, number_of_nodes + 1)) + ["ANY"])
|
||||
action_destination_ip = (
|
||||
int(action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip
|
||||
)
|
||||
action_protocol = service_num + 1 # +1 as 0 is any
|
||||
# Randomly select a port to block
|
||||
# Bad assumption that number of protocols equals number of ports
|
||||
# AND no rules exist with an ANY port
|
||||
action_port = np.random.choice(list(range(1, len(s) + 1)))
|
||||
|
||||
action = [
|
||||
action_decision,
|
||||
action_permission,
|
||||
action_source_ip,
|
||||
action_destination_ip,
|
||||
action_protocol,
|
||||
action_port,
|
||||
]
|
||||
action = transform_action_acl_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# If no good/useful action has been found, just perform a nothing action
|
||||
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
|
||||
nothing_action = transform_action_acl_enum(nothing_action)
|
||||
nothing_action = get_new_action(nothing_action, action_dict)
|
||||
return nothing_action
|
||||
119
src/primaite/agents/hardcoded_node.py
Normal file
119
src/primaite/agents/hardcoded_node.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from primaite.agents.agent import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable
|
||||
|
||||
|
||||
class HardCodedNodeAgent(HardCodedAgentSessionABC):
|
||||
"""An Agent Session class that implements a deterministic Node agent."""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
"""
|
||||
Calculate a good node-based action for the blue agent to take.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
action_dict = self._env.action_dict
|
||||
r_obs = transform_change_obs_readable(obs)
|
||||
_, o, os, *s = r_obs
|
||||
|
||||
if len(r_obs) == 4: # only 1 service
|
||||
s = [*s]
|
||||
|
||||
# Check in order of most important states (order doesn't currently
|
||||
# matter, but it probably should)
|
||||
# First see if any OS states are compromised
|
||||
for x, os_state in enumerate(os):
|
||||
if os_state == "COMPROMISED":
|
||||
action_node_id = x + 1
|
||||
action_node_property = "OS"
|
||||
property_action = "PATCHING"
|
||||
action_service_index = 0 # does nothing isn't relevant for os
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# Next, see if any Services are compromised
|
||||
# We fix the compromised state before overwhelemd state,
|
||||
# If a compromised entry node is fixed before the overwhelmed state is triggered, instruction is ignored
|
||||
for service_num, service in enumerate(s):
|
||||
for x, service_state in enumerate(service):
|
||||
if service_state == "COMPROMISED":
|
||||
action_node_id = x + 1
|
||||
action_node_property = "SERVICE"
|
||||
property_action = "PATCHING"
|
||||
action_service_index = service_num
|
||||
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# Next, See if any services are overwhelmed
|
||||
# perhaps this should be fixed automatically when the compromised PCs issues are also resolved
|
||||
# Currently there's no reason that an Overwhelmed state cannot be resolved before resolving the compromised PCs
|
||||
|
||||
for service_num, service in enumerate(s):
|
||||
for x, service_state in enumerate(service):
|
||||
if service_state == "OVERWHELMED":
|
||||
action_node_id = x + 1
|
||||
action_node_property = "SERVICE"
|
||||
property_action = "PATCHING"
|
||||
action_service_index = service_num
|
||||
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# Finally, turn on any off nodes
|
||||
for x, operating_state in enumerate(o):
|
||||
if os_state == "OFF":
|
||||
action_node_id = x + 1
|
||||
action_node_property = "OPERATING"
|
||||
property_action = "ON" # Why reset it when we can just turn it on
|
||||
action_service_index = 0 # does nothing isn't relevant for operating state
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
action = transform_action_node_enum(action, action_dict)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# If no good actions, just go with an action that wont do any harm
|
||||
action_node_id = 1
|
||||
action_node_property = "NONE"
|
||||
property_action = "ON"
|
||||
action_service_index = 0
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
|
||||
return action
|
||||
171
src/primaite/agents/rllib.py
Normal file
171
src/primaite/agents/rllib.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from ray.rllib.algorithms import Algorithm
|
||||
from ray.rllib.algorithms.a2c import A2CConfig
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def _env_creator(env_config):
|
||||
return Primaite(
|
||||
training_config_path=env_config["training_config_path"],
|
||||
lay_down_config_path=env_config["lay_down_config_path"],
|
||||
session_path=env_config["session_path"],
|
||||
timestamp_str=env_config["timestamp_str"],
|
||||
)
|
||||
|
||||
|
||||
def _custom_log_creator(session_path: Path):
|
||||
logdir = session_path / "ray_results"
|
||||
logdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def logger_creator(config):
|
||||
return UnifiedLogger(config, logdir, loggers=None)
|
||||
|
||||
return logger_creator
|
||||
|
||||
|
||||
class RLlibAgent(AgentSessionABC):
|
||||
"""An AgentSession class that implements a Ray RLlib agent."""
|
||||
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
if not self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
||||
self._agent_config_class = PPOConfig
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
||||
self._agent_config_class = A2CConfig
|
||||
else:
|
||||
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
self._agent_config: Union[PPOConfig, A2CConfig]
|
||||
|
||||
self._current_result: dict
|
||||
self._setup()
|
||||
_LOGGER.debug(
|
||||
f"Created {self.__class__.__name__} using: "
|
||||
f"agent_framework={self._training_config.agent_framework}, "
|
||||
f"agent_identifier="
|
||||
f"{self._training_config.agent_identifier}, "
|
||||
f"deep_learning_framework="
|
||||
f"{self._training_config.deep_learning_framework}"
|
||||
)
|
||||
|
||||
def _update_session_metadata_file(self):
|
||||
"""
|
||||
Update the ``session_metadata.json`` file.
|
||||
|
||||
Updates the `session_metadata.json`` in the ``session_path`` directory
|
||||
with the following key/value pairs:
|
||||
|
||||
- end_datetime: The date & time the session ended in iso format.
|
||||
- total_episodes: The total number of training episodes completed.
|
||||
- total_time_steps: The total number of training time steps completed.
|
||||
"""
|
||||
with open(self.session_path / "session_metadata.json", "r") as file:
|
||||
metadata_dict = json.load(file)
|
||||
|
||||
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
metadata_dict["total_episodes"] = self._current_result["episodes_total"]
|
||||
metadata_dict["total_time_steps"] = self._current_result["timesteps_total"]
|
||||
|
||||
filepath = self.session_path / "session_metadata.json"
|
||||
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||
with open(filepath, "w") as file:
|
||||
json.dump(metadata_dict, file)
|
||||
_LOGGER.debug("Finished updating session metadata file")
|
||||
|
||||
def _setup(self):
|
||||
super()._setup()
|
||||
register_env("primaite", _env_creator)
|
||||
self._agent_config = self._agent_config_class()
|
||||
|
||||
self._agent_config.environment(
|
||||
env="primaite",
|
||||
env_config=dict(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str,
|
||||
),
|
||||
)
|
||||
|
||||
self._agent_config.training(train_batch_size=self._training_config.num_steps)
|
||||
self._agent_config.framework(framework="tf")
|
||||
|
||||
self._agent_config.rollouts(
|
||||
num_rollout_workers=1,
|
||||
num_envs_per_worker=1,
|
||||
horizon=self._training_config.num_steps,
|
||||
)
|
||||
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
|
||||
|
||||
def _save_checkpoint(self):
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._current_result["episodes_total"]
|
||||
if checkpoint_n > 0 and episode_count > 0:
|
||||
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
|
||||
self._agent.save(str(self.checkpoints_path))
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
time_steps = self._training_config.num_steps
|
||||
episodes = self._training_config.num_episodes
|
||||
|
||||
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
|
||||
for i in range(episodes):
|
||||
self._current_result = self._agent.train()
|
||||
self._save_checkpoint()
|
||||
self._agent.stop()
|
||||
super().learn()
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Union[str, Path]) -> RLlibAgent:
|
||||
"""Load an agent from file."""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self):
|
||||
"""Save the agent."""
|
||||
raise NotImplementedError
|
||||
|
||||
def export(self):
|
||||
"""Export the agent to transportable file format."""
|
||||
raise NotImplementedError
|
||||
141
src/primaite/agents/sb3.py
Normal file
141
src/primaite/agents/sb3.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class SB3Agent(AgentSessionABC):
|
||||
"""An AgentSession class that implements a Stable Baselines3 agent."""
|
||||
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
if not self._training_config.agent_framework == AgentFramework.SB3:
|
||||
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
||||
self._agent_class = PPO
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
||||
self._agent_class = A2C
|
||||
else:
|
||||
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
self._tensorboard_log_path = self.learning_path / "tensorboard_logs"
|
||||
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
|
||||
self._setup()
|
||||
_LOGGER.debug(
|
||||
f"Created {self.__class__.__name__} using: "
|
||||
f"agent_framework={self._training_config.agent_framework}, "
|
||||
f"agent_identifier="
|
||||
f"{self._training_config.agent_identifier}"
|
||||
)
|
||||
|
||||
self.is_eval = False
|
||||
|
||||
def _setup(self):
|
||||
super()._setup()
|
||||
self._env = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str,
|
||||
)
|
||||
self._agent = self._agent_class(
|
||||
PPOMlp,
|
||||
self._env,
|
||||
verbose=self.sb3_output_verbose_level,
|
||||
n_steps=self._training_config.num_steps,
|
||||
tensorboard_log=str(self._tensorboard_log_path),
|
||||
)
|
||||
|
||||
def _save_checkpoint(self):
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._env.episode_count
|
||||
if checkpoint_n > 0 and episode_count > 0:
|
||||
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
|
||||
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
|
||||
self._agent.save(checkpoint_path)
|
||||
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
time_steps = self._training_config.num_steps
|
||||
episodes = self._training_config.num_episodes
|
||||
self.is_eval = False
|
||||
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
|
||||
for i in range(episodes):
|
||||
self._agent.learn(total_timesteps=time_steps)
|
||||
self._save_checkpoint()
|
||||
self._env.reset()
|
||||
self._env.close()
|
||||
super().learn()
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
deterministic: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param deterministic: Whether the evaluation is deterministic.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
time_steps = self._training_config.num_steps
|
||||
episodes = self._training_config.num_episodes
|
||||
self._env.set_as_eval()
|
||||
self.is_eval = True
|
||||
if deterministic:
|
||||
deterministic_str = "deterministic"
|
||||
else:
|
||||
deterministic_str = "non-deterministic"
|
||||
_LOGGER.info(
|
||||
f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
|
||||
)
|
||||
for episode in range(episodes):
|
||||
obs = self._env.reset()
|
||||
|
||||
for step in range(time_steps):
|
||||
action, _states = self._agent.predict(obs, deterministic=deterministic)
|
||||
if isinstance(action, np.ndarray):
|
||||
action = np.int64(action)
|
||||
obs, rewards, done, info = self._env.step(action)
|
||||
self._env.reset()
|
||||
self._env.close()
|
||||
super().evaluate()
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Union[str, Path]) -> SB3Agent:
|
||||
"""Load an agent from file."""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self):
|
||||
"""Save the agent."""
|
||||
raise NotImplementedError
|
||||
|
||||
def export(self):
|
||||
"""Export the agent to transportable file format."""
|
||||
raise NotImplementedError
|
||||
56
src/primaite/agents/simple.py
Normal file
56
src/primaite/agents/simple.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from primaite.agents.agent import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
|
||||
|
||||
|
||||
class RandomAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A Random Agent.
|
||||
|
||||
Get a completely random action from the action space.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
return self._env.action_space.sample()
|
||||
|
||||
|
||||
class DummyAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A Dummy Agent.
|
||||
|
||||
All action spaces setup so dummy action is always 0 regardless of action
|
||||
type used.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
return 0
|
||||
|
||||
|
||||
class DoNothingACLAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A do nothing ACL agent.
|
||||
|
||||
A valid ACL action that has no effect; does nothing.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
|
||||
nothing_action = transform_action_acl_enum(nothing_action)
|
||||
nothing_action = get_new_action(nothing_action, self._env.action_dict)
|
||||
|
||||
return nothing_action
|
||||
|
||||
|
||||
class DoNothingNodeAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A do nothing Node agent.
|
||||
|
||||
A valid Node action that has no effect; does nothing.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
nothing_action = [1, "NONE", "ON", 0]
|
||||
nothing_action = transform_action_node_enum(nothing_action)
|
||||
nothing_action = get_new_action(nothing_action, self._env.action_dict)
|
||||
# nothing_action should currently always be 0
|
||||
|
||||
return nothing_action
|
||||
@@ -1,4 +1,13 @@
|
||||
from primaite.common.enums import NodeHardwareAction, NodePOLType, NodeSoftwareAction
|
||||
import numpy as np
|
||||
|
||||
from primaite.common.enums import (
|
||||
HardwareState,
|
||||
LinkStatus,
|
||||
NodeHardwareAction,
|
||||
NodePOLType,
|
||||
NodeSoftwareAction,
|
||||
SoftwareState,
|
||||
)
|
||||
|
||||
|
||||
def transform_action_node_readable(action):
|
||||
@@ -7,14 +16,15 @@ def transform_action_node_readable(action):
|
||||
|
||||
example:
|
||||
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
action_node_property = NodePOLType(action[1]).name
|
||||
|
||||
if action_node_property == "OPERATING":
|
||||
property_action = NodeHardwareAction(action[2]).name
|
||||
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[
|
||||
2
|
||||
] <= 1:
|
||||
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1:
|
||||
property_action = NodeSoftwareAction(action[2]).name
|
||||
else:
|
||||
property_action = "NONE"
|
||||
@@ -29,6 +39,9 @@ def transform_action_acl_readable(action):
|
||||
|
||||
example:
|
||||
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"}
|
||||
action_permissions = {0: "DENY", 1: "ALLOW"}
|
||||
@@ -53,6 +66,9 @@ def is_valid_node_action(action):
|
||||
Does NOT consider:
|
||||
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
|
||||
- Node already being in that state (turning an ON node ON)
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
action_r = transform_action_node_readable(action)
|
||||
|
||||
@@ -68,7 +84,10 @@ def is_valid_node_action(action):
|
||||
if node_property == "OPERATING" and node_action == "PATCHING":
|
||||
# Operating State cannot PATCH
|
||||
return False
|
||||
if node_property != "OPERATING" and node_action not in ["NONE", "PATCHING"]:
|
||||
if node_property != "OPERATING" and node_action not in [
|
||||
"NONE",
|
||||
"PATCHING",
|
||||
]:
|
||||
# Software States can only do Nothing or Patch
|
||||
return False
|
||||
return True
|
||||
@@ -83,6 +102,9 @@ def is_valid_acl_action(action):
|
||||
Does NOT consider:
|
||||
- Trying to create identical rules
|
||||
- Trying to create a rule which is a subset of another rule (caused by "ANY")
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
action_r = transform_action_acl_readable(action)
|
||||
|
||||
@@ -93,11 +115,7 @@ def is_valid_acl_action(action):
|
||||
|
||||
if action_decision == "NONE":
|
||||
return False
|
||||
if (
|
||||
action_source_id == action_destination_id
|
||||
and action_source_id != "ANY"
|
||||
and action_destination_id != "ANY"
|
||||
):
|
||||
if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY":
|
||||
# ACL rule towards itself
|
||||
return False
|
||||
if action_permission == "DENY":
|
||||
@@ -109,7 +127,12 @@ def is_valid_acl_action(action):
|
||||
|
||||
|
||||
def is_valid_acl_action_extra(action):
|
||||
"""Harsher version of valid acl actions, does not allow action."""
|
||||
"""
|
||||
Harsher version of valid acl actions, does not allow action.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
if is_valid_acl_action(action) is False:
|
||||
return False
|
||||
|
||||
@@ -125,3 +148,406 @@ def is_valid_acl_action_extra(action):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def transform_change_obs_readable(obs):
|
||||
"""
|
||||
Transform list of transactions to readable list of each observation property.
|
||||
|
||||
example:
|
||||
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']]
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
ids = [i for i in obs[:, 0]]
|
||||
operating_states = [HardwareState(i).name for i in obs[:, 1]]
|
||||
os_states = [SoftwareState(i).name for i in obs[:, 2]]
|
||||
new_obs = [ids, operating_states, os_states]
|
||||
|
||||
for service in range(3, obs.shape[1]):
|
||||
# Links bit/s don't have a service state
|
||||
service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]]
|
||||
new_obs.append(service_states)
|
||||
|
||||
return new_obs
|
||||
|
||||
|
||||
def transform_obs_readable(obs):
|
||||
"""
|
||||
Transform observation to readable format.
|
||||
|
||||
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
changed_obs = transform_change_obs_readable(obs)
|
||||
new_obs = list(zip(*changed_obs))
|
||||
# Convert list of tuples to list of lists
|
||||
new_obs = [list(i) for i in new_obs]
|
||||
|
||||
return new_obs
|
||||
|
||||
|
||||
def convert_to_new_obs(obs, num_nodes=10):
|
||||
"""
|
||||
Convert original gym Box observation space to new multiDiscrete observation space.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
# Remove ID columns, remove links and flatten to MultiDiscrete observation space
|
||||
new_obs = obs[:num_nodes, 1:].flatten()
|
||||
return new_obs
|
||||
|
||||
|
||||
def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
|
||||
"""
|
||||
Convert to old observation.
|
||||
|
||||
Links filled with 0's as no information is included in new observation space.
|
||||
|
||||
example:
|
||||
obs = array([1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1, 1])
|
||||
|
||||
new_obs = array([[ 1, 1, 1, 1],
|
||||
[ 2, 1, 1, 1],
|
||||
[ 3, 1, 1, 1],
|
||||
...
|
||||
[20, 0, 0, 0]])
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
# Convert back to more readable, original format
|
||||
reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2)
|
||||
|
||||
# Add empty links back and add node ID back
|
||||
s = np.zeros(
|
||||
[reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1],
|
||||
dtype=np.int64,
|
||||
)
|
||||
s[:, 0] = range(1, num_nodes + num_links + 1) # Adding ID back
|
||||
s[:num_nodes, 1:] = reshaped_nodes # put values back in
|
||||
new_obs = s
|
||||
|
||||
# Add links back in
|
||||
links = obs[-num_links:]
|
||||
# Links will be added to the last protocol/service slot but they are not specific to that service
|
||||
new_obs[num_nodes:, -1] = links
|
||||
|
||||
return new_obs
|
||||
|
||||
|
||||
def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
|
||||
"""
|
||||
Return string describing change between two observations.
|
||||
|
||||
example:
|
||||
obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]])
|
||||
obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]])
|
||||
output = 'ID 1: SERVICE 2 set to GOOD'
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services)
|
||||
obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services)
|
||||
list_of_changes = []
|
||||
for n, row in enumerate(obs1 - obs2):
|
||||
if row.any() != 0:
|
||||
relevant_changes = np.where(row != 0, obs2[n], -1)
|
||||
relevant_changes[0] = obs2[n, 0] # ID is always relevant
|
||||
is_link = relevant_changes[0] > num_nodes
|
||||
desc = _describe_obs_change_helper(relevant_changes, is_link)
|
||||
list_of_changes.append(desc)
|
||||
|
||||
change_string = "\n ".join(list_of_changes)
|
||||
if len(list_of_changes) > 0:
|
||||
change_string = "\n " + change_string
|
||||
return change_string
|
||||
|
||||
|
||||
def _describe_obs_change_helper(obs_change, is_link):
|
||||
"""
|
||||
Helper funcion to describe what has changed.
|
||||
|
||||
example:
|
||||
[ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD"
|
||||
|
||||
Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.'
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
# Indexes where a change has occured, not including 0th index
|
||||
index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1]
|
||||
# Node pol types, Indexes >= 3 are service nodes
|
||||
NodePOLTypes = [NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) for i in index_changed]
|
||||
# Account for hardware states, software sattes and links
|
||||
states = [
|
||||
LinkStatus(obs_change[i]).name
|
||||
if is_link
|
||||
else HardwareState(obs_change[i]).name
|
||||
if i == 1
|
||||
else SoftwareState(obs_change[i]).name
|
||||
for i in index_changed
|
||||
]
|
||||
|
||||
if not is_link:
|
||||
desc = f"ID {obs_change[0]}:"
|
||||
for node_pol_type, state in list(zip(NodePOLTypes, states)):
|
||||
desc = desc + " " + node_pol_type + " changed to " + state + "."
|
||||
else:
|
||||
desc = f"ID {obs_change[0]}: Link traffic changed to {states[0]}."
|
||||
|
||||
return desc
|
||||
|
||||
|
||||
def transform_action_node_enum(action):
|
||||
"""
|
||||
Convert a node action from readable string format, to enumerated format.
|
||||
|
||||
example:
|
||||
[1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0]
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
action_node_id = action[0]
|
||||
action_node_property = NodePOLType[action[1]].value
|
||||
|
||||
if action[1] == "OPERATING":
|
||||
property_action = NodeHardwareAction[action[2]].value
|
||||
elif action[1] == "OS" or action[1] == "SERVICE":
|
||||
property_action = NodeSoftwareAction[action[2]].value
|
||||
else:
|
||||
property_action = 0
|
||||
|
||||
action_service_index = action[3]
|
||||
|
||||
new_action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
|
||||
return new_action
|
||||
|
||||
|
||||
def transform_action_node_readable(action):
|
||||
"""
|
||||
Convert a node action from enumerated format to readable format.
|
||||
|
||||
example:
|
||||
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
action_node_property = NodePOLType(action[1]).name
|
||||
|
||||
if action_node_property == "OPERATING":
|
||||
property_action = NodeHardwareAction(action[2]).name
|
||||
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1:
|
||||
property_action = NodeSoftwareAction(action[2]).name
|
||||
else:
|
||||
property_action = "NONE"
|
||||
|
||||
new_action = [action[0], action_node_property, property_action, action[3]]
|
||||
return new_action
|
||||
|
||||
|
||||
def node_action_description(action):
|
||||
"""
|
||||
Generate string describing a node-based action.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
if isinstance(action[1], (int, np.int64)):
|
||||
# transform action to readable format
|
||||
action = transform_action_node_readable(action)
|
||||
|
||||
node_id = action[0]
|
||||
node_property = action[1]
|
||||
property_action = action[2]
|
||||
service_id = action[3]
|
||||
|
||||
if property_action == "NONE":
|
||||
return ""
|
||||
if node_property == "OPERATING" or node_property == "OS":
|
||||
description = f"NODE {node_id}, {node_property}, SET TO {property_action}"
|
||||
elif node_property == "SERVICE":
|
||||
description = f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}"
|
||||
else:
|
||||
return ""
|
||||
|
||||
return description
|
||||
|
||||
|
||||
def transform_action_acl_enum(action):
|
||||
"""
|
||||
Convert acl action from readable str format, to enumerated format.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2}
|
||||
action_permissions = {"DENY": 0, "ALLOW": 1}
|
||||
|
||||
action_decision = action_decisions[action[0]]
|
||||
action_permission = action_permissions[action[1]]
|
||||
|
||||
# For IPs, Ports and Protocols, ANY has value 0, otherwise its just an index
|
||||
new_action = [action_decision, action_permission] + list(action[2:6])
|
||||
for n, val in enumerate(list(action[2:6])):
|
||||
if val == "ANY":
|
||||
new_action[n + 2] = 0
|
||||
|
||||
new_action = np.array(new_action)
|
||||
return new_action
|
||||
|
||||
|
||||
def acl_action_description(action):
|
||||
"""
|
||||
Generate string describing an acl-based action.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
if isinstance(action[0], (int, np.int64)):
|
||||
# transform action to readable format
|
||||
action = transform_action_acl_readable(action)
|
||||
if action[0] == "NONE":
|
||||
description = "NO ACL RULE APPLIED"
|
||||
else:
|
||||
description = (
|
||||
f"{action[0]} RULE: {action[1]} traffic from IP {action[2]} to IP {action[3]},"
|
||||
f" for protocol/service index {action[4]} on port index {action[5]}"
|
||||
)
|
||||
|
||||
return description
|
||||
|
||||
|
||||
def get_node_of_ip(ip, node_dict):
|
||||
"""
|
||||
Get the node ID of an IP address.
|
||||
|
||||
node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes)
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
for node_key, node_value in node_dict.items():
|
||||
node_ip = node_value.ip_address
|
||||
if node_ip == ip:
|
||||
return node_key
|
||||
|
||||
|
||||
def is_valid_node_action(action):
|
||||
"""Is the node action an actual valid action.
|
||||
|
||||
Only uses information about the action to determine if the action has an effect
|
||||
|
||||
Does NOT consider:
|
||||
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
|
||||
- Node already being in that state (turning an ON node ON)
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
action_r = transform_action_node_readable(action)
|
||||
|
||||
node_property = action_r[1]
|
||||
node_action = action_r[2]
|
||||
|
||||
if node_property == "NONE":
|
||||
return False
|
||||
if node_action == "NONE":
|
||||
return False
|
||||
if node_property == "OPERATING" and node_action == "PATCHING":
|
||||
# Operating State cannot PATCH
|
||||
return False
|
||||
if node_property != "OPERATING" and node_action not in [
|
||||
"NONE",
|
||||
"PATCHING",
|
||||
]:
|
||||
# Software States can only do Nothing or Patch
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_valid_acl_action(action):
|
||||
"""
|
||||
Is the ACL action an actual valid action.
|
||||
|
||||
Only uses information about the action to determine if the action has an effect
|
||||
|
||||
Does NOT consider:
|
||||
- Trying to create identical rules
|
||||
- Trying to create a rule which is a subset of another rule (caused by "ANY")
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
action_r = transform_action_acl_readable(action)
|
||||
|
||||
action_decision = action_r[0]
|
||||
action_permission = action_r[1]
|
||||
action_source_id = action_r[2]
|
||||
action_destination_id = action_r[3]
|
||||
|
||||
if action_decision == "NONE":
|
||||
return False
|
||||
if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY":
|
||||
# ACL rule towards itself
|
||||
return False
|
||||
if action_permission == "DENY":
|
||||
# DENY is unnecessary, we can create and delete allow rules instead
|
||||
# No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY.
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def is_valid_acl_action_extra(action):
|
||||
"""
|
||||
Harsher version of valid acl actions, does not allow action.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
if is_valid_acl_action(action) is False:
|
||||
return False
|
||||
|
||||
action_r = transform_action_acl_readable(action)
|
||||
action_protocol = action_r[4]
|
||||
action_port = action_r[5]
|
||||
|
||||
# Don't allow protocols or ports to be ANY
|
||||
# in the future we might want to do the opposite, and only have ANY option for ports and service
|
||||
if action_protocol == "ANY":
|
||||
return False
|
||||
if action_port == "ANY":
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_new_action(old_action, action_dict):
|
||||
"""
|
||||
Get new action (e.g. 32) from old action e.g. [1,1,1,0].
|
||||
|
||||
Old_action can be either node or acl action type
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
for key, val in action_dict.items():
|
||||
if list(val) == list(old_action):
|
||||
return key
|
||||
# Not all possible actions are included in dict, only valid action are
|
||||
# if action is not in the dict, its an invalid action so return 0
|
||||
return 0
|
||||
|
||||
@@ -13,6 +13,8 @@ import yaml
|
||||
from platformdirs import PlatformDirs
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from primaite.data_viz import PlotlyTemplate
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@@ -76,11 +78,12 @@ def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None):
|
||||
primaite_config = yaml.safe_load(file)
|
||||
|
||||
if level:
|
||||
primaite_config["log_level"] = level.value
|
||||
primaite_config["logging"]["log_level"] = level.value
|
||||
with open(user_config_path, "w") as file:
|
||||
yaml.dump(primaite_config, file)
|
||||
print(f"PrimAITE Log Level: {level}")
|
||||
else:
|
||||
level = primaite_config["log_level"]
|
||||
level = primaite_config["logging"]["log_level"]
|
||||
print(f"PrimAITE Log Level: {level}")
|
||||
|
||||
|
||||
@@ -119,30 +122,18 @@ def setup(overwrite_existing: bool = True):
|
||||
app_dirs = PlatformDirs(appname="primaite")
|
||||
app_dirs.user_config_path.mkdir(exist_ok=True, parents=True)
|
||||
user_config_path = app_dirs.user_config_path / "primaite_config.yaml"
|
||||
build_config = overwrite_existing or (not user_config_path.exists())
|
||||
if build_config:
|
||||
pkg_config_path = Path(
|
||||
pkg_resources.resource_filename(
|
||||
"primaite", "setup/_package_data/primaite_config.yaml"
|
||||
)
|
||||
)
|
||||
pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
|
||||
|
||||
shutil.copy2(pkg_config_path, user_config_path)
|
||||
shutil.copy2(pkg_config_path, user_config_path)
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.setup import (
|
||||
old_installation_clean_up,
|
||||
reset_demo_notebooks,
|
||||
reset_example_configs,
|
||||
setup_app_dirs,
|
||||
)
|
||||
from primaite.setup import old_installation_clean_up, reset_demo_notebooks, reset_example_configs, setup_app_dirs
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
_LOGGER.info("Performing the PrimAITE first-time setup...")
|
||||
|
||||
if build_config:
|
||||
_LOGGER.info("Building primaite_config.yaml...")
|
||||
_LOGGER.info("Building primaite_config.yaml...")
|
||||
|
||||
_LOGGER.info("Building the PrimAITE app directories...")
|
||||
setup_app_dirs.run()
|
||||
@@ -160,13 +151,54 @@ def setup(overwrite_existing: bool = True):
|
||||
|
||||
|
||||
@app.command()
|
||||
def session(tc: str, ldc: str):
|
||||
def session(tc: Optional[str] = None, ldc: Optional[str] = None):
|
||||
"""
|
||||
Run a PrimAITE session.
|
||||
|
||||
:param tc: The training config filepath.
|
||||
:param ldc: The lay down config file path.
|
||||
tc: The training config filepath. Optional. If no value is passed then
|
||||
example default training config is used from:
|
||||
~/primaite/config/example_config/training/training_config_main.yaml.
|
||||
|
||||
ldc: The lay down config file path. Optional. If no value is passed then
|
||||
example default lay down config is used from:
|
||||
~/primaite/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml.
|
||||
"""
|
||||
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from primaite.config.training_config import main_training_config_path
|
||||
from primaite.main import run
|
||||
|
||||
if not tc:
|
||||
tc = main_training_config_path()
|
||||
|
||||
if not ldc:
|
||||
ldc = dos_very_basic_config_path()
|
||||
|
||||
run(training_config_path=tc, lay_down_config_path=ldc)
|
||||
|
||||
|
||||
@app.command()
|
||||
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None):
|
||||
"""
|
||||
View or set the plotly template for Session plots.
|
||||
|
||||
To View, simply call: primaite plotly-template
|
||||
|
||||
To set, call: primaite plotly-template <desired template>
|
||||
|
||||
For example, to set as plotly_dark, call: primaite plotly-template PLOTLY_DARK
|
||||
"""
|
||||
app_dirs = PlatformDirs(appname="primaite")
|
||||
app_dirs.user_config_path.mkdir(exist_ok=True, parents=True)
|
||||
user_config_path = app_dirs.user_config_path / "primaite_config.yaml"
|
||||
if user_config_path.exists():
|
||||
with open(user_config_path, "r") as file:
|
||||
primaite_config = yaml.safe_load(file)
|
||||
|
||||
if template:
|
||||
primaite_config["session"]["outputs"]["plots"]["template"] = template.value
|
||||
with open(user_config_path, "w") as file:
|
||||
yaml.dump(primaite_config, file)
|
||||
print(f"PrimAITE plotly template: {template.value}")
|
||||
else:
|
||||
template = primaite_config["session"]["outputs"]["plots"]["template"]
|
||||
print(f"PrimAITE plotly template: {template}")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""Enumerations for APE."""
|
||||
|
||||
from enum import Enum
|
||||
from enum import Enum, IntEnum
|
||||
|
||||
|
||||
class NodeType(Enum):
|
||||
@@ -32,6 +32,7 @@ class Priority(Enum):
|
||||
class HardwareState(Enum):
|
||||
"""Node hardware state enumeration."""
|
||||
|
||||
NONE = 0
|
||||
ON = 1
|
||||
OFF = 2
|
||||
RESETTING = 3
|
||||
@@ -42,6 +43,7 @@ class HardwareState(Enum):
|
||||
class SoftwareState(Enum):
|
||||
"""Software or Service state enumeration."""
|
||||
|
||||
NONE = 0
|
||||
GOOD = 1
|
||||
PATCHING = 2
|
||||
COMPROMISED = 3
|
||||
@@ -79,6 +81,65 @@ class Protocol(Enum):
|
||||
NONE = 7
|
||||
|
||||
|
||||
class SessionType(Enum):
|
||||
"""The type of PrimAITE Session to be run."""
|
||||
|
||||
TRAIN = 1
|
||||
"Train an agent"
|
||||
EVAL = 2
|
||||
"Evaluate an agent"
|
||||
TRAIN_EVAL = 3
|
||||
"Train then evaluate an agent"
|
||||
|
||||
|
||||
class AgentFramework(Enum):
|
||||
"""The agent algorithm framework/package."""
|
||||
|
||||
CUSTOM = 0
|
||||
"Custom Agent"
|
||||
SB3 = 1
|
||||
"Stable Baselines3"
|
||||
RLLIB = 2
|
||||
"Ray RLlib"
|
||||
|
||||
|
||||
class DeepLearningFramework(Enum):
|
||||
"""The deep learning framework."""
|
||||
|
||||
TF = "tf"
|
||||
"Tensorflow"
|
||||
TF2 = "tf2"
|
||||
"Tensorflow 2.x"
|
||||
TORCH = "torch"
|
||||
"PyTorch"
|
||||
|
||||
|
||||
class AgentIdentifier(Enum):
|
||||
"""The Red Agent algo/class."""
|
||||
|
||||
A2C = 1
|
||||
"Advantage Actor Critic"
|
||||
PPO = 2
|
||||
"Proximal Policy Optimization"
|
||||
HARDCODED = 3
|
||||
"The Hardcoded agents"
|
||||
DO_NOTHING = 4
|
||||
"The DoNothing agents"
|
||||
RANDOM = 5
|
||||
"The RandomAgent"
|
||||
DUMMY = 6
|
||||
"The DummyAgent"
|
||||
|
||||
|
||||
class HardCodedAgentView(Enum):
|
||||
"""The view the deterministic hard-coded agent has of the environment."""
|
||||
|
||||
BASIC = 1
|
||||
"The current observation space only"
|
||||
FULL = 2
|
||||
"Full environment view with actions taken and reward feedback"
|
||||
|
||||
|
||||
class ActionType(Enum):
|
||||
"""Action type enumeration."""
|
||||
|
||||
@@ -128,3 +189,11 @@ class LinkStatus(Enum):
|
||||
MEDIUM = 2
|
||||
HIGH = 3
|
||||
OVERLOAD = 4
|
||||
|
||||
|
||||
class SB3OutputVerboseLevel(IntEnum):
|
||||
"""The Stable Baselines3 learn/eval output verbosity level."""
|
||||
|
||||
NONE = 0
|
||||
INFO = 1
|
||||
DEBUG = 2
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
# # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# """The config class."""
|
||||
# from dataclasses import dataclass
|
||||
#
|
||||
# from primaite.common.enums import ActionType
|
||||
#
|
||||
#
|
||||
# @dataclass()
|
||||
# class TrainingConfig:
|
||||
# """Class to hold main config values."""
|
||||
#
|
||||
# # Generic
|
||||
# agent_identifier: str # The Red Agent algo/class to be used
|
||||
# action_type: ActionType # type of action to use (NODE/ACL/ANY)
|
||||
# num_episodes: int # number of episodes to train over
|
||||
# num_steps: int # number of steps in an episode
|
||||
# time_delay: int # delay between steps (ms) - applies to generic agents only
|
||||
# # file
|
||||
# session_type: str # the session type to run (TRAINING or EVALUATION)
|
||||
# load_agent: str # Determine whether to load an agent from file
|
||||
# agent_load_file: str # File path and file name of agent if you're loading one in
|
||||
#
|
||||
# # Environment
|
||||
# observation_space_high_value: int # The high value for the observation space
|
||||
#
|
||||
# # Reward values
|
||||
# # Generic
|
||||
# all_ok: int
|
||||
# # Node Hardware State
|
||||
# off_should_be_on: int
|
||||
# off_should_be_resetting: int
|
||||
# on_should_be_off: int
|
||||
# on_should_be_resetting: int
|
||||
# resetting_should_be_on: int
|
||||
# resetting_should_be_off: int
|
||||
# resetting: int
|
||||
# # Node Software or Service State
|
||||
# good_should_be_patching: int
|
||||
# good_should_be_compromised: int
|
||||
# good_should_be_overwhelmed: int
|
||||
# patching_should_be_good: int
|
||||
# patching_should_be_compromised: int
|
||||
# patching_should_be_overwhelmed: int
|
||||
# patching: int
|
||||
# compromised_should_be_good: int
|
||||
# compromised_should_be_patching: int
|
||||
# compromised_should_be_overwhelmed: int
|
||||
# compromised: int
|
||||
# overwhelmed_should_be_good: int
|
||||
# overwhelmed_should_be_patching: int
|
||||
# overwhelmed_should_be_compromised: int
|
||||
# overwhelmed: int
|
||||
# # Node File System State
|
||||
# good_should_be_repairing: int
|
||||
# good_should_be_restoring: int
|
||||
# good_should_be_corrupt: int
|
||||
# good_should_be_destroyed: int
|
||||
# repairing_should_be_good: int
|
||||
# repairing_should_be_restoring: int
|
||||
# repairing_should_be_corrupt: int
|
||||
# repairing_should_be_destroyed: int # Repairing does not fix destroyed state - you need to restore
|
||||
#
|
||||
# repairing: int
|
||||
# restoring_should_be_good: int
|
||||
# restoring_should_be_repairing: int
|
||||
# restoring_should_be_corrupt: int # Not the optimal method (as repair will fix corruption)
|
||||
#
|
||||
# restoring_should_be_destroyed: int
|
||||
# restoring: int
|
||||
# corrupt_should_be_good: int
|
||||
# corrupt_should_be_repairing: int
|
||||
# corrupt_should_be_restoring: int
|
||||
# corrupt_should_be_destroyed: int
|
||||
# corrupt: int
|
||||
# destroyed_should_be_good: int
|
||||
# destroyed_should_be_repairing: int
|
||||
# destroyed_should_be_restoring: int
|
||||
# destroyed_should_be_corrupt: int
|
||||
# destroyed: int
|
||||
# scanning: int
|
||||
# # IER status
|
||||
# red_ier_running: int
|
||||
# green_ier_blocked: int
|
||||
#
|
||||
# # Patching / Reset
|
||||
# os_patching_duration: int # The time taken to patch the OS
|
||||
# node_reset_duration: int # The time taken to reset a node (hardware)
|
||||
# node_booting_duration = 0 # The Time taken to turn on the node
|
||||
# node_shutdown_duration = 0 # The time taken to turn off the node
|
||||
# service_patching_duration: int # The time taken to patch a service
|
||||
# file_system_repairing_limit: int # The time take to repair a file
|
||||
# file_system_restoring_limit: int # The time take to restore a file
|
||||
# file_system_scanning_limit: int # The time taken to scan the file system
|
||||
# # Patching / Reset
|
||||
#
|
||||
@@ -1,7 +1,3 @@
|
||||
- item_type: ACTIONS
|
||||
type: NODE
|
||||
- item_type: STEPS
|
||||
steps: 128
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '80'
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
- item_type: ACTIONS
|
||||
type: NODE
|
||||
- item_type: STEPS
|
||||
steps: 128
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '80'
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
- item_type: ACTIONS
|
||||
type: NODE
|
||||
- item_type: STEPS
|
||||
steps: 256
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '80'
|
||||
|
||||
@@ -1,16 +1,42 @@
|
||||
# Main Config File
|
||||
# Training Config File
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agent_identifier: STABLE_BASELINES3_A2C
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: SB3
|
||||
|
||||
# RED AGENT IDENTIFIER
|
||||
# RANDOM or NONE
|
||||
# Sets which deep learning framework will be used (by RLlib ONLY).
|
||||
# Default is TF (Tensorflow).
|
||||
# Options are:
|
||||
# "TF" (Tensorflow)
|
||||
# TF2 (Tensorflow 2.X)
|
||||
# TORCH (PyTorch)
|
||||
deep_learning_framework: TF2
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: PPO
|
||||
|
||||
# Sets whether Red Agent POL and IER is randomised.
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
random_red_agent: False
|
||||
|
||||
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
|
||||
# Options are:
|
||||
# "BASIC" (The current observation space only)
|
||||
# "FULL" (Full environment view with actions taken and reward feedback)
|
||||
hard_coded_agent_view: FULL
|
||||
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
@@ -25,21 +51,34 @@ observation_space:
|
||||
# - name: LINK_TRAFFIC_LEVELS
|
||||
# Number of episodes to run per session
|
||||
num_episodes: 10
|
||||
|
||||
# Number of time_steps per episode
|
||||
num_steps: 256
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 10
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAINING
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
|
||||
# Sets how often the agent will save a checkpoint (every n time episodes).
|
||||
# Set to 0 if no checkpoints are required. Default is 10
|
||||
checkpoint_every_n_episodes: 10
|
||||
|
||||
# Time delay (milliseconds) between steps for CUSTOM agents.
|
||||
time_delay: 5
|
||||
|
||||
# Type of session to be run. Options are:
|
||||
# "TRAIN" (Trains an agent)
|
||||
# "EVAL" (Evaluates an agent)
|
||||
# "TRAIN_EVAL" (Trains then evaluates an agent)
|
||||
session_type: TRAIN_EVAL
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
# The Stable Baselines3 learn/eval output verbosity level:
|
||||
# Options are:
|
||||
# "NONE" (No Output)
|
||||
# "INFO" (Info Messages (such as devices and wrappers used))
|
||||
# "DEBUG" (All Messages)
|
||||
sb3_output_verbose_level: NONE
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0
|
||||
|
||||
@@ -7,8 +7,10 @@
|
||||
# "GENERIC"
|
||||
agent_identifier: STABLE_BASELINES3_A2C
|
||||
|
||||
# RED AGENT IDENTIFIER
|
||||
# RANDOM or NONE
|
||||
# Sets whether Red Agent POL and IER is randomised.
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
random_red_agent: True
|
||||
|
||||
# Sets How the Action Space is defined:
|
||||
|
||||
@@ -1,14 +1,57 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
from typing import Any, Dict, Final, Union
|
||||
|
||||
from primaite import USERS_CONFIG_DIR, getLogger
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger, USERS_CONFIG_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"
|
||||
|
||||
|
||||
def convert_legacy_lay_down_config_dict(legacy_config_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a legacy lay down config dict to the new format.
|
||||
|
||||
:param legacy_config_dict: A legacy lay down config dict.
|
||||
"""
|
||||
_LOGGER.warning("Legacy lay down config conversion not yet implemented")
|
||||
return legacy_config_dict
|
||||
|
||||
|
||||
def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict:
|
||||
"""
|
||||
Read in a lay down config yaml file.
|
||||
|
||||
:param file_path: The config file path.
|
||||
:param legacy_file: True if the config file is legacy format, otherwise
|
||||
False.
|
||||
:return: The lay down config as a dict.
|
||||
:raises ValueError: If the file_path does not exist.
|
||||
"""
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
if file_path.exists():
|
||||
with open(file_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
_LOGGER.debug(f"Loading lay down config file: {file_path}")
|
||||
if legacy_file:
|
||||
try:
|
||||
config = convert_legacy_lay_down_config_dict(config)
|
||||
except KeyError:
|
||||
msg = (
|
||||
f"Failed to convert lay down config file {file_path} "
|
||||
f"from legacy format. Attempting to use file as is."
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
return config
|
||||
msg = f"Cannot load the lay down config as it does not exist: {file_path}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def ddos_basic_one_config_path() -> Path:
|
||||
"""
|
||||
The path to the example lay_down_config_1_DDOS_basic.yaml file.
|
||||
|
||||
@@ -1,58 +1,96 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Optional, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite import USERS_CONFIG_DIR, getLogger
|
||||
from primaite.common.enums import ActionType
|
||||
from primaite import getLogger, USERS_CONFIG_DIR
|
||||
from primaite.common.enums import (
|
||||
ActionType,
|
||||
AgentFramework,
|
||||
AgentIdentifier,
|
||||
DeepLearningFramework,
|
||||
HardCodedAgentView,
|
||||
SB3OutputVerboseLevel,
|
||||
SessionType,
|
||||
)
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training"
|
||||
|
||||
|
||||
def main_training_config_path() -> Path:
|
||||
"""
|
||||
The path to the example training_config_main.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
path = _EXAMPLE_TRAINING / "training_config_main.yaml"
|
||||
if not path.exists():
|
||||
msg = "Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
@dataclass()
|
||||
class TrainingConfig:
|
||||
"""The Training Config class."""
|
||||
|
||||
# Generic
|
||||
agent_identifier: str = "STABLE_BASELINES3_A2C"
|
||||
"The Red Agent algo/class to be used."
|
||||
agent_framework: AgentFramework = AgentFramework.SB3
|
||||
"The AgentFramework"
|
||||
|
||||
deep_learning_framework: DeepLearningFramework = DeepLearningFramework.TF
|
||||
"The DeepLearningFramework"
|
||||
|
||||
agent_identifier: AgentIdentifier = AgentIdentifier.PPO
|
||||
"The AgentIdentifier"
|
||||
|
||||
hard_coded_agent_view: HardCodedAgentView = HardCodedAgentView.FULL
|
||||
"The view the deterministic hard-coded agent has of the environment"
|
||||
|
||||
random_red_agent: bool = False
|
||||
"Creates Random Red Agent Attacks"
|
||||
|
||||
action_type: ActionType = ActionType.ANY
|
||||
"The ActionType to use."
|
||||
"The ActionType to use"
|
||||
|
||||
num_episodes: int = 10
|
||||
"The number of episodes to train over."
|
||||
"The number of episodes to train over"
|
||||
|
||||
num_steps: int = 256
|
||||
"The number of steps in an episode."
|
||||
observation_space: dict = field(
|
||||
default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]}
|
||||
)
|
||||
"The observation space config dict."
|
||||
"The number of steps in an episode"
|
||||
|
||||
checkpoint_every_n_episodes: int = 5
|
||||
"The agent will save a checkpoint every n episodes"
|
||||
|
||||
observation_space: dict = field(default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]})
|
||||
"The observation space config dict"
|
||||
|
||||
time_delay: int = 10
|
||||
"The delay between steps (ms). Applies to generic agents only."
|
||||
"The delay between steps (ms). Applies to generic agents only"
|
||||
|
||||
# file
|
||||
session_type: str = "TRAINING"
|
||||
"the session type to run (TRAINING or EVALUATION)"
|
||||
session_type: SessionType = SessionType.TRAIN
|
||||
"The type of PrimAITE session to run"
|
||||
|
||||
load_agent: str = False
|
||||
"Determine whether to load an agent from file."
|
||||
"Determine whether to load an agent from file"
|
||||
|
||||
agent_load_file: Optional[str] = None
|
||||
"File path and file name of agent if you're loading one in."
|
||||
"File path and file name of agent if you're loading one in"
|
||||
|
||||
# Environment
|
||||
observation_space_high_value: int = 1000000000
|
||||
"The high value for the observation space."
|
||||
"The high value for the observation space"
|
||||
|
||||
sb3_output_verbose_level: SB3OutputVerboseLevel = SB3OutputVerboseLevel.NONE
|
||||
"Stable Baselines3 learn/eval output verbosity level"
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
@@ -117,28 +155,51 @@ class TrainingConfig:
|
||||
|
||||
# Patching / Reset durations
|
||||
os_patching_duration: int = 5
|
||||
"The time taken to patch the OS."
|
||||
"The time taken to patch the OS"
|
||||
|
||||
node_reset_duration: int = 5
|
||||
"The time taken to reset a node (hardware)."
|
||||
"The time taken to reset a node (hardware)"
|
||||
|
||||
node_booting_duration: int = 3
|
||||
"The Time taken to turn on the node."
|
||||
"The Time taken to turn on the node"
|
||||
|
||||
node_shutdown_duration: int = 2
|
||||
"The time taken to turn off the node."
|
||||
"The time taken to turn off the node"
|
||||
|
||||
service_patching_duration: int = 5
|
||||
"The time taken to patch a service."
|
||||
"The time taken to patch a service"
|
||||
|
||||
file_system_repairing_limit: int = 5
|
||||
"The time take to repair the file system."
|
||||
"The time take to repair the file system"
|
||||
|
||||
file_system_restoring_limit: int = 5
|
||||
"The time take to restore the file system."
|
||||
"The time take to restore the file system"
|
||||
|
||||
file_system_scanning_limit: int = 5
|
||||
"The time taken to scan the file system."
|
||||
"The time taken to scan the file system"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig:
|
||||
"""
|
||||
Create an instance of TrainingConfig from a dict.
|
||||
|
||||
:param config_dict: The training config dict.
|
||||
:return: The instance of TrainingConfig.
|
||||
"""
|
||||
field_enum_map = {
|
||||
"agent_framework": AgentFramework,
|
||||
"deep_learning_framework": DeepLearningFramework,
|
||||
"agent_identifier": AgentIdentifier,
|
||||
"action_type": ActionType,
|
||||
"session_type": SessionType,
|
||||
"sb3_output_verbose_level": SB3OutputVerboseLevel,
|
||||
"hard_coded_agent_view": HardCodedAgentView,
|
||||
}
|
||||
|
||||
for key, value in field_enum_map.items():
|
||||
if key in config_dict:
|
||||
config_dict[key] = value[config_dict[key]]
|
||||
return TrainingConfig(**config_dict)
|
||||
|
||||
def to_dict(self, json_serializable: bool = True):
|
||||
"""
|
||||
@@ -150,24 +211,28 @@ class TrainingConfig:
|
||||
"""
|
||||
data = self.__dict__
|
||||
if json_serializable:
|
||||
data["action_type"] = self.action_type.value
|
||||
data["agent_framework"] = self.agent_framework.name
|
||||
data["deep_learning_framework"] = self.deep_learning_framework.name
|
||||
data["agent_identifier"] = self.agent_identifier.name
|
||||
data["action_type"] = self.action_type.name
|
||||
data["sb3_output_verbose_level"] = self.sb3_output_verbose_level.name
|
||||
data["session_type"] = self.session_type.name
|
||||
data["hard_coded_agent_view"] = self.hard_coded_agent_view.name
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def main_training_config_path() -> Path:
|
||||
"""
|
||||
The path to the example training_config_main.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
path = _EXAMPLE_TRAINING / "training_config_main.yaml"
|
||||
if not path.exists():
|
||||
msg = "Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
def __str__(self) -> str:
|
||||
tc = f"{self.agent_framework}, "
|
||||
if self.agent_framework is AgentFramework.RLLIB:
|
||||
tc += f"{self.deep_learning_framework}, "
|
||||
tc += f"{self.agent_identifier}, "
|
||||
if self.agent_identifier is AgentIdentifier.HARDCODED:
|
||||
tc += f"{self.hard_coded_agent_view}, "
|
||||
tc += f"{self.action_type}, "
|
||||
tc += f"observation_space={self.observation_space}, "
|
||||
tc += f"{self.num_episodes} episodes @ "
|
||||
tc += f"{self.num_steps} steps"
|
||||
return tc
|
||||
|
||||
|
||||
def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig:
|
||||
@@ -198,15 +263,10 @@ def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConf
|
||||
f"from legacy format. Attempting to use file as is."
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
# Convert values to Enums
|
||||
config["action_type"] = ActionType[config["action_type"]]
|
||||
try:
|
||||
return TrainingConfig(**config)
|
||||
return TrainingConfig.from_dict(config)
|
||||
except TypeError as e:
|
||||
msg = (
|
||||
f"Error when creating an instance of {TrainingConfig} "
|
||||
f"from the training config file {file_path}"
|
||||
)
|
||||
msg = f"Error when creating an instance of {TrainingConfig} " f"from the training config file {file_path}"
|
||||
_LOGGER.critical(msg, exc_info=True)
|
||||
raise e
|
||||
msg = f"Cannot load the training config as it does not exist: {file_path}"
|
||||
@@ -215,19 +275,35 @@ def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConf
|
||||
|
||||
|
||||
def convert_legacy_training_config_dict(
|
||||
legacy_config_dict: Dict[str, Any], num_steps: int = 256, action_type: str = "ANY"
|
||||
legacy_config_dict: Dict[str, Any],
|
||||
agent_framework: AgentFramework = AgentFramework.SB3,
|
||||
agent_identifier: AgentIdentifier = AgentIdentifier.PPO,
|
||||
action_type: ActionType = ActionType.ANY,
|
||||
num_steps: int = 256,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a legacy training config dict to the new format.
|
||||
|
||||
:param legacy_config_dict: A legacy training config dict.
|
||||
:param num_steps: The number of steps to set as legacy training configs
|
||||
don't have num_steps values.
|
||||
:param agent_framework: The agent framework to use as legacy training
|
||||
configs don't have agent_framework values.
|
||||
:param agent_identifier: The red agent identifier to use as legacy
|
||||
training configs don't have agent_identifier values.
|
||||
:param action_type: The action space type to set as legacy training configs
|
||||
don't have action_type values.
|
||||
:param num_steps: The number of steps to set as legacy training configs
|
||||
don't have num_steps values.
|
||||
:return: The converted training config dict.
|
||||
"""
|
||||
config_dict = {"num_steps": num_steps, "action_type": action_type}
|
||||
config_dict = {
|
||||
"agent_framework": agent_framework.name,
|
||||
"agent_identifier": agent_identifier.name,
|
||||
"action_type": action_type.name,
|
||||
"num_steps": num_steps,
|
||||
"sb3_output_verbose_level": SB3OutputVerboseLevel.INFO.name,
|
||||
}
|
||||
session_type_map = {"TRAINING": "TRAIN", "EVALUATION": "EVAL"}
|
||||
legacy_config_dict["sessionType"] = session_type_map[legacy_config_dict["sessionType"]]
|
||||
for legacy_key, value in legacy_config_dict.items():
|
||||
new_key = _get_new_key_from_legacy(legacy_key)
|
||||
if new_key:
|
||||
@@ -243,7 +319,7 @@ def _get_new_key_from_legacy(legacy_key: str) -> str:
|
||||
:return: The mapped key.
|
||||
"""
|
||||
key_mapping = {
|
||||
"agentIdentifier": "agent_identifier",
|
||||
"agentIdentifier": None,
|
||||
"numEpisodes": "num_episodes",
|
||||
"timeDelay": "time_delay",
|
||||
"configFilename": None,
|
||||
|
||||
13
src/primaite/data_viz/__init__.py
Normal file
13
src/primaite/data_viz/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class PlotlyTemplate(Enum):
|
||||
"""The built-in plotly templates."""
|
||||
|
||||
PLOTLY = "plotly"
|
||||
PLOTLY_WHITE = "plotly_white"
|
||||
PLOTLY_DARK = "plotly_dark"
|
||||
GGPLOT2 = "ggplot2"
|
||||
SEABORN = "seaborn"
|
||||
SIMPLE_WHITE = "simple_white"
|
||||
NONE = "none"
|
||||
73
src/primaite/data_viz/session_plots.py
Normal file
73
src/primaite/data_viz/session_plots.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import plotly.graph_objects as go
|
||||
import polars as pl
|
||||
import yaml
|
||||
from plotly.graph_objs import Figure
|
||||
|
||||
from primaite import _PLATFORM_DIRS
|
||||
|
||||
|
||||
def _get_plotly_config() -> Dict:
|
||||
"""Get the plotly config from primaite_config.yaml."""
|
||||
user_config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml"
|
||||
with open(user_config_path, "r") as file:
|
||||
primaite_config = yaml.safe_load(file)
|
||||
return primaite_config["session"]["outputs"]["plots"]
|
||||
|
||||
|
||||
def plot_av_reward_per_episode(
|
||||
av_reward_per_episode_csv: Union[str, Path],
|
||||
title: Optional[str] = None,
|
||||
subtitle: Optional[str] = None,
|
||||
) -> Figure:
|
||||
"""
|
||||
Plot the average reward per episode from a csv session output.
|
||||
|
||||
:param av_reward_per_episode_csv: The average reward per episode csv
|
||||
file path.
|
||||
:param title: The plot title. This is optional.
|
||||
:param subtitle: The plot subtitle. This is optional.
|
||||
:return: The plot as an instance of ``plotly.graph_objs._figure.Figure``.
|
||||
"""
|
||||
df = pl.read_csv(av_reward_per_episode_csv)
|
||||
|
||||
if title:
|
||||
if subtitle:
|
||||
title = f"{title} <br>{subtitle}</sup>"
|
||||
else:
|
||||
if subtitle:
|
||||
title = subtitle
|
||||
|
||||
config = _get_plotly_config()
|
||||
layout = go.Layout(
|
||||
autosize=config["size"]["auto_size"],
|
||||
width=config["size"]["width"],
|
||||
height=config["size"]["height"],
|
||||
)
|
||||
# Create the line graph with a colored line
|
||||
fig = go.Figure(layout=layout)
|
||||
fig.update_layout(template=config["template"])
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=df["Episode"],
|
||||
y=df["Average Reward"],
|
||||
mode="lines",
|
||||
name="Mean Reward per Episode",
|
||||
)
|
||||
)
|
||||
|
||||
# Set the layout of the graph
|
||||
fig.update_layout(
|
||||
xaxis={
|
||||
"title": "Episode",
|
||||
"type": "linear",
|
||||
"rangeslider": {"visible": config["range_slider"]},
|
||||
},
|
||||
yaxis={"title": "Average Reward"},
|
||||
title=title,
|
||||
showlegend=False,
|
||||
)
|
||||
|
||||
return fig
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Module for handling configurable observation spaces in PrimAITE."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Dict, Final, List, Tuple, Union
|
||||
from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
@@ -63,7 +63,7 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
"""
|
||||
|
||||
_FIXED_PARAMETERS: int = 4
|
||||
_MAX_VAL: int = 1_000_000
|
||||
_MAX_VAL: int = 1_000_000_000
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite"):
|
||||
@@ -101,9 +101,7 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
self.current_observation[item_index][1] = node.hardware_state.value
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
self.current_observation[item_index][2] = node.software_state.value
|
||||
self.current_observation[item_index][
|
||||
3
|
||||
] = node.file_system_state_observed.value
|
||||
self.current_observation[item_index][3] = node.file_system_state_observed.value
|
||||
else:
|
||||
self.current_observation[item_index][2] = 0
|
||||
self.current_observation[item_index][3] = 0
|
||||
@@ -111,9 +109,7 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
if isinstance(node, ServiceNode):
|
||||
for service in self.env.services_list:
|
||||
if node.has_service(service):
|
||||
self.current_observation[item_index][
|
||||
service_index
|
||||
] = node.get_service_state(service).value
|
||||
self.current_observation[item_index][service_index] = node.get_service_state(service).value
|
||||
else:
|
||||
self.current_observation[item_index][service_index] = 0
|
||||
service_index += 1
|
||||
@@ -133,9 +129,7 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
protocol_list = link.get_protocol_list()
|
||||
protocol_index = 0
|
||||
for protocol in protocol_list:
|
||||
self.current_observation[item_index][
|
||||
protocol_index + 4
|
||||
] = protocol.get_load()
|
||||
self.current_observation[item_index][protocol_index + 4] = protocol.get_load()
|
||||
protocol_index += 1
|
||||
item_index += 1
|
||||
|
||||
@@ -244,7 +238,12 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
if node.has_service(service):
|
||||
service_states[i] = node.get_service_state(service).value
|
||||
obs.extend(
|
||||
[hardware_state, software_state, file_system_state, *service_states]
|
||||
[
|
||||
hardware_state,
|
||||
software_state,
|
||||
file_system_state,
|
||||
*service_states,
|
||||
]
|
||||
)
|
||||
self.current_observation[:] = obs
|
||||
|
||||
@@ -267,9 +266,7 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
for service in services:
|
||||
structure.append(f"node_{node_id}_service_{service}_state_NONE")
|
||||
for state in SoftwareState:
|
||||
structure.append(
|
||||
f"node_{node_id}_service_{service}_state_{state.name}"
|
||||
)
|
||||
structure.append(f"node_{node_id}_service_{service}_state_{state.name}")
|
||||
return structure
|
||||
|
||||
|
||||
@@ -325,9 +322,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
self._entries_per_link = self.env.num_services
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
shape = (
|
||||
[self._quantisation_levels] * self.env.num_links * self._entries_per_link
|
||||
)
|
||||
shape = [self._quantisation_levels] * self.env.num_links * self._entries_per_link
|
||||
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.MultiDiscrete(shape)
|
||||
@@ -356,9 +351,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
elif load >= bandwidth:
|
||||
traffic_level = self._quantisation_levels - 1
|
||||
else:
|
||||
traffic_level = (load / bandwidth) // (
|
||||
1 / (self._quantisation_levels - 2)
|
||||
) + 1
|
||||
traffic_level = (load / bandwidth) // (1 / (self._quantisation_levels - 2)) + 1
|
||||
|
||||
obs.append(int(traffic_level))
|
||||
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""Main environment module containing the PRIMmary AI Training Evironment (Primaite) class."""
|
||||
import copy
|
||||
import csv
|
||||
import logging
|
||||
import uuid as uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from random import choice, randint, sample, uniform
|
||||
from typing import Dict, Tuple, Union
|
||||
from typing import Dict, Final, Tuple, Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
@@ -15,11 +13,13 @@ import yaml
|
||||
from gym import Env, spaces
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import (
|
||||
ActionType,
|
||||
AgentFramework,
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
NodePOLInitiator,
|
||||
@@ -27,6 +27,7 @@ from primaite.common.enums import (
|
||||
NodeType,
|
||||
ObservationType,
|
||||
Priority,
|
||||
SessionType,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.service import Service
|
||||
@@ -45,9 +46,9 @@ from primaite.pol.green_pol import apply_iers, apply_node_pol
|
||||
from primaite.pol.ier import IER
|
||||
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
|
||||
from primaite.transactions.transaction import Transaction
|
||||
from primaite.utils.session_output_writer import SessionOutputWriter
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_LOGGER.setLevel(logging.INFO)
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class Primaite(Env):
|
||||
@@ -63,7 +64,6 @@ class Primaite(Env):
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
transaction_list,
|
||||
session_path: Path,
|
||||
timestamp_str: str,
|
||||
):
|
||||
@@ -72,26 +72,23 @@ class Primaite(Env):
|
||||
|
||||
:param training_config_path: The training config filepath.
|
||||
:param lay_down_config_path: The lay down config filepath.
|
||||
:param transaction_list: The list of transactions to populate.
|
||||
:param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
self.session_path: Final[Path] = session_path
|
||||
self.timestamp_str: Final[str] = timestamp_str
|
||||
self._training_config_path = training_config_path
|
||||
self._lay_down_config_path = lay_down_config_path
|
||||
|
||||
self.training_config: TrainingConfig = training_config.load(
|
||||
training_config_path
|
||||
)
|
||||
self.training_config: TrainingConfig = training_config.load(training_config_path)
|
||||
_LOGGER.info(f"Using: {str(self.training_config)}")
|
||||
|
||||
# Number of steps in an episode
|
||||
self.episode_steps = self.training_config.num_steps
|
||||
|
||||
super(Primaite, self).__init__()
|
||||
|
||||
# Transaction list
|
||||
self.transaction_list = transaction_list
|
||||
|
||||
# The agent in use
|
||||
self.agent_identifier = self.training_config.agent_identifier
|
||||
|
||||
@@ -178,6 +175,9 @@ class Primaite(Env):
|
||||
# It will be initialised later.
|
||||
self.obs_handler: ObservationsHandler
|
||||
|
||||
self._obs_space_description = None
|
||||
"The env observation space description for transactions writing"
|
||||
|
||||
# Open the config file and build the environment laydown
|
||||
with open(self._lay_down_config_path, "r") as file:
|
||||
# Open the config file and build the environment laydown
|
||||
@@ -204,16 +204,14 @@ class Primaite(Env):
|
||||
plt.savefig(file_path, format="PNG")
|
||||
plt.clf()
|
||||
except Exception:
|
||||
_LOGGER.error("Could not save network diagram")
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
print("Could not save network diagram")
|
||||
_LOGGER.error("Could not save network diagram", exc_info=True)
|
||||
|
||||
# Initiate observation space
|
||||
self.observation_space, self.env_obs = self.init_observations()
|
||||
|
||||
# Define Action Space - depends on action space type (Node or ACL)
|
||||
if self.training_config.action_type == ActionType.NODE:
|
||||
_LOGGER.info("Action space type NODE selected")
|
||||
_LOGGER.debug("Action space type NODE selected")
|
||||
# Terms (for node action space):
|
||||
# [0, num nodes] - node ID (0 = nothing, node ID)
|
||||
# [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa
|
||||
@@ -222,7 +220,7 @@ class Primaite(Env):
|
||||
self.action_dict = self.create_node_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
elif self.training_config.action_type == ActionType.ACL:
|
||||
_LOGGER.info("Action space type ACL selected")
|
||||
_LOGGER.debug("Action space type ACL selected")
|
||||
# Terms (for ACL action space):
|
||||
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
|
||||
# [0, 1] - Permission (0 = DENY, 1 = ALLOW)
|
||||
@@ -233,27 +231,29 @@ class Primaite(Env):
|
||||
self.action_dict = self.create_acl_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
elif self.training_config.action_type == ActionType.ANY:
|
||||
_LOGGER.info("Action space type ANY selected - Node + ACL")
|
||||
_LOGGER.debug("Action space type ANY selected - Node + ACL")
|
||||
self.action_dict = self.create_node_and_acl_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"Invalid action type selected: {self.training_config.action_type}"
|
||||
)
|
||||
# Set up a csv to store the results of the training
|
||||
try:
|
||||
header = ["Episode", "Average Reward"]
|
||||
_LOGGER.error(f"Invalid action type selected: {self.training_config.action_type}")
|
||||
|
||||
file_name = f"average_reward_per_episode_{timestamp_str}.csv"
|
||||
file_path = session_path / file_name
|
||||
self.csv_file = open(file_path, "w", encoding="UTF8", newline="")
|
||||
self.csv_writer = csv.writer(self.csv_file)
|
||||
self.csv_writer.writerow(header)
|
||||
except Exception:
|
||||
_LOGGER.error(
|
||||
"Could not create csv file to hold average reward per episode"
|
||||
)
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=True)
|
||||
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=True)
|
||||
|
||||
@property
|
||||
def actual_episode_count(self) -> int:
|
||||
"""Shifts the episode_count by -1 for RLlib."""
|
||||
if self.training_config.agent_framework is AgentFramework.RLLIB:
|
||||
return self.episode_count - 1
|
||||
return self.episode_count
|
||||
|
||||
def set_as_eval(self):
|
||||
"""Set the writers to write to eval directories."""
|
||||
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=False)
|
||||
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=False)
|
||||
self.episode_count = 0
|
||||
self.step_count = 0
|
||||
self.total_step_count = 0
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
@@ -262,12 +262,14 @@ class Primaite(Env):
|
||||
Returns:
|
||||
Environment observation space (reset)
|
||||
"""
|
||||
csv_data = self.episode_count, self.average_reward
|
||||
self.csv_writer.writerow(csv_data)
|
||||
if self.actual_episode_count > 0:
|
||||
csv_data = self.actual_episode_count, self.average_reward
|
||||
self.episode_av_reward_writer.write(csv_data)
|
||||
|
||||
self.episode_count += 1
|
||||
|
||||
# Don't need to reset links, as they are cleared and recalculated every step
|
||||
# Don't need to reset links, as they are cleared and recalculated every
|
||||
# step
|
||||
|
||||
# Clear the ACL
|
||||
self.init_acl()
|
||||
@@ -287,6 +289,7 @@ class Primaite(Env):
|
||||
|
||||
# Update observations space and return
|
||||
self.update_environent_obs()
|
||||
|
||||
return self.env_obs
|
||||
|
||||
def step(self, action):
|
||||
@@ -302,15 +305,10 @@ class Primaite(Env):
|
||||
done: Indicates episode is complete if True
|
||||
step_info: Additional information relating to this step
|
||||
"""
|
||||
if self.step_count == 0:
|
||||
print(f"Episode: {str(self.episode_count)}")
|
||||
|
||||
# TEMP
|
||||
done = False
|
||||
|
||||
self.step_count += 1
|
||||
self.total_step_count += 1
|
||||
# print("Episode step: " + str(self.step_count))
|
||||
|
||||
# Need to clear traffic on all links first
|
||||
for link_key, link_value in self.links.items():
|
||||
@@ -320,14 +318,15 @@ class Primaite(Env):
|
||||
link.clear_traffic()
|
||||
|
||||
# Create a Transaction (metric) object for this step
|
||||
transaction = Transaction(
|
||||
datetime.now(), self.agent_identifier, self.episode_count, self.step_count
|
||||
)
|
||||
transaction = Transaction(self.agent_identifier, self.actual_episode_count, self.step_count)
|
||||
# Load the initial observation space into the transaction
|
||||
transaction.set_obs_space(self.obs_handler._flat_observation)
|
||||
transaction.obs_space = self.obs_handler._flat_observation
|
||||
|
||||
# Set the transaction obs space description
|
||||
transaction.obs_space_description = self._obs_space_description
|
||||
|
||||
# Load the action space into the transaction
|
||||
transaction.set_action_space(copy.deepcopy(action))
|
||||
transaction.action_space = copy.deepcopy(action)
|
||||
|
||||
# 1. Implement Blue Action
|
||||
self.interpret_action_and_apply(action)
|
||||
@@ -371,9 +370,7 @@ class Primaite(Env):
|
||||
self.acl,
|
||||
self.step_count,
|
||||
)
|
||||
apply_red_agent_node_pol(
|
||||
self.nodes, self.red_iers, self.red_node_pol, self.step_count
|
||||
)
|
||||
apply_red_agent_node_pol(self.nodes, self.red_iers, self.red_node_pol, self.step_count)
|
||||
# Take snapshots of nodes and links
|
||||
self.nodes_post_red = copy.deepcopy(self.nodes)
|
||||
self.links_post_red = copy.deepcopy(self.links)
|
||||
@@ -389,17 +386,17 @@ class Primaite(Env):
|
||||
self.step_count,
|
||||
self.training_config,
|
||||
)
|
||||
# print(f" Step {self.step_count} Reward: {str(reward)}")
|
||||
_LOGGER.debug(f"Episode: {self.actual_episode_count}, " f"Step {self.step_count}, " f"Reward: {reward}")
|
||||
self.total_reward += reward
|
||||
if self.step_count == self.episode_steps:
|
||||
self.average_reward = self.total_reward / self.step_count
|
||||
if self.training_config.session_type == "EVALUATION":
|
||||
if self.training_config.session_type is SessionType.EVAL:
|
||||
# For evaluation, need to trigger the done value = True when
|
||||
# step count is reached in order to prevent neverending episode
|
||||
done = True
|
||||
print(f" Average Reward: {str(self.average_reward)}")
|
||||
_LOGGER.info(f"Episode: {self.actual_episode_count}, " f"Average Reward: {self.average_reward}")
|
||||
# Load the reward into the transaction
|
||||
transaction.set_reward(reward)
|
||||
transaction.reward = reward
|
||||
|
||||
# 6. Output Verbose
|
||||
# self.output_link_status()
|
||||
@@ -407,15 +404,21 @@ class Primaite(Env):
|
||||
# 7. Update env_obs
|
||||
self.update_environent_obs()
|
||||
|
||||
# 8. Add the transaction to the list of transactions
|
||||
self.transaction_list.append(copy.deepcopy(transaction))
|
||||
# Write transaction to file
|
||||
if self.actual_episode_count > 0:
|
||||
self.transaction_writer.write(transaction)
|
||||
|
||||
# Return
|
||||
return self.env_obs, reward, done, self.step_info
|
||||
|
||||
def __close__(self):
|
||||
"""Override close function."""
|
||||
self.csv_file.close()
|
||||
def close(self):
|
||||
"""Override parent close and close writers."""
|
||||
# Close files if last episode/step
|
||||
# if self.can_finish:
|
||||
super().close()
|
||||
|
||||
self.transaction_writer.close()
|
||||
self.episode_av_reward_writer.close()
|
||||
|
||||
def init_acl(self):
|
||||
"""Initialise the Access Control List."""
|
||||
@@ -424,14 +427,9 @@ class Primaite(Env):
|
||||
def output_link_status(self):
|
||||
"""Output the link status of all links to the console."""
|
||||
for link_key, link_value in self.links.items():
|
||||
print("Link ID: " + link_value.get_id())
|
||||
_LOGGER.debug("Link ID: " + link_value.get_id())
|
||||
for protocol in link_value.protocol_list:
|
||||
print(
|
||||
" Protocol: "
|
||||
+ protocol.get_name().name
|
||||
+ ", Load: "
|
||||
+ str(protocol.get_load())
|
||||
)
|
||||
print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load()))
|
||||
|
||||
def interpret_action_and_apply(self, _action):
|
||||
"""
|
||||
@@ -446,13 +444,9 @@ class Primaite(Env):
|
||||
self.apply_actions_to_nodes(_action)
|
||||
elif self.training_config.action_type == ActionType.ACL:
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif (
|
||||
len(self.action_dict[_action]) == 6
|
||||
): # ACL actions in multidiscrete form have len 6
|
||||
elif len(self.action_dict[_action]) == 6: # ACL actions in multidiscrete form have len 6
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif (
|
||||
len(self.action_dict[_action]) == 4
|
||||
): # Node actions in multdiscrete (array) from have len 4
|
||||
elif len(self.action_dict[_action]) == 4: # Node actions in multdiscrete (array) from have len 4
|
||||
self.apply_actions_to_nodes(_action)
|
||||
else:
|
||||
logging.error("Invalid action type found")
|
||||
@@ -518,9 +512,7 @@ class Primaite(Env):
|
||||
return
|
||||
elif property_action == 1:
|
||||
# Patch (valid action if it's good or compromised)
|
||||
node.set_service_state(
|
||||
self.services_list[service_index], SoftwareState.PATCHING
|
||||
)
|
||||
node.set_service_state(self.services_list[service_index], SoftwareState.PATCHING)
|
||||
else:
|
||||
# Node is not of Service Type
|
||||
return
|
||||
@@ -675,6 +667,9 @@ class Primaite(Env):
|
||||
"""
|
||||
self.obs_handler = ObservationsHandler.from_config(self, self.obs_config)
|
||||
|
||||
if not self._obs_space_description:
|
||||
self._obs_space_description = self.obs_handler.describe_structure()
|
||||
|
||||
return self.obs_handler.space, self.obs_handler.current_observation
|
||||
|
||||
def update_environent_obs(self):
|
||||
@@ -1225,11 +1220,7 @@ class Primaite(Env):
|
||||
|
||||
# Change node keys to not overlap with acl keys
|
||||
# Only 1 nothing action (key 0) is required, remove the other
|
||||
new_node_action_dict = {
|
||||
k + len(acl_action_dict) - 1: v
|
||||
for k, v in node_action_dict.items()
|
||||
if k != 0
|
||||
}
|
||||
new_node_action_dict = {k + len(acl_action_dict) - 1: v for k, v in node_action_dict.items() if k != 0}
|
||||
|
||||
# Combine the Node dict and ACL dict
|
||||
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
|
||||
@@ -1244,9 +1235,7 @@ class Primaite(Env):
|
||||
# Decide how many nodes become compromised
|
||||
node_list = list(self.nodes.values())
|
||||
computers = [node for node in node_list if node.node_type == NodeType.COMPUTER]
|
||||
max_num_nodes_compromised = len(
|
||||
computers
|
||||
) # only computers can become compromised
|
||||
max_num_nodes_compromised = len(computers) # only computers can become compromised
|
||||
# random select between 1 and max_num_nodes_compromised
|
||||
num_nodes_to_compromise = randint(1, max_num_nodes_compromised)
|
||||
|
||||
@@ -1257,9 +1246,7 @@ class Primaite(Env):
|
||||
source_node = choice(nodes_to_be_compromised)
|
||||
|
||||
# For each of the nodes to be compromised decide which step they become compromised
|
||||
max_step_compromised = (
|
||||
self.episode_steps // 2
|
||||
) # always compromise in first half of episode
|
||||
max_step_compromised = self.episode_steps // 2 # always compromise in first half of episode
|
||||
|
||||
# Bandwidth for all links
|
||||
bandwidths = [i.get_bandwidth() for i in list(self.links.values())]
|
||||
@@ -1308,9 +1295,7 @@ class Primaite(Env):
|
||||
ier_protocol = pol_service_name # Same protocol as compromised node
|
||||
ier_service = node.services[pol_service_name]
|
||||
ier_port = ier_service.port
|
||||
ier_mission_criticality = (
|
||||
0 # Red IER will never be important to green agent success
|
||||
)
|
||||
ier_mission_criticality = 0 # Red IER will never be important to green agent success
|
||||
# We choose a node to attack based on the first that applies:
|
||||
# a. Green IERs, select dest node of the red ier based on dest node of green IER
|
||||
# b. Attack a random server that doesn't have a DENY acl rule in default config
|
||||
|
||||
@@ -41,27 +41,19 @@ def calculate_reward_function(
|
||||
reference_node = reference_nodes[node_key]
|
||||
|
||||
# Hardware State
|
||||
reward_value += score_node_operating_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
)
|
||||
reward_value += score_node_operating_state(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# Software State
|
||||
if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode):
|
||||
reward_value += score_node_os_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
)
|
||||
reward_value += score_node_os_state(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# Service State
|
||||
if isinstance(final_node, ServiceNode):
|
||||
reward_value += score_node_service_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
)
|
||||
reward_value += score_node_service_state(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# File System State
|
||||
if isinstance(final_node, ActiveNode):
|
||||
reward_value += score_node_file_system(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
)
|
||||
reward_value += score_node_file_system(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# Go through each red IER - penalise if it is running
|
||||
for ier_key, ier_value in red_iers.items():
|
||||
@@ -80,14 +72,9 @@ def calculate_reward_function(
|
||||
if step_count >= start_step and step_count <= stop_step:
|
||||
reference_blocked = not reference_ier.get_is_running()
|
||||
live_blocked = not ier_value.get_is_running()
|
||||
ier_reward = (
|
||||
config_values.green_ier_blocked * ier_value.get_mission_criticality()
|
||||
)
|
||||
ier_reward = config_values.green_ier_blocked * ier_value.get_mission_criticality()
|
||||
|
||||
if live_blocked and not reference_blocked:
|
||||
_LOGGER.debug(
|
||||
f"Applying reward of {ier_reward} because IER {ier_key} is blocked"
|
||||
)
|
||||
reward_value += ier_reward
|
||||
elif live_blocked and reference_blocked:
|
||||
_LOGGER.debug(
|
||||
|
||||
@@ -1,338 +1,29 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
The main PrimAITE session runner module.
|
||||
|
||||
TODO: This will eventually be refactored out into a proper Session class.
|
||||
TODO: The passing about of session_dir and timestamp_str is temporary and
|
||||
will be cleaned up once we move to a proper Session class.
|
||||
"""
|
||||
"""The main PrimAITE session runner module."""
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Final, Union
|
||||
from uuid import uuid4
|
||||
from typing import Union
|
||||
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
from primaite import SESSIONS_DIR, getLogger
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.transactions.transactions_to_file import write_transaction_to_file
|
||||
from primaite import getLogger
|
||||
from primaite.primaite_session import PrimaiteSession
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run_generic(env: Primaite, config_values: TrainingConfig):
|
||||
"""
|
||||
Run against a generic agent.
|
||||
|
||||
:param env: An instance of
|
||||
:class:`~primaite.environment.primaite_env.Primaite`.
|
||||
:param config_values: An instance of
|
||||
:class:`~primaite.config.training_config.TrainingConfig`.
|
||||
"""
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
env.reset()
|
||||
for step in range(0, config_values.num_steps):
|
||||
# Send the observation space to the agent to get an action
|
||||
# TEMP - random action for now
|
||||
# action = env.blue_agent_action(obs)
|
||||
action = env.action_space.sample()
|
||||
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
# Break if done is True
|
||||
if done:
|
||||
break
|
||||
|
||||
# Introduce a delay between steps
|
||||
time.sleep(config_values.time_delay / 1000)
|
||||
|
||||
# Reset the environment at the end of the episode
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def run_stable_baselines3_ppo(
|
||||
env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str
|
||||
def run(
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
):
|
||||
"""
|
||||
Run against a stable_baselines3 PPO agent.
|
||||
|
||||
:param env: An instance of
|
||||
:class:`~primaite.environment.primaite_env.Primaite`.
|
||||
:param config_values: An instance of
|
||||
:class:`~primaite.config.training_config.TrainingConfig`.
|
||||
:param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
if config_values.load_agent:
|
||||
try:
|
||||
agent = PPO.load(
|
||||
config_values.agent_load_file,
|
||||
env,
|
||||
verbose=0,
|
||||
n_steps=config_values.num_steps,
|
||||
)
|
||||
except Exception:
|
||||
print(
|
||||
"ERROR: Could not load agent at location: "
|
||||
+ config_values.agent_load_file
|
||||
)
|
||||
_LOGGER.error("Could not load agent")
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
else:
|
||||
agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps)
|
||||
|
||||
if config_values.session_type == "TRAINING":
|
||||
# We're in a training session
|
||||
print("Starting training session...")
|
||||
_LOGGER.debug("Starting training session...")
|
||||
for episode in range(config_values.num_episodes):
|
||||
agent.learn(total_timesteps=config_values.num_steps)
|
||||
_save_agent(agent, session_path, timestamp_str)
|
||||
else:
|
||||
# Default to being in an evaluation session
|
||||
print("Starting evaluation session...")
|
||||
_LOGGER.debug("Starting evaluation session...")
|
||||
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def run_stable_baselines3_a2c(
|
||||
env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str
|
||||
):
|
||||
"""
|
||||
Run against a stable_baselines3 A2C agent.
|
||||
|
||||
:param env: An instance of
|
||||
:class:`~primaite.environment.primaite_env.Primaite`.
|
||||
:param config_values: An instance of
|
||||
:class:`~primaite.config.training_config.TrainingConfig`.
|
||||
param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
if config_values.load_agent:
|
||||
try:
|
||||
agent = A2C.load(
|
||||
config_values.agent_load_file,
|
||||
env,
|
||||
verbose=0,
|
||||
n_steps=config_values.num_steps,
|
||||
)
|
||||
except Exception:
|
||||
print(
|
||||
"ERROR: Could not load agent at location: "
|
||||
+ config_values.agent_load_file
|
||||
)
|
||||
_LOGGER.error("Could not load agent")
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
else:
|
||||
agent = A2C("MlpPolicy", env, verbose=0, n_steps=config_values.num_steps)
|
||||
|
||||
if config_values.session_type == "TRAINING":
|
||||
# We're in a training session
|
||||
print("Starting training session...")
|
||||
_LOGGER.debug("Starting training session...")
|
||||
for episode in range(config_values.num_episodes):
|
||||
agent.learn(total_timesteps=config_values.num_steps)
|
||||
_save_agent(agent, session_path, timestamp_str)
|
||||
else:
|
||||
# Default to being in an evaluation session
|
||||
print("Starting evaluation session...")
|
||||
_LOGGER.debug("Starting evaluation session...")
|
||||
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def _write_session_metadata_file(
|
||||
session_dir: Path, uuid: str, session_timestamp: datetime, env: Primaite
|
||||
):
|
||||
"""
|
||||
Write the ``session_metadata.json`` file.
|
||||
|
||||
Creates a ``session_metadata.json`` in the ``session_dir`` directory
|
||||
and adds the following key/value pairs:
|
||||
|
||||
- uuid: The UUID assigned to the session upon instantiation.
|
||||
- start_datetime: The date & time the session started in iso format.
|
||||
- end_datetime: NULL.
|
||||
- total_episodes: NULL.
|
||||
- total_time_steps: NULL.
|
||||
- env:
|
||||
- training_config:
|
||||
- All training config items
|
||||
- lay_down_config:
|
||||
- All lay down config items
|
||||
|
||||
"""
|
||||
metadata_dict = {
|
||||
"uuid": uuid,
|
||||
"start_datetime": session_timestamp.isoformat(),
|
||||
"end_datetime": None,
|
||||
"total_episodes": None,
|
||||
"total_time_steps": None,
|
||||
"env": {
|
||||
"training_config": env.training_config.to_dict(json_serializable=True),
|
||||
"lay_down_config": env.lay_down_config,
|
||||
},
|
||||
}
|
||||
filepath = session_dir / "session_metadata.json"
|
||||
_LOGGER.debug(f"Writing Session Metadata file: {filepath}")
|
||||
with open(filepath, "w") as file:
|
||||
json.dump(metadata_dict, file)
|
||||
|
||||
|
||||
def _update_session_metadata_file(session_dir: Path, env: Primaite):
|
||||
"""
|
||||
Update the ``session_metadata.json`` file.
|
||||
|
||||
Updates the `session_metadata.json`` in the ``session_dir`` directory
|
||||
with the following key/value pairs:
|
||||
|
||||
- end_datetime: NULL.
|
||||
- total_episodes: NULL.
|
||||
- total_time_steps: NULL.
|
||||
"""
|
||||
with open(session_dir / "session_metadata.json", "r") as file:
|
||||
metadata_dict = json.load(file)
|
||||
|
||||
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
metadata_dict["total_episodes"] = env.episode_count
|
||||
metadata_dict["total_time_steps"] = env.total_step_count
|
||||
|
||||
filepath = session_dir / "session_metadata.json"
|
||||
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||
with open(filepath, "w") as file:
|
||||
json.dump(metadata_dict, file)
|
||||
|
||||
|
||||
def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str):
|
||||
"""
|
||||
Persist an agent.
|
||||
|
||||
Only works for stable baselines3 agents at present.
|
||||
|
||||
:param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
if not isinstance(agent, OnPolicyAlgorithm):
|
||||
msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}."
|
||||
_LOGGER.error(msg)
|
||||
else:
|
||||
filepath = session_path / f"agent_saved_{timestamp_str}"
|
||||
agent.save(filepath)
|
||||
_LOGGER.debug(f"Trained agent saved as: {filepath}")
|
||||
|
||||
|
||||
def _get_session_path(session_timestamp: datetime) -> Path:
|
||||
"""
|
||||
Get the directory path the session will output to.
|
||||
|
||||
This is set in the format of:
|
||||
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
|
||||
:param session_timestamp: This is the datetime that the session started.
|
||||
:return: The session directory path.
|
||||
"""
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = SESSIONS_DIR / date_dir / session_dir
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
return session_path
|
||||
|
||||
|
||||
def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]):
|
||||
"""Run the PrimAITE Session.
|
||||
|
||||
:param training_config_path: The training config filepath.
|
||||
:param lay_down_config_path: The lay down config filepath.
|
||||
"""
|
||||
# Welcome message
|
||||
print("Welcome to the Primary-level AI Training Environment (PrimAITE)")
|
||||
uuid = str(uuid4())
|
||||
session_timestamp: Final[datetime] = datetime.now()
|
||||
session_dir = _get_session_path(session_timestamp)
|
||||
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session = PrimaiteSession(training_config_path, lay_down_config_path)
|
||||
|
||||
print(f"The output directory for this session is: {session_dir}")
|
||||
|
||||
# Create a list of transactions
|
||||
# A transaction is an object holding the:
|
||||
# - episode #
|
||||
# - step #
|
||||
# - initial observation space
|
||||
# - action
|
||||
# - reward
|
||||
# - new observation space
|
||||
transaction_list = []
|
||||
|
||||
# Create the Primaite environment
|
||||
env = Primaite(
|
||||
training_config_path=training_config_path,
|
||||
lay_down_config_path=lay_down_config_path,
|
||||
transaction_list=transaction_list,
|
||||
session_path=session_dir,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
|
||||
print("Writing Session Metadata file...")
|
||||
|
||||
_write_session_metadata_file(
|
||||
session_dir=session_dir, uuid=uuid, session_timestamp=session_timestamp, env=env
|
||||
)
|
||||
|
||||
config_values = env.training_config
|
||||
|
||||
# Get the number of steps (which is stored in the child config file)
|
||||
config_values.num_steps = env.episode_steps
|
||||
|
||||
# Run environment against an agent
|
||||
if config_values.agent_identifier == "GENERIC":
|
||||
run_generic(env=env, config_values=config_values)
|
||||
elif config_values.agent_identifier == "STABLE_BASELINES3_PPO":
|
||||
run_stable_baselines3_ppo(
|
||||
env=env,
|
||||
config_values=config_values,
|
||||
session_path=session_dir,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
elif config_values.agent_identifier == "STABLE_BASELINES3_A2C":
|
||||
run_stable_baselines3_a2c(
|
||||
env=env,
|
||||
config_values=config_values,
|
||||
session_path=session_dir,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
|
||||
print("Session finished")
|
||||
_LOGGER.debug("Session finished")
|
||||
|
||||
print("Saving transaction logs...")
|
||||
write_transaction_to_file(
|
||||
transaction_list=transaction_list,
|
||||
session_path=session_dir,
|
||||
timestamp_str=timestamp_str,
|
||||
obs_space_description=env.obs_handler.describe_structure(),
|
||||
)
|
||||
|
||||
print("Updating Session Metadata file...")
|
||||
_update_session_metadata_file(session_dir=session_dir, env=env)
|
||||
|
||||
print("Finished")
|
||||
_LOGGER.debug("Finished")
|
||||
session.setup()
|
||||
session.learn()
|
||||
session.evaluate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -341,11 +32,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--ldc")
|
||||
args = parser.parse_args()
|
||||
if not args.tc:
|
||||
_LOGGER.error(
|
||||
"Please provide a training config file using the --tc " "argument"
|
||||
)
|
||||
_LOGGER.error("Please provide a training config file using the --tc " "argument")
|
||||
if not args.ldc:
|
||||
_LOGGER.error(
|
||||
"Please provide a lay down config file using the --ldc " "argument"
|
||||
)
|
||||
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
|
||||
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||
|
||||
@@ -3,13 +3,7 @@
|
||||
import logging
|
||||
from typing import Final
|
||||
|
||||
from primaite.common.enums import (
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
NodeType,
|
||||
Priority,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.node import Node
|
||||
|
||||
@@ -44,9 +38,7 @@ class ActiveNode(Node):
|
||||
:param file_system_state: The node file system state
|
||||
:param config_values: The config values
|
||||
"""
|
||||
super().__init__(
|
||||
node_id, name, node_type, priority, hardware_state, config_values
|
||||
)
|
||||
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)
|
||||
self.ip_address: str = ip_address
|
||||
# Related to Software
|
||||
self._software_state: SoftwareState = software_state
|
||||
@@ -125,14 +117,10 @@ class ActiveNode(Node):
|
||||
self.file_system_state_actual = file_system_state
|
||||
|
||||
if file_system_state == FileSystemState.REPAIRING:
|
||||
self.file_system_action_count = (
|
||||
self.config_values.file_system_repairing_limit
|
||||
)
|
||||
self.file_system_action_count = self.config_values.file_system_repairing_limit
|
||||
self.file_system_state_observed = FileSystemState.REPAIRING
|
||||
elif file_system_state == FileSystemState.RESTORING:
|
||||
self.file_system_action_count = (
|
||||
self.config_values.file_system_restoring_limit
|
||||
)
|
||||
self.file_system_action_count = self.config_values.file_system_restoring_limit
|
||||
self.file_system_state_observed = FileSystemState.RESTORING
|
||||
elif file_system_state == FileSystemState.GOOD:
|
||||
self.file_system_state_observed = FileSystemState.GOOD
|
||||
@@ -145,9 +133,7 @@ class ActiveNode(Node):
|
||||
f"Node.file_system_state.actual:{self.file_system_state_actual}"
|
||||
)
|
||||
|
||||
def set_file_system_state_if_not_compromised(
|
||||
self, file_system_state: FileSystemState
|
||||
):
|
||||
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState):
|
||||
"""
|
||||
Sets the file system state (actual and observed) if not in a compromised state.
|
||||
|
||||
@@ -164,14 +150,10 @@ class ActiveNode(Node):
|
||||
self.file_system_state_actual = file_system_state
|
||||
|
||||
if file_system_state == FileSystemState.REPAIRING:
|
||||
self.file_system_action_count = (
|
||||
self.config_values.file_system_repairing_limit
|
||||
)
|
||||
self.file_system_action_count = self.config_values.file_system_repairing_limit
|
||||
self.file_system_state_observed = FileSystemState.REPAIRING
|
||||
elif file_system_state == FileSystemState.RESTORING:
|
||||
self.file_system_action_count = (
|
||||
self.config_values.file_system_restoring_limit
|
||||
)
|
||||
self.file_system_action_count = self.config_values.file_system_restoring_limit
|
||||
self.file_system_state_observed = FileSystemState.RESTORING
|
||||
elif file_system_state == FileSystemState.GOOD:
|
||||
self.file_system_state_observed = FileSystemState.GOOD
|
||||
|
||||
@@ -28,9 +28,7 @@ class PassiveNode(Node):
|
||||
:param config_values: Config values.
|
||||
"""
|
||||
# Pass through to Super for now
|
||||
super().__init__(
|
||||
node_id, name, node_type, priority, hardware_state, config_values
|
||||
)
|
||||
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)
|
||||
|
||||
@property
|
||||
def ip_address(self) -> str:
|
||||
|
||||
@@ -3,13 +3,7 @@
|
||||
import logging
|
||||
from typing import Dict, Final
|
||||
|
||||
from primaite.common.enums import (
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
NodeType,
|
||||
Priority,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
|
||||
from primaite.common.service import Service
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
@@ -128,9 +122,7 @@ class ServiceNode(ActiveNode):
|
||||
) or software_state != SoftwareState.COMPROMISED:
|
||||
service_value.software_state = software_state
|
||||
if software_state == SoftwareState.PATCHING:
|
||||
service_value.patching_count = (
|
||||
self.config_values.service_patching_duration
|
||||
)
|
||||
service_value.patching_count = self.config_values.service_patching_duration
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so the state of a service "
|
||||
@@ -141,9 +133,7 @@ class ServiceNode(ActiveNode):
|
||||
f"Node.services[<key>].software_state:{software_state}"
|
||||
)
|
||||
|
||||
def set_service_state_if_not_compromised(
|
||||
self, protocol_name: str, software_state: SoftwareState
|
||||
):
|
||||
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState):
|
||||
"""
|
||||
Sets the software_state of a service (protocol) on the node.
|
||||
|
||||
@@ -159,9 +149,7 @@ class ServiceNode(ActiveNode):
|
||||
if service_value.software_state != SoftwareState.COMPROMISED:
|
||||
service_value.software_state = software_state
|
||||
if software_state == SoftwareState.PATCHING:
|
||||
service_value.patching_count = (
|
||||
self.config_values.service_patching_duration
|
||||
)
|
||||
service_value.patching_count = self.config_values.service_patching_duration
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so the state of a service "
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from primaite import NOTEBOOKS_DIR, getLogger
|
||||
from primaite import getLogger, NOTEBOOKS_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -86,9 +86,7 @@ def apply_iers(
|
||||
and source_node.software_state != SoftwareState.PATCHING
|
||||
):
|
||||
if source_node.has_service(protocol):
|
||||
if source_node.service_running(
|
||||
protocol
|
||||
) and not source_node.service_is_overwhelmed(protocol):
|
||||
if source_node.service_running(protocol) and not source_node.service_is_overwhelmed(protocol):
|
||||
source_valid = True
|
||||
else:
|
||||
source_valid = False
|
||||
@@ -103,10 +101,7 @@ def apply_iers(
|
||||
# 2. Check the dest node situation
|
||||
if dest_node.node_type == NodeType.SWITCH:
|
||||
# It's a switch
|
||||
if (
|
||||
dest_node.hardware_state == HardwareState.ON
|
||||
and dest_node.software_state != SoftwareState.PATCHING
|
||||
):
|
||||
if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING:
|
||||
dest_valid = True
|
||||
else:
|
||||
# IER no longer valid
|
||||
@@ -116,14 +111,9 @@ def apply_iers(
|
||||
pass
|
||||
else:
|
||||
# It's not a switch or an actuator (so active node)
|
||||
if (
|
||||
dest_node.hardware_state == HardwareState.ON
|
||||
and dest_node.software_state != SoftwareState.PATCHING
|
||||
):
|
||||
if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING:
|
||||
if dest_node.has_service(protocol):
|
||||
if dest_node.service_running(
|
||||
protocol
|
||||
) and not dest_node.service_is_overwhelmed(protocol):
|
||||
if dest_node.service_running(protocol) and not dest_node.service_is_overwhelmed(protocol):
|
||||
dest_valid = True
|
||||
else:
|
||||
dest_valid = False
|
||||
@@ -136,9 +126,7 @@ def apply_iers(
|
||||
dest_valid = False
|
||||
|
||||
# 3. Check that the ACL doesn't block it
|
||||
acl_block = acl.is_blocked(
|
||||
source_node.ip_address, dest_node.ip_address, protocol, port
|
||||
)
|
||||
acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port)
|
||||
if acl_block:
|
||||
if _VERBOSE:
|
||||
print(
|
||||
@@ -169,10 +157,7 @@ def apply_iers(
|
||||
|
||||
# We might have a switch in the path, so check all nodes are operational
|
||||
for node in path_node_list:
|
||||
if (
|
||||
node.hardware_state != HardwareState.ON
|
||||
or node.software_state == SoftwareState.PATCHING
|
||||
):
|
||||
if node.hardware_state != HardwareState.ON or node.software_state == SoftwareState.PATCHING:
|
||||
path_valid = False
|
||||
|
||||
if path_valid:
|
||||
@@ -184,9 +169,7 @@ def apply_iers(
|
||||
# Check that the link capacity is not exceeded by the new load
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count], path_node_list[count + 1]
|
||||
)
|
||||
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1])
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Check whether the new load exceeds the bandwidth
|
||||
@@ -204,7 +187,8 @@ def apply_iers(
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count], path_node_list[count + 1]
|
||||
path_node_list[count],
|
||||
path_node_list[count + 1],
|
||||
)
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
|
||||
@@ -6,13 +6,7 @@ from networkx import MultiGraph, shortest_path
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import (
|
||||
HardwareState,
|
||||
NodePOLInitiator,
|
||||
NodePOLType,
|
||||
NodeType,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.enums import HardwareState, NodePOLInitiator, NodePOLType, NodeType, SoftwareState
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
@@ -83,10 +77,7 @@ def apply_red_agent_iers(
|
||||
if source_node.hardware_state == HardwareState.ON:
|
||||
if source_node.has_service(protocol):
|
||||
# Red agents IERs can only be valid if the source service is in a compromised state
|
||||
if (
|
||||
source_node.get_service_state(protocol)
|
||||
== SoftwareState.COMPROMISED
|
||||
):
|
||||
if source_node.get_service_state(protocol) == SoftwareState.COMPROMISED:
|
||||
source_valid = True
|
||||
else:
|
||||
source_valid = False
|
||||
@@ -124,9 +115,7 @@ def apply_red_agent_iers(
|
||||
dest_valid = False
|
||||
|
||||
# 3. Check that the ACL doesn't block it
|
||||
acl_block = acl.is_blocked(
|
||||
source_node.ip_address, dest_node.ip_address, protocol, port
|
||||
)
|
||||
acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port)
|
||||
if acl_block:
|
||||
if _VERBOSE:
|
||||
print(
|
||||
@@ -170,9 +159,7 @@ def apply_red_agent_iers(
|
||||
# Check that the link capacity is not exceeded by the new load
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count], path_node_list[count + 1]
|
||||
)
|
||||
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1])
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Check whether the new load exceeds the bandwidth
|
||||
@@ -190,7 +177,8 @@ def apply_red_agent_iers(
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count], path_node_list[count + 1]
|
||||
path_node_list[count],
|
||||
path_node_list[count + 1],
|
||||
)
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
@@ -248,9 +236,7 @@ def apply_red_agent_node_pol(
|
||||
state = node_instruction.get_state()
|
||||
source_node_id = node_instruction.get_source_node_id()
|
||||
source_node_service_name = node_instruction.get_source_node_service()
|
||||
source_node_service_state_value = (
|
||||
node_instruction.get_source_node_service_state()
|
||||
)
|
||||
source_node_service_state_value = node_instruction.get_source_node_service_state()
|
||||
|
||||
passed_checks = False
|
||||
|
||||
@@ -292,9 +278,7 @@ def apply_red_agent_node_pol(
|
||||
target_node.hardware_state = state
|
||||
elif pol_type == NodePOLType.OS:
|
||||
# Change OS state
|
||||
if isinstance(target_node, ActiveNode) or isinstance(
|
||||
target_node, ServiceNode
|
||||
):
|
||||
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
|
||||
target_node.software_state = state
|
||||
elif pol_type == NodePOLType.SERVICE:
|
||||
# Change a service state
|
||||
@@ -302,9 +286,7 @@ def apply_red_agent_node_pol(
|
||||
target_node.set_service_state(service_name, state)
|
||||
else:
|
||||
# Change the file system status
|
||||
if isinstance(target_node, ActiveNode) or isinstance(
|
||||
target_node, ServiceNode
|
||||
):
|
||||
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
|
||||
target_node.set_file_system_state(state)
|
||||
else:
|
||||
if _VERBOSE:
|
||||
|
||||
150
src/primaite/primaite_session.py
Normal file
150
src/primaite/primaite_session.py
Normal file
@@ -0,0 +1,150 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, Final, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.agents.hardcoded_acl import HardCodedACLAgent
|
||||
from primaite.agents.hardcoded_node import HardCodedNodeAgent
|
||||
from primaite.agents.rllib import RLlibAgent
|
||||
from primaite.agents.sb3 import SB3Agent
|
||||
from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyAgent, RandomAgent
|
||||
from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType
|
||||
from primaite.config import lay_down_config, training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class PrimaiteSession:
|
||||
"""
|
||||
The PrimaiteSession class.
|
||||
|
||||
Provides a single learning and evaluation entry point for all training
|
||||
and lay down configurations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
):
|
||||
"""
|
||||
The PrimaiteSession constructor.
|
||||
|
||||
:param training_config_path: The training config path.
|
||||
:param lay_down_config_path: The lay down config path.
|
||||
"""
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Final[Union[Path]] = training_config_path
|
||||
self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path)
|
||||
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
|
||||
self._agent_session: AgentSessionABC = None # noqa
|
||||
self.session_path: Path = None # noqa
|
||||
self.timestamp_str: str = None # noqa
|
||||
self.learning_path: Path = None # noqa
|
||||
self.evaluation_path: Path = None # noqa
|
||||
|
||||
def setup(self):
|
||||
"""Performs the session setup."""
|
||||
if self._training_config.agent_framework == AgentFramework.CUSTOM:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}")
|
||||
if self._training_config.agent_identifier == AgentIdentifier.HARDCODED:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}")
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
# Deterministic Hardcoded Agent with Node Action Space
|
||||
self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
self._agent_session = HardCodedACLAgent(self._training_config_path, self._lay_down_config_path)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
# Deterministic Hardcoded Agent with ANY Action Space
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
# Invalid AgentIdentifier ActionType combo
|
||||
raise ValueError
|
||||
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHING}")
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
self._agent_session = DoNothingNodeAgent(self._training_config_path, self._lay_down_config_path)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
self._agent_session = DoNothingACLAgent(self._training_config_path, self._lay_down_config_path)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
# Deterministic Hardcoded Agent with ANY Action Space
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
# Invalid AgentIdentifier ActionType combo
|
||||
raise ValueError
|
||||
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.RANDOM:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}")
|
||||
self._agent_session = RandomAgent(self._training_config_path, self._lay_down_config_path)
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.DUMMY:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}")
|
||||
self._agent_session = DummyAgent(self._training_config_path, self._lay_down_config_path)
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework AgentIdentifier combo
|
||||
raise ValueError
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.SB3:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}")
|
||||
# Stable Baselines3 Agent
|
||||
self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path)
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}")
|
||||
# Ray RLlib Agent
|
||||
self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path)
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework
|
||||
raise ValueError
|
||||
|
||||
self.session_path: Path = self._agent_session.session_path
|
||||
self.timestamp_str: str = self._agent_session.timestamp_str
|
||||
self.learning_path: Path = self._agent_session.learning_path
|
||||
self.evaluation_path: Path = self._agent_session.evaluation_path
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-framework specific key word args.
|
||||
"""
|
||||
if not self._training_config.session_type == SessionType.EVAL:
|
||||
self._agent_session.learn(**kwargs)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-framework specific key word args.
|
||||
"""
|
||||
if not self._training_config.session_type == SessionType.TRAIN:
|
||||
self._agent_session.evaluate(**kwargs)
|
||||
|
||||
def close(self):
|
||||
"""Closes the agent."""
|
||||
self._agent_session.close()
|
||||
@@ -1,5 +1,22 @@
|
||||
# The main PrimAITE application config file
|
||||
|
||||
# Logging
|
||||
log_level: INFO
|
||||
logger_format: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
logging:
|
||||
log_level: INFO
|
||||
logger_format:
|
||||
DEBUG: '%(asctime)s: %(message)s'
|
||||
INFO: '%(asctime)s: %(message)s'
|
||||
WARNING: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
ERROR: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
CRITICAL: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
|
||||
# Session
|
||||
session:
|
||||
outputs:
|
||||
plots:
|
||||
size:
|
||||
auto_size: false
|
||||
width: 1500
|
||||
height: 900
|
||||
template: plotly_white
|
||||
range_slider: false
|
||||
|
||||
@@ -6,7 +6,7 @@ from pathlib import Path
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from primaite import NOTEBOOKS_DIR, getLogger
|
||||
from primaite import getLogger, NOTEBOOKS_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -18,9 +18,7 @@ def run(overwrite_existing: bool = True):
|
||||
:param overwrite_existing: A bool to toggle replacing existing edited
|
||||
notebooks on or off.
|
||||
"""
|
||||
notebooks_package_data_root = pkg_resources.resource_filename(
|
||||
"primaite", "notebooks/_package_data"
|
||||
)
|
||||
notebooks_package_data_root = pkg_resources.resource_filename("primaite", "notebooks/_package_data")
|
||||
for subdir, dirs, files in os.walk(notebooks_package_data_root):
|
||||
for file in files:
|
||||
fp = os.path.join(subdir, file)
|
||||
@@ -30,9 +28,7 @@ def run(overwrite_existing: bool = True):
|
||||
copy_file = not target_fp.is_file()
|
||||
|
||||
if overwrite_existing and not copy_file:
|
||||
copy_file = (not filecmp.cmp(fp, target_fp)) and (
|
||||
".ipynb_checkpoints" not in str(target_fp)
|
||||
)
|
||||
copy_file = (not filecmp.cmp(fp, target_fp)) and (".ipynb_checkpoints" not in str(target_fp))
|
||||
|
||||
if copy_file:
|
||||
shutil.copy2(fp, target_fp)
|
||||
|
||||
@@ -5,7 +5,7 @@ from pathlib import Path
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from primaite import USERS_CONFIG_DIR, getLogger
|
||||
from primaite import getLogger, USERS_CONFIG_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -17,9 +17,7 @@ def run(overwrite_existing=True):
|
||||
:param overwrite_existing: A bool to toggle replacing existing edited
|
||||
config on or off.
|
||||
"""
|
||||
configs_package_data_root = pkg_resources.resource_filename(
|
||||
"primaite", "config/_package_data"
|
||||
)
|
||||
configs_package_data_root = pkg_resources.resource_filename("primaite", "config/_package_data")
|
||||
|
||||
for subdir, dirs, files in os.walk(configs_package_data_root):
|
||||
for file in files:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
from primaite import _USER_DIRS, LOG_DIR, NOTEBOOKS_DIR, getLogger
|
||||
from primaite import _USER_DIRS, getLogger, LOG_DIR, NOTEBOOKS_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,48 +1,99 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""The Transaction class."""
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple
|
||||
|
||||
from primaite.common.enums import AgentIdentifier
|
||||
|
||||
|
||||
class Transaction(object):
|
||||
"""Transaction class."""
|
||||
|
||||
def __init__(self, _timestamp, _agent_identifier, _episode_number, _step_number):
|
||||
def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int):
|
||||
"""
|
||||
Init.
|
||||
Transaction constructor.
|
||||
|
||||
Args:
|
||||
_timestamp: The time this object was created
|
||||
_agent_identifier: An identifier for the agent in use
|
||||
_episode_number: The episode number
|
||||
_step_number: The step number
|
||||
:param agent_identifier: An identifier for the agent in use
|
||||
:param episode_number: The episode number
|
||||
:param step_number: The step number
|
||||
"""
|
||||
self.timestamp = _timestamp
|
||||
self.agent_identifier = _agent_identifier
|
||||
self.episode_number = _episode_number
|
||||
self.step_number = _step_number
|
||||
self.timestamp = datetime.now()
|
||||
"The datetime of the transaction"
|
||||
self.agent_identifier: AgentIdentifier = agent_identifier
|
||||
"The agent identifier"
|
||||
self.episode_number: int = episode_number
|
||||
"The episode number"
|
||||
self.step_number: int = step_number
|
||||
"The step number"
|
||||
self.obs_space = None
|
||||
"The observation space (pre)"
|
||||
self.obs_space_pre = None
|
||||
"The observation space before any actions are taken"
|
||||
self.obs_space_post = None
|
||||
"The observation space after any actions are taken"
|
||||
self.reward = None
|
||||
"The reward value"
|
||||
self.action_space = None
|
||||
"The action space invoked by the agent"
|
||||
self.obs_space_description = None
|
||||
"The env observation space description"
|
||||
|
||||
def set_obs_space(self, _obs_space):
|
||||
def as_csv_data(self) -> Tuple[List, List]:
|
||||
"""
|
||||
Sets the observation space (pre).
|
||||
Converts the Transaction to a csv data row and provides a header.
|
||||
|
||||
Args:
|
||||
_obs_space_pre: The observation space before any actions are taken
|
||||
:return: A tuple consisting of (header, data).
|
||||
"""
|
||||
self.obs_space = _obs_space
|
||||
if isinstance(self.action_space, int):
|
||||
action_length = self.action_space
|
||||
else:
|
||||
action_length = self.action_space.size
|
||||
|
||||
def set_reward(self, _reward):
|
||||
"""
|
||||
Sets the reward.
|
||||
# Create the action space headers array
|
||||
action_header = []
|
||||
for x in range(action_length):
|
||||
action_header.append("AS_" + str(x))
|
||||
|
||||
Args:
|
||||
_reward: The reward value
|
||||
"""
|
||||
self.reward = _reward
|
||||
# Open up a csv file
|
||||
header = ["Timestamp", "Episode", "Step", "Reward"]
|
||||
header = header + action_header + self.obs_space_description
|
||||
|
||||
def set_action_space(self, _action_space):
|
||||
"""
|
||||
Sets the action space.
|
||||
row = [
|
||||
str(self.timestamp),
|
||||
str(self.episode_number),
|
||||
str(self.step_number),
|
||||
str(self.reward),
|
||||
]
|
||||
row = row + _turn_action_space_to_array(self.action_space) + self.obs_space.tolist()
|
||||
return header, row
|
||||
|
||||
Args:
|
||||
_action_space: The action space invoked by the agent
|
||||
"""
|
||||
self.action_space = _action_space
|
||||
|
||||
def _turn_action_space_to_array(action_space) -> List[str]:
|
||||
"""
|
||||
Turns action space into a string array so it can be saved to csv.
|
||||
|
||||
:param action_space: The action space
|
||||
:return: The action space as an array of strings
|
||||
"""
|
||||
if isinstance(action_space, list):
|
||||
return [str(i) for i in action_space]
|
||||
else:
|
||||
return [str(action_space)]
|
||||
|
||||
|
||||
def _turn_obs_space_to_array(obs_space, obs_assets, obs_features) -> List[str]:
|
||||
"""
|
||||
Turns observation space into a string array so it can be saved to csv.
|
||||
|
||||
:param obs_space: The observation space
|
||||
:param obs_assets: The number of assets (i.e. nodes or links) in the
|
||||
observation space
|
||||
:param obs_features: The number of features associated with the asset
|
||||
:return: The observation space as an array of strings
|
||||
"""
|
||||
return_array = []
|
||||
for x in range(obs_assets):
|
||||
for y in range(obs_features):
|
||||
return_array.append(str(obs_space[x][y]))
|
||||
|
||||
return return_array
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""Writes the Transaction log list out to file for evaluation to utilse."""
|
||||
|
||||
import csv
|
||||
from pathlib import Path
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def turn_action_space_to_array(_action_space):
|
||||
"""
|
||||
Turns action space into a string array so it can be saved to csv.
|
||||
|
||||
Args:
|
||||
_action_space: The action space.
|
||||
"""
|
||||
if isinstance(_action_space, list):
|
||||
return [str(i) for i in _action_space]
|
||||
else:
|
||||
return [str(_action_space)]
|
||||
|
||||
|
||||
def write_transaction_to_file(
|
||||
transaction_list,
|
||||
session_path: Path,
|
||||
timestamp_str: str,
|
||||
obs_space_description: list,
|
||||
):
|
||||
"""
|
||||
Writes transaction logs to file to support training evaluation.
|
||||
|
||||
:param transaction_list: The list of transactions from all steps and all
|
||||
episodes.
|
||||
:param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
# Get the first transaction and use it to determine the makeup of the
|
||||
# observation space and action space
|
||||
# Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action
|
||||
# space as "AS_1"
|
||||
# This will be tied into the PrimAITE Use Case so that they make sense
|
||||
template_transation = transaction_list[0]
|
||||
action_length = template_transation.action_space.size
|
||||
# obs_shape = template_transation.obs_space_post.shape
|
||||
# obs_assets = template_transation.obs_space_post.shape[0]
|
||||
# if len(obs_shape) == 1:
|
||||
# bit of a workaround but I think the way transactions are written will change soon
|
||||
# obs_features = 1
|
||||
# else:
|
||||
# obs_features = template_transation.obs_space_post.shape[1]
|
||||
|
||||
# Create the action space headers array
|
||||
action_header = []
|
||||
for x in range(action_length):
|
||||
action_header.append("AS_" + str(x))
|
||||
|
||||
# Create the observation space headers array
|
||||
# obs_header_initial = [f"pre_{o}" for o in obs_space_description]
|
||||
# obs_header_new = [f"post_{o}" for o in obs_space_description]
|
||||
|
||||
# Open up a csv file
|
||||
header = ["Timestamp", "Episode", "Step", "Reward"]
|
||||
header = header + action_header + obs_space_description
|
||||
|
||||
try:
|
||||
filename = session_path / f"all_transactions_{timestamp_str}.csv"
|
||||
_LOGGER.debug(f"Saving transaction logs: {filename}")
|
||||
csv_file = open(filename, "w", encoding="UTF8", newline="")
|
||||
csv_writer = csv.writer(csv_file)
|
||||
csv_writer.writerow(header)
|
||||
|
||||
for transaction in transaction_list:
|
||||
csv_data = [
|
||||
str(transaction.timestamp),
|
||||
str(transaction.episode_number),
|
||||
str(transaction.step_number),
|
||||
str(transaction.reward),
|
||||
]
|
||||
csv_data = (
|
||||
csv_data
|
||||
+ turn_action_space_to_array(transaction.action_space)
|
||||
+ transaction.obs_space.tolist()
|
||||
)
|
||||
csv_writer.writerow(csv_data)
|
||||
|
||||
csv_file.close()
|
||||
except Exception:
|
||||
_LOGGER.error("Could not save the transaction file", exc_info=True)
|
||||
20
src/primaite/utils/session_output_reader.py
Normal file
20
src/primaite/utils/session_output_reader.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
# Using polars as it's faster than Pandas; it will speed things up when
|
||||
# files get big!
|
||||
import polars as pl
|
||||
|
||||
|
||||
def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
|
||||
"""
|
||||
Read an average rewards per episode csv file and return as a dict.
|
||||
|
||||
The dictionary keys are the episode number, and the values are the mean
|
||||
reward that episode.
|
||||
|
||||
:param av_rewards_csv_file: The average rewards per episode csv file path.
|
||||
:return: The average rewards per episode cdv as a dict.
|
||||
"""
|
||||
d = pl.read_csv(av_rewards_csv_file).to_dict()
|
||||
return {v: d["Average Reward"][i] for i, v in enumerate(d["Episode"])}
|
||||
83
src/primaite/utils/session_output_writer.py
Normal file
83
src/primaite/utils/session_output_writer.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import csv
|
||||
from logging import Logger
|
||||
from typing import Final, List, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.transactions.transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
class SessionOutputWriter:
|
||||
"""
|
||||
A session output writer class.
|
||||
|
||||
Is used to write session outputs to csv file.
|
||||
"""
|
||||
|
||||
_AV_REWARD_PER_EPISODE_HEADER: Final[List[str]] = [
|
||||
"Episode",
|
||||
"Average Reward",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: "Primaite",
|
||||
transaction_writer: bool = False,
|
||||
learning_session: bool = True,
|
||||
):
|
||||
self._env = env
|
||||
self.transaction_writer = transaction_writer
|
||||
self.learning_session = learning_session
|
||||
|
||||
if self.transaction_writer:
|
||||
fn = f"all_transactions_{self._env.timestamp_str}.csv"
|
||||
else:
|
||||
fn = f"average_reward_per_episode_{self._env.timestamp_str}.csv"
|
||||
|
||||
if self.learning_session:
|
||||
self._csv_file_path = self._env.session_path / "learning" / fn
|
||||
else:
|
||||
self._csv_file_path = self._env.session_path / "evaluation" / fn
|
||||
|
||||
self._csv_file_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
self._csv_file = None
|
||||
self._csv_writer = None
|
||||
|
||||
self._first_write: bool = True
|
||||
|
||||
def _init_csv_writer(self):
|
||||
self._csv_file = open(self._csv_file_path, "w", encoding="UTF8", newline="")
|
||||
|
||||
self._csv_writer = csv.writer(self._csv_file)
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
"""Close the cvs file."""
|
||||
if self._csv_file:
|
||||
self._csv_file.close()
|
||||
_LOGGER.debug(f"Finished writing file: {self._csv_file_path}")
|
||||
|
||||
def write(self, data: Union[Tuple, Transaction]):
|
||||
"""
|
||||
Write a row of session data.
|
||||
|
||||
:param data: The row of data to write. Can be a Tuple or an instance
|
||||
of Transaction.
|
||||
"""
|
||||
if isinstance(data, Transaction):
|
||||
header, data = data.as_csv_data()
|
||||
else:
|
||||
header = self._AV_REWARD_PER_EPISODE_HEADER
|
||||
|
||||
if self._first_write:
|
||||
self._init_csv_writer()
|
||||
self._csv_writer.writerow(header)
|
||||
self._first_write = False
|
||||
self._csv_writer.writerow(data)
|
||||
@@ -1,11 +1,20 @@
|
||||
# Main Config File
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agent_identifier: STABLE_BASELINES3_A2C
|
||||
|
||||
# Sets which agent algorithm framework will be used:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray[RLlib])
|
||||
# "NONE" (Custom Agent)
|
||||
agent_framework: SB3
|
||||
|
||||
# Sets which Red Agent algo/class will be used:
|
||||
# "PPO" (Proximal Policy Optimization)
|
||||
# "A2C" (Advantage Actor Critic)
|
||||
# "HARDCODED" (Custom Agent)
|
||||
# "RANDOM" (Random Action)
|
||||
agent_identifier: PPO
|
||||
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
@@ -18,7 +27,7 @@ num_steps: 256
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 10
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAINING
|
||||
session_type: TRAIN
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
@@ -1,11 +1,22 @@
|
||||
# Main Config File
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: SB3
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: A2C
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agent_identifier: STABLE_BASELINES3_A2C
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
@@ -28,7 +39,7 @@ observation_space:
|
||||
time_delay: 1
|
||||
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAINING
|
||||
session_type: TRAIN
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
|
||||
@@ -1,11 +1,22 @@
|
||||
# Main Config File
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: CUSTOM
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: RANDOM
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agent_identifier: NONE
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
@@ -24,7 +35,7 @@ observation_space:
|
||||
time_delay: 1
|
||||
# Filename of the scenario / laydown
|
||||
|
||||
session_type: TRAINING
|
||||
session_type: TRAIN
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
|
||||
@@ -1,11 +1,22 @@
|
||||
# Main Config File
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: CUSTOM
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: RANDOM
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agent_identifier: NONE
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
@@ -25,7 +36,7 @@ observation_space:
|
||||
time_delay: 1
|
||||
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAINING
|
||||
session_type: TRAIN
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
|
||||
@@ -1,11 +1,22 @@
|
||||
# Main Config File
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: CUSTOM
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: RANDOM
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agent_identifier: NONE
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
@@ -18,7 +29,7 @@ num_steps: 5
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 1
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAINING
|
||||
session_type: TRAIN
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
|
||||
@@ -1,10 +1,22 @@
|
||||
# Main Config File
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: CUSTOM
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: DUMMY
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
agent_identifier: GENERIC
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
@@ -18,7 +30,7 @@ num_steps: 15
|
||||
time_delay: 1
|
||||
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAINING
|
||||
session_type: EVAL
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
|
||||
@@ -1,11 +1,22 @@
|
||||
# Main Config File
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: CUSTOM
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: RANDOM
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agent_identifier: GENERIC
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
@@ -18,7 +29,7 @@ num_steps: 15
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 1
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAINING
|
||||
session_type: EVAL
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
|
||||
@@ -1,11 +1,22 @@
|
||||
# Main Config File
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: CUSTOM
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: RANDOM
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agent_identifier: GENERIC
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
@@ -18,7 +29,7 @@ num_steps: 5
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 1
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAINING
|
||||
session_type: EVAL
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
|
||||
112
tests/config/test_random_red_main_config.yaml
Normal file
112
tests/config/test_random_red_main_config.yaml
Normal file
@@ -0,0 +1,112 @@
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: CUSTOM
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: DUMMY
|
||||
|
||||
# Sets whether Red Agent POL and IER is randomised.
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
random_red_agent: True
|
||||
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: NODE
|
||||
# Number of episodes to run per session
|
||||
num_episodes: 2
|
||||
# Number of time_steps per episode
|
||||
num_steps: 15
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 1
|
||||
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: EVAL
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0
|
||||
# Node Hardware State
|
||||
off_should_be_on: -10
|
||||
off_should_be_resetting: -5
|
||||
on_should_be_off: -2
|
||||
on_should_be_resetting: -5
|
||||
resetting_should_be_on: -5
|
||||
resetting_should_be_off: -2
|
||||
resetting: -3
|
||||
# Node Software or Service State
|
||||
good_should_be_patching: 2
|
||||
good_should_be_compromised: 5
|
||||
good_should_be_overwhelmed: 5
|
||||
patching_should_be_good: -5
|
||||
patching_should_be_compromised: 2
|
||||
patching_should_be_overwhelmed: 2
|
||||
patching: -3
|
||||
compromised_should_be_good: -20
|
||||
compromised_should_be_patching: -20
|
||||
compromised_should_be_overwhelmed: -20
|
||||
compromised: -20
|
||||
overwhelmed_should_be_good: -20
|
||||
overwhelmed_should_be_patching: -20
|
||||
overwhelmed_should_be_compromised: -20
|
||||
overwhelmed: -20
|
||||
# Node File System State
|
||||
good_should_be_repairing: 2
|
||||
good_should_be_restoring: 2
|
||||
good_should_be_corrupt: 5
|
||||
good_should_be_destroyed: 10
|
||||
repairing_should_be_good: -5
|
||||
repairing_should_be_restoring: 2
|
||||
repairing_should_be_corrupt: 2
|
||||
repairing_should_be_destroyed: 0
|
||||
repairing: -3
|
||||
restoring_should_be_good: -10
|
||||
restoring_should_be_repairing: -2
|
||||
restoring_should_be_corrupt: 1
|
||||
restoring_should_be_destroyed: 2
|
||||
restoring: -6
|
||||
corrupt_should_be_good: -10
|
||||
corrupt_should_be_repairing: -10
|
||||
corrupt_should_be_restoring: -10
|
||||
corrupt_should_be_destroyed: 2
|
||||
corrupt: -10
|
||||
destroyed_should_be_good: -20
|
||||
destroyed_should_be_repairing: -20
|
||||
destroyed_should_be_restoring: -20
|
||||
destroyed_should_be_corrupt: -20
|
||||
destroyed: -20
|
||||
scanning: -2
|
||||
# IER status
|
||||
red_ier_running: -5
|
||||
green_ier_blocked: -10
|
||||
|
||||
# Patching / Reset durations
|
||||
os_patching_duration: 5 # The time taken to patch the OS
|
||||
node_reset_duration: 5 # The time taken to reset a node (hardware)
|
||||
service_patching_duration: 5 # The time taken to patch a service
|
||||
file_system_repairing_limit: 5 # The time take to repair the file system
|
||||
file_system_restoring_limit: 5 # The time take to restore the file system
|
||||
file_system_scanning_limit: 5 # The time taken to scan the file system
|
||||
@@ -1,43 +1,150 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
import datetime
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Dict, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.common.enums import AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.primaite_session import PrimaiteSession
|
||||
from primaite.utils.session_output_reader import av_rewards_dict
|
||||
from tests.mock_and_patch.get_session_path_mock import get_temp_session_path
|
||||
|
||||
ACTION_SPACE_NODE_VALUES = 1
|
||||
ACTION_SPACE_NODE_ACTION_VALUES = 1
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
def _get_temp_session_path(session_timestamp: datetime) -> Path:
|
||||
|
||||
class TempPrimaiteSession(PrimaiteSession):
|
||||
"""
|
||||
A temporary PrimaiteSession class.
|
||||
|
||||
Uses context manager for deletion of files upon exit.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
):
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
self.setup()
|
||||
|
||||
def learn_av_reward_per_episode(self) -> Dict[int, float]:
|
||||
"""Get the learn av reward per episode from file."""
|
||||
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
|
||||
return av_rewards_dict(self.learning_path / csv_file)
|
||||
|
||||
def eval_av_reward_per_episode_csv(self) -> Dict[int, float]:
|
||||
"""Get the eval av reward per episode from file."""
|
||||
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
|
||||
return av_rewards_dict(self.evaluation_path / csv_file)
|
||||
|
||||
@property
|
||||
def env(self) -> Primaite:
|
||||
"""Direct access to the env for ease of testing."""
|
||||
return self._agent_session._env # noqa
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, tb):
|
||||
shutil.rmtree(self.session_path)
|
||||
shutil.rmtree(self.session_path.parent)
|
||||
_LOGGER.debug(f"Deleted temp session directory: {self.session_path}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_primaite_session(request):
|
||||
"""
|
||||
Provides a temporary PrimaiteSession instance.
|
||||
|
||||
It's temporary as it uses a temporary directory as the session path.
|
||||
|
||||
To use this fixture you need to:
|
||||
|
||||
- parametrize your test function with:
|
||||
|
||||
- "temp_primaite_session"
|
||||
- [[path to training config, path to lay down config]]
|
||||
- Include the temp_primaite_session fixture as a param in your test
|
||||
function.
|
||||
- use the temp_primaite_session as a context manager assigning is the
|
||||
name 'session'.
|
||||
|
||||
.. code:: python
|
||||
|
||||
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from primaite.config.training_config import main_training_config_path
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[main_training_config_path(), dos_very_basic_config_path()]
|
||||
],
|
||||
indirect=True
|
||||
)
|
||||
def test_primaite_session(temp_primaite_session):
|
||||
with temp_primaite_session as session:
|
||||
# Learning outputs are saved in session.learning_path
|
||||
session.learn()
|
||||
|
||||
# Evaluation outputs are saved in session.evaluation_path
|
||||
session.evaluate()
|
||||
|
||||
# To ensure that all files are written, you must call .close()
|
||||
session.close()
|
||||
|
||||
# If you need to inspect any session outputs, it must be done
|
||||
# inside the context manager
|
||||
|
||||
# Now that we've exited the context manager, the
|
||||
# session.session_path directory and its contents are deleted
|
||||
"""
|
||||
training_config_path = request.param[0]
|
||||
lay_down_config_path = request.param[1]
|
||||
with patch("primaite.agents.agent.get_session_path", get_temp_session_path) as mck:
|
||||
mck.session_timestamp = datetime.now()
|
||||
|
||||
return TempPrimaiteSession(training_config_path, lay_down_config_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_session_path() -> Path:
|
||||
"""
|
||||
Get a temp directory session path the test session will output to.
|
||||
|
||||
:param session_timestamp: This is the datetime that the session started.
|
||||
:return: The session directory path.
|
||||
"""
|
||||
session_timestamp = datetime.now()
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_dir
|
||||
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
return session_path
|
||||
|
||||
|
||||
def _get_primaite_env_from_config(
|
||||
training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
temp_session_path,
|
||||
):
|
||||
"""Takes a config path and returns the created instance of Primaite."""
|
||||
session_timestamp: datetime = datetime.now()
|
||||
session_path = _get_temp_session_path(session_timestamp)
|
||||
session_path = temp_session_path(session_timestamp)
|
||||
|
||||
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
env = Primaite(
|
||||
training_config_path=training_config_path,
|
||||
lay_down_config_path=lay_down_config_path,
|
||||
transaction_list=[],
|
||||
session_path=session_path,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
@@ -46,7 +153,7 @@ def _get_primaite_env_from_config(
|
||||
|
||||
# TOOD: This needs t be refactored to happen outside. Should be part of
|
||||
# a main Session class.
|
||||
if env.training_config.agent_identifier == "GENERIC":
|
||||
if env.training_config.agent_identifier is AgentIdentifier.RANDOM:
|
||||
run_generic(env, config_values)
|
||||
|
||||
return env
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
from primaite.config.lay_down_config import data_manipulation_config_path
|
||||
from primaite.config.training_config import main_training_config_path
|
||||
from primaite.main import run
|
||||
|
||||
|
||||
def test_primaite_main_e2e():
|
||||
"""Tests the primaite.main.run function end-to-end."""
|
||||
run(main_training_config_path(), data_manipulation_config_path())
|
||||
22
tests/mock_and_patch/get_session_path_mock.py
Normal file
22
tests/mock_and_patch/get_session_path_mock.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def get_temp_session_path(session_timestamp: datetime) -> Path:
|
||||
"""
|
||||
Get a temp directory session path the test session will output to.
|
||||
|
||||
:param session_timestamp: This is the datetime that the session started.
|
||||
:return: The session directory path.
|
||||
"""
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
_LOGGER.debug(f"Created temp session directory: {session_path}")
|
||||
return session_path
|
||||
@@ -95,8 +95,6 @@ def test_rule_hash():
|
||||
rule = ACLRule("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
hash_value_local = hash(rule)
|
||||
|
||||
hash_value_remote = acl.get_dictionary_hash(
|
||||
"DENY", "192.168.1.1", "192.168.1.2", "TCP", "80"
|
||||
)
|
||||
hash_value_remote = acl.get_dictionary_hash("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert hash_value_local == hash_value_remote
|
||||
|
||||
@@ -2,84 +2,79 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from primaite.environment.observations import (
|
||||
NodeLinkTable,
|
||||
NodeStatuses,
|
||||
ObservationsHandler,
|
||||
)
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.environment.observations import NodeLinkTable, NodeStatuses, ObservationsHandler
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
from tests.conftest import _get_primaite_env_from_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def env(request):
|
||||
"""Build Primaite environment for integration tests of observation space."""
|
||||
marker = request.node.get_closest_marker("env_config_paths")
|
||||
training_config_path = marker.args[0]["training_config_path"]
|
||||
lay_down_config_path = marker.args[0]["lay_down_config_path"]
|
||||
env = _get_primaite_env_from_config(
|
||||
training_config_path=training_config_path,
|
||||
lay_down_config_path=lay_down_config_path,
|
||||
)
|
||||
yield env
|
||||
|
||||
|
||||
@pytest.mark.env_config_paths(
|
||||
dict(
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "obs_tests/main_config_without_obs.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
|
||||
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_default_obs_space(env: Primaite):
|
||||
def test_default_obs_space(temp_primaite_session):
|
||||
"""Create environment with no obs space defined in config and check that the default obs space was created."""
|
||||
env.update_environent_obs()
|
||||
with temp_primaite_session as session:
|
||||
session.env.update_environent_obs()
|
||||
|
||||
components = env.obs_handler.registered_obs_components
|
||||
components = session.env.obs_handler.registered_obs_components
|
||||
|
||||
assert len(components) == 1
|
||||
assert isinstance(components[0], NodeLinkTable)
|
||||
assert len(components) == 1
|
||||
assert isinstance(components[0], NodeLinkTable)
|
||||
|
||||
|
||||
@pytest.mark.env_config_paths(
|
||||
dict(
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "obs_tests/main_config_without_obs.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
|
||||
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_registering_components(env: Primaite):
|
||||
def test_registering_components(temp_primaite_session):
|
||||
"""Test regitering and deregistering a component."""
|
||||
handler = ObservationsHandler()
|
||||
component = NodeStatuses(env)
|
||||
handler.register(component)
|
||||
assert component in handler.registered_obs_components
|
||||
handler.deregister(component)
|
||||
assert component not in handler.registered_obs_components
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
handler = ObservationsHandler()
|
||||
component = NodeStatuses(env)
|
||||
handler.register(component)
|
||||
assert component in handler.registered_obs_components
|
||||
handler.deregister(component)
|
||||
assert component not in handler.registered_obs_components
|
||||
|
||||
|
||||
@pytest.mark.env_config_paths(
|
||||
dict(
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "obs_tests/main_config_NODE_LINK_TABLE.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_LINK_TABLE.yaml",
|
||||
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
class TestNodeLinkTable:
|
||||
"""Test the NodeLinkTable observation component (in isolation)."""
|
||||
|
||||
def test_obs_shape(self, env: Primaite):
|
||||
def test_obs_shape(self, temp_primaite_session):
|
||||
"""Try creating env with box observation space."""
|
||||
env.update_environent_obs()
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
env.update_environent_obs()
|
||||
|
||||
# we have three nodes and two links, with two service
|
||||
# therefore the box observation space will have:
|
||||
# * 5 rows (3 nodes + 2 links)
|
||||
# * 6 columns (four fixed and two for the services)
|
||||
assert env.env_obs.shape == (5, 6)
|
||||
# we have three nodes and two links, with two service
|
||||
# therefore the box observation space will have:
|
||||
# * 5 rows (3 nodes + 2 links)
|
||||
# * 6 columns (four fixed and two for the services)
|
||||
assert env.env_obs.shape == (5, 6)
|
||||
|
||||
def test_value(self, env: Primaite):
|
||||
def test_value(self, temp_primaite_session):
|
||||
"""Test that the observation is generated correctly.
|
||||
|
||||
The laydown has:
|
||||
@@ -125,36 +120,43 @@ class TestNodeLinkTable:
|
||||
* 999 (999 traffic service1)
|
||||
* 0 (no traffic for service2)
|
||||
"""
|
||||
# act = np.asarray([0,])
|
||||
obs, reward, done, info = env.step(0) # apply the 'do nothing' action
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
# act = np.asarray([0,])
|
||||
obs, reward, done, info = env.step(0) # apply the 'do nothing' action
|
||||
|
||||
assert np.array_equal(
|
||||
obs,
|
||||
[
|
||||
[1, 1, 3, 1, 1, 1],
|
||||
[2, 1, 1, 1, 1, 4],
|
||||
[3, 1, 1, 1, 0, 0],
|
||||
[4, 0, 0, 0, 999, 0],
|
||||
[5, 0, 0, 0, 999, 0],
|
||||
],
|
||||
)
|
||||
assert np.array_equal(
|
||||
obs,
|
||||
[
|
||||
[1, 1, 3, 1, 1, 1],
|
||||
[2, 1, 1, 1, 1, 4],
|
||||
[3, 1, 1, 1, 0, 0],
|
||||
[4, 0, 0, 0, 999, 0],
|
||||
[5, 0, 0, 0, 999, 0],
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.env_config_paths(
|
||||
dict(
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "obs_tests/main_config_NODE_STATUSES.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_STATUSES.yaml",
|
||||
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
class TestNodeStatuses:
|
||||
"""Test the NodeStatuses observation component (in isolation)."""
|
||||
|
||||
def test_obs_shape(self, env: Primaite):
|
||||
def test_obs_shape(self, temp_primaite_session):
|
||||
"""Try creating env with NodeStatuses as the only component."""
|
||||
assert env.env_obs.shape == (15,)
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
assert env.env_obs.shape == (15,)
|
||||
|
||||
def test_values(self, env: Primaite):
|
||||
def test_values(self, temp_primaite_session):
|
||||
"""Test that the hardware and software states are encoded correctly.
|
||||
|
||||
The laydown has:
|
||||
@@ -181,28 +183,36 @@ class TestNodeStatuses:
|
||||
* service 1 = n/a (0)
|
||||
* service 2 = n/a (0)
|
||||
"""
|
||||
obs, _, _, _ = env.step(0) # apply the 'do nothing' action
|
||||
assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0])
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
obs, _, _, _ = env.step(0) # apply the 'do nothing' action
|
||||
print(obs)
|
||||
assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0])
|
||||
|
||||
|
||||
@pytest.mark.env_config_paths(
|
||||
dict(
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml",
|
||||
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
class TestLinkTrafficLevels:
|
||||
"""Test the LinkTrafficLevels observation component (in isolation)."""
|
||||
|
||||
def test_obs_shape(self, env: Primaite):
|
||||
def test_obs_shape(self, temp_primaite_session):
|
||||
"""Try creating env with MultiDiscrete observation space."""
|
||||
env.update_environent_obs()
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
env.update_environent_obs()
|
||||
|
||||
# we have two links and two services, so the shape should be 2 * 2
|
||||
assert env.env_obs.shape == (2 * 2,)
|
||||
# we have two links and two services, so the shape should be 2 * 2
|
||||
assert env.env_obs.shape == (2 * 2,)
|
||||
|
||||
def test_values(self, env: Primaite):
|
||||
def test_values(self, temp_primaite_session):
|
||||
"""Test that traffic values are encoded correctly.
|
||||
|
||||
The laydown has:
|
||||
@@ -212,12 +222,14 @@ class TestLinkTrafficLevels:
|
||||
* an IER trying to send 999 bits of data over both links the whole time (via the first service)
|
||||
* link bandwidth of 1000, therefore the utilisation is 99.9%
|
||||
"""
|
||||
obs, reward, done, info = env.step(0)
|
||||
obs, reward, done, info = env.step(0)
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
obs, reward, done, info = env.step(0)
|
||||
obs, reward, done, info = env.step(0)
|
||||
|
||||
# the observation space has combine_service_traffic set to False, so the space has this format:
|
||||
# [link1_service1, link1_service2, link2_service1, link2_service2]
|
||||
# we send 999 bits of data via link1 and link2 on service 1.
|
||||
# therefore the first and third elements should be 6 and all others 0
|
||||
# (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%)
|
||||
assert np.array_equal(obs, [6, 0, 6, 0])
|
||||
# the observation space has combine_service_traffic set to False, so the space has this format:
|
||||
# [link1_service1, link1_service2, link2_service1, link2_service2]
|
||||
# we send 999 bits of data via link1 and link2 on service 1.
|
||||
# therefore the first and third elements should be 6 and all others 0
|
||||
# (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%)
|
||||
assert np.array_equal(obs, [6, 0, 6, 0])
|
||||
|
||||
55
tests/test_primaite_session.py
Normal file
55
tests/test_primaite_session.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from primaite.config.training_config import main_training_config_path
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[[main_training_config_path(), dos_very_basic_config_path()]],
|
||||
indirect=True,
|
||||
)
|
||||
def test_primaite_session(temp_primaite_session):
|
||||
"""Tests the PrimaiteSession class and its outputs."""
|
||||
with temp_primaite_session as session:
|
||||
session_path = session.session_path
|
||||
assert session_path.exists()
|
||||
session.learn()
|
||||
# Learning outputs are saved in session.learning_path
|
||||
session.evaluate()
|
||||
# Evaluation outputs are saved in session.evaluation_path
|
||||
|
||||
# If you need to inspect any session outputs, it must be done inside
|
||||
# the context manager
|
||||
|
||||
# Check that the metadata json file exists
|
||||
assert (session_path / "session_metadata.json").exists()
|
||||
|
||||
# Check that the network png file exists
|
||||
assert (session_path / f"network_{session.timestamp_str}.png").exists()
|
||||
|
||||
# Check that both the transactions and av reward csv files exist
|
||||
for file in session.learning_path.iterdir():
|
||||
if file.suffix == ".csv":
|
||||
assert "all_transactions" in file.name or "average_reward_per_episode" in file.name
|
||||
|
||||
# Check that both the transactions and av reward csv files exist
|
||||
for file in session.evaluation_path.iterdir():
|
||||
if file.suffix == ".csv":
|
||||
assert "all_transactions" in file.name or "average_reward_per_episode" in file.name
|
||||
|
||||
_LOGGER.debug("Inspecting files in temp session path...")
|
||||
for dir_path, dir_names, file_names in os.walk(session_path):
|
||||
for file in file_names:
|
||||
path = os.path.join(dir_path, file)
|
||||
file_str = path.split(str(session_path))[-1]
|
||||
_LOGGER.debug(f" {file_str}")
|
||||
|
||||
# Now that we've exited the context manager, the session.session_path
|
||||
# directory and its contents are deleted
|
||||
assert not session_path.exists()
|
||||
@@ -1,68 +1,30 @@
|
||||
from datetime import datetime
|
||||
import pytest
|
||||
|
||||
from primaite.config.lay_down_config import data_manipulation_config_path
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
from tests.conftest import _get_temp_session_path
|
||||
|
||||
|
||||
def run_generic(env, config_values):
|
||||
"""Run against a generic agent."""
|
||||
# Reset the environment at the start of the episode
|
||||
env.reset()
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
for step in range(0, config_values.num_steps):
|
||||
# Send the observation space to the agent to get an action
|
||||
# TEMP - random action for now
|
||||
# action = env.blue_agent_action(obs)
|
||||
# action = env.action_space.sample()
|
||||
action = 0
|
||||
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
# Break if done is True
|
||||
if done:
|
||||
break
|
||||
|
||||
# Reset the environment at the end of the episode
|
||||
env.reset()
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def test_random_red_agent_behaviour():
|
||||
"""
|
||||
Test that hardware state is penalised at each step.
|
||||
|
||||
When the initial state is OFF compared to reference state which is ON.
|
||||
"""
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "test_random_red_main_config.yaml",
|
||||
data_manipulation_config_path(),
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_random_red_agent_behaviour(temp_primaite_session):
|
||||
"""Test that red agent POL is randomised each episode."""
|
||||
list_of_node_instructions = []
|
||||
|
||||
# RUN TWICE so we can make sure that red agent is randomised
|
||||
for i in range(2):
|
||||
"""Takes a config path and returns the created instance of Primaite."""
|
||||
session_timestamp: datetime = datetime.now()
|
||||
session_path = _get_temp_session_path(session_timestamp)
|
||||
with temp_primaite_session as session:
|
||||
session.evaluate()
|
||||
list_of_node_instructions.append(session.env.red_node_pol)
|
||||
|
||||
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
env = Primaite(
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "one_node_states_on_off_main_config.yaml",
|
||||
lay_down_config_path=data_manipulation_config_path(),
|
||||
transaction_list=[],
|
||||
session_path=session_path,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
# set red_agent_
|
||||
env.training_config.random_red_agent = True
|
||||
training_config = env.training_config
|
||||
training_config.num_steps = env.episode_steps
|
||||
|
||||
run_generic(env, training_config)
|
||||
# add red pol instructions to list
|
||||
list_of_node_instructions.append(env.red_node_pol)
|
||||
session.evaluate()
|
||||
list_of_node_instructions.append(session.env.red_node_pol)
|
||||
|
||||
# compare instructions to make sure that red instructions are truly random
|
||||
for index, instruction in enumerate(list_of_node_instructions):
|
||||
@@ -73,5 +35,4 @@ def test_random_red_agent_behaviour():
|
||||
print(f"{key} end step: {instruction.get_end_step()}")
|
||||
print(f"{key} target node id: {instruction.get_target_node_id()}")
|
||||
print("")
|
||||
|
||||
assert list_of_node_instructions[0].__ne__(list_of_node_instructions[1])
|
||||
|
||||
@@ -1,13 +1,7 @@
|
||||
"""Used to test Active Node functions."""
|
||||
import pytest
|
||||
|
||||
from primaite.common.enums import (
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
NodeType,
|
||||
Priority,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
|
||||
from primaite.common.service import Service
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
@@ -57,9 +51,7 @@ def test_node_boots_correctly(operating_state, expected_operating_state):
|
||||
file_system_state="GOOD",
|
||||
config_values=1,
|
||||
)
|
||||
service_attributes = Service(
|
||||
name="node", port="80", software_state=SoftwareState.COMPROMISED
|
||||
)
|
||||
service_attributes = Service(name="node", port="80", software_state=SoftwareState.COMPROMISED)
|
||||
service_node.add_service(service_attributes)
|
||||
|
||||
for x in range(5):
|
||||
|
||||
@@ -1,21 +1,26 @@
|
||||
import pytest
|
||||
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
from tests.conftest import _get_primaite_env_from_config
|
||||
|
||||
|
||||
def test_rewards_are_being_penalised_at_each_step_function():
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml",
|
||||
TEST_CONFIG_ROOT / "one_node_states_on_off_lay_down_config.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_rewards_are_being_penalised_at_each_step_function(
|
||||
temp_primaite_session,
|
||||
):
|
||||
"""
|
||||
Test that hardware state is penalised at each step.
|
||||
|
||||
When the initial state is OFF compared to reference state which is ON.
|
||||
"""
|
||||
env = _get_primaite_env_from_config(
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "one_node_states_on_off_main_config.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT
|
||||
/ "one_node_states_on_off_lay_down_config.yaml",
|
||||
)
|
||||
|
||||
"""
|
||||
The config 'one_node_states_on_off_lay_down_config.yaml' has 15 steps:
|
||||
On different steps, the laydown config has Pattern of Life (PoLs) which change a state of the node's attribute.
|
||||
For example, turning the nodes' file system state to CORRUPT from its original state GOOD.
|
||||
@@ -38,4 +43,8 @@ def test_rewards_are_being_penalised_at_each_step_function():
|
||||
For the 4 steps where this occurs the average reward is:
|
||||
Average Reward: -8 (-120 / 15)
|
||||
"""
|
||||
assert env.average_reward == -8.0
|
||||
with temp_primaite_session as session:
|
||||
session.evaluate()
|
||||
session.close()
|
||||
ev_rewards = session.eval_av_reward_per_episode_csv()
|
||||
assert ev_rewards[1] == -8.0
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.common.enums import HardwareState
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
from tests.conftest import _get_primaite_env_from_config
|
||||
|
||||
|
||||
def run_generic_set_actions(env: Primaite):
|
||||
@@ -17,7 +18,6 @@ def run_generic_set_actions(env: Primaite):
|
||||
# TEMP - random action for now
|
||||
# action = env.blue_agent_action(obs)
|
||||
action = 0
|
||||
print("Episode:", episode, "\nStep:", step)
|
||||
if step == 5:
|
||||
# [1, 1, 2, 1, 1, 1]
|
||||
# Creates an ACL rule
|
||||
@@ -44,59 +44,71 @@ def run_generic_set_actions(env: Primaite):
|
||||
# env.close()
|
||||
|
||||
|
||||
def test_single_action_space_is_valid():
|
||||
"""Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations."""
|
||||
env = _get_primaite_env_from_config(
|
||||
training_config_path=TEST_CONFIG_ROOT / "single_action_space_main_config.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT
|
||||
/ "single_action_space_lay_down_config.yaml",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "single_action_space_main_config.yaml",
|
||||
TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_single_action_space_is_valid(temp_primaite_session):
|
||||
"""Test single action space is valid."""
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
|
||||
run_generic_set_actions(env)
|
||||
|
||||
# Retrieve the action space dictionary values from environment
|
||||
env_action_space_dict = env.action_dict.values()
|
||||
# Flags to check the conditions of the action space
|
||||
contains_acl_actions = False
|
||||
contains_node_actions = False
|
||||
both_action_spaces = False
|
||||
# Loop through each element of the list (which is every value from the dictionary)
|
||||
for dict_item in env_action_space_dict:
|
||||
# Node action detected
|
||||
if len(dict_item) == 4:
|
||||
contains_node_actions = True
|
||||
# Link action detected
|
||||
elif len(dict_item) == 6:
|
||||
contains_acl_actions = True
|
||||
# If both are there then the ANY action type is working
|
||||
if contains_node_actions and contains_acl_actions:
|
||||
both_action_spaces = True
|
||||
# Check condition should be True
|
||||
assert both_action_spaces
|
||||
run_generic_set_actions(env)
|
||||
# Retrieve the action space dictionary values from environment
|
||||
env_action_space_dict = env.action_dict.values()
|
||||
# Flags to check the conditions of the action space
|
||||
contains_acl_actions = False
|
||||
contains_node_actions = False
|
||||
both_action_spaces = False
|
||||
# Loop through each element of the list (which is every value from the dictionary)
|
||||
for dict_item in env_action_space_dict:
|
||||
# Node action detected
|
||||
if len(dict_item) == 4:
|
||||
contains_node_actions = True
|
||||
# Link action detected
|
||||
elif len(dict_item) == 6:
|
||||
contains_acl_actions = True
|
||||
# If both are there then the ANY action type is working
|
||||
if contains_node_actions and contains_acl_actions:
|
||||
both_action_spaces = True
|
||||
# Check condition should be True
|
||||
assert both_action_spaces
|
||||
|
||||
|
||||
def test_agent_is_executing_actions_from_both_spaces():
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "single_action_space_fixed_blue_actions_main_config.yaml",
|
||||
TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_agent_is_executing_actions_from_both_spaces(temp_primaite_session):
|
||||
"""Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL)."""
|
||||
env = _get_primaite_env_from_config(
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "single_action_space_fixed_blue_actions_main_config.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT
|
||||
/ "single_action_space_lay_down_config.yaml",
|
||||
)
|
||||
# Run environment with specified fixed blue agent actions only
|
||||
run_generic_set_actions(env)
|
||||
# Retrieve hardware state of computer_1 node in laydown config
|
||||
# Agent turned this off in Step 5
|
||||
computer_node_hardware_state = env.nodes["1"].hardware_state
|
||||
# Retrieve the Access Control List object stored by the environment at the end of the episode
|
||||
access_control_list = env.acl
|
||||
# Use the Access Control List object acl object attribute to get dictionary
|
||||
# Use dictionary.values() to get total list of all items in the dictionary
|
||||
acl_rules_list = access_control_list.acl.values()
|
||||
# Length of this list tells you how many items are in the dictionary
|
||||
# This number is the frequency of Access Control Rules in the environment
|
||||
# In the scenario, we specified that the agent should create only 1 acl rule
|
||||
num_of_rules = len(acl_rules_list)
|
||||
# Therefore these statements below MUST be true
|
||||
assert computer_node_hardware_state == HardwareState.OFF
|
||||
assert num_of_rules == 1
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
# Run environment with specified fixed blue agent actions only
|
||||
run_generic_set_actions(env)
|
||||
# Retrieve hardware state of computer_1 node in laydown config
|
||||
# Agent turned this off in Step 5
|
||||
computer_node_hardware_state = env.nodes["1"].hardware_state
|
||||
# Retrieve the Access Control List object stored by the environment at the end of the episode
|
||||
access_control_list = env.acl
|
||||
# Use the Access Control List object acl object attribute to get dictionary
|
||||
# Use dictionary.values() to get total list of all items in the dictionary
|
||||
acl_rules_list = access_control_list.acl.values()
|
||||
# Length of this list tells you how many items are in the dictionary
|
||||
# This number is the frequency of Access Control Rules in the environment
|
||||
# In the scenario, we specified that the agent should create only 1 acl rule
|
||||
num_of_rules = len(acl_rules_list)
|
||||
# Therefore these statements below MUST be true
|
||||
assert computer_node_hardware_state == HardwareState.OFF
|
||||
assert num_of_rules == 1
|
||||
|
||||
@@ -7,8 +7,8 @@ from tests import TEST_CONFIG_ROOT
|
||||
|
||||
def test_legacy_lay_down_config_yaml_conversion():
|
||||
"""Tests the conversion of legacy lay down config files."""
|
||||
legacy_path = TEST_CONFIG_ROOT / "legacy" / "legacy_training_config.yaml"
|
||||
new_path = TEST_CONFIG_ROOT / "legacy" / "new_training_config.yaml"
|
||||
legacy_path = TEST_CONFIG_ROOT / "legacy_conversion" / "legacy_training_config.yaml"
|
||||
new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "new_training_config.yaml"
|
||||
|
||||
with open(legacy_path, "r") as file:
|
||||
legacy_dict = yaml.safe_load(file)
|
||||
@@ -24,13 +24,13 @@ def test_legacy_lay_down_config_yaml_conversion():
|
||||
|
||||
def test_create_config_values_main_from_file():
|
||||
"""Tests creating an instance of TrainingConfig from file."""
|
||||
new_path = TEST_CONFIG_ROOT / "legacy" / "new_training_config.yaml"
|
||||
new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "new_training_config.yaml"
|
||||
|
||||
training_config.load(new_path)
|
||||
|
||||
|
||||
def test_create_config_values_main_from_legacy_file():
|
||||
"""Tests creating an instance of TrainingConfig from legacy file."""
|
||||
new_path = TEST_CONFIG_ROOT / "legacy" / "legacy_training_config.yaml"
|
||||
new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "legacy_training_config.yaml"
|
||||
|
||||
training_config.load(new_path, legacy_file=True)
|
||||
|
||||
Reference in New Issue
Block a user