diff --git a/.azure/artifact-release-pipeline.yaml b/.azure/artifact-release-pipeline.yaml index ca8f5b60..47e9aacc 100644 --- a/.azure/artifact-release-pipeline.yaml +++ b/.azure/artifact-release-pipeline.yaml @@ -1,38 +1,12 @@ trigger: - main +pool: + vmImage: ubuntu-latest strategy: matrix: - Ubuntu2004Python38: - python.version: '3.8' - imageName: 'ubuntu-20.04' - Ubuntu2004Python39: - python.version: '3.9' - imageName: 'ubuntu-20.04' - Ubuntu2004Python310: + Python310: python.version: '3.10' - imageName: 'ubuntu-20.04' - WindowsPython38: - python.version: '3.8' - imageName: 'windows-latest' - WindowsPython39: - python.version: '3.9' - imageName: 'windows-latest' - WindowsPython310: - python.version: '3.10' - imageName: 'windows-latest' - MacPython38: - python.version: '3.8' - imageName: 'macOS-latest' - MacPython39: - python.version: '3.9' - imageName: 'macOS-latest' - MacPython310: - python.version: '3.10' - imageName: 'macOS-latest' - -pool: - vmImage: $(imageName) steps: - task: UsePythonVersion@0 diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index 8bfdca02..902eb38d 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -6,18 +6,29 @@ trigger: - bugfix/* - release/* -pool: - vmImage: ubuntu-latest strategy: matrix: - Python38: + UbuntuPython38: python.version: '3.8' - Python39: - python.version: '3.9' - Python310: + imageName: 'ubuntu-latest' + UbuntuPython310: python.version: '3.10' - Python311: - python.version: '3.11' + imageName: 'ubuntu-latest' + WindowsPython38: + python.version: '3.8' + imageName: 'windows-latest' + WindowsPython310: + python.version: '3.10' + imageName: 'windows-latest' + MacOSPython38: + python.version: '3.8' + imageName: 'macOS-latest' + MacOSPython310: + python.version: '3.10' + imageName: 'macOS-latest' + +pool: + vmImage: $(imageName) steps: - task: UsePythonVersion@0 @@ -47,16 +58,17 @@ steps: PRIMAITE_WHEEL=$(ls ./dist/primaite*.whl) python -m pip install $PRIMAITE_WHEEL[dev] displayName: 'Install PrimAITE' + condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' )) + +- script: | + forfiles /p dist\ /m *.whl /c "cmd /c python -m pip install @file[dev]" + displayName: 'Install PrimAITE' + condition: eq( variables['Agent.OS'], 'Windows_NT' ) - script: | primaite setup displayName: 'Perform PrimAITE Setup' -#- script: | -# flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics -# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -# displayName: 'Lint with flake8' - - script: | pytest tests/ - displayName: 'Run unmarked tests' + displayName: 'Run tests' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a08b17b8..6e435bee 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,6 +13,9 @@ repos: rev: 23.1.0 hooks: - id: black + args: [ "--line-length=120" ] + additional_dependencies: + - jupyter - repo: http://github.com/pycqa/isort rev: 5.12.0 hooks: @@ -22,4 +25,5 @@ repos: rev: 6.0.0 hooks: - id: flake8 - additional_dependencies: [ flake8-docstrings ] + additional_dependencies: + - flake8-docstrings diff --git a/docs/source/config.rst b/docs/source/config.rst index 01f1e325..50594549 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -20,13 +20,24 @@ The environment config file consists of the following attributes: **Generic Config Values** -* **agent_identifier** [enum] - This identifies the agent to use for the session. Select from one of the following: +* **agent_framework** [enum] + + This identifies the agent framework to be used to instantiate the agent algorithm. Select from one of the following: + + * NONE - Where a user developed agent is to be used + * SB3 - Stable Baselines3 + * RLLIB - Ray RLlib. + +* **agent_identifier** + + This identifies the agent to use for the session. Select from one of the following: + + * A2C - Advantage Actor Critic + * PPO - Proximal Policy Optimization + * HARDCODED - A custom built deterministic agent + * RANDOM - A Stochastic random agent - * GENERIC - Where a user developed agent is to be used - * STABLE_BASELINES3_PPO - Use a SB3 PPO agent - * STABLE_BASELINES3_A2C - use a SB3 A2C agent * **random_red_agent** [bool] @@ -34,38 +45,38 @@ The environment config file consists of the following attributes: * **action_type** [enum] - Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session + Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session * **num_episodes** [int] - This defines the number of episodes that the agent will train or be evaluated over. + This defines the number of episodes that the agent will train or be evaluated over. * **num_steps** [int] - Determines the number of steps to run in each episode of the session + Determines the number of steps to run in each episode of the session * **time_delay** [int] - The time delay (in milliseconds) to take between each step when running a GENERIC agent session + The time delay (in milliseconds) to take between each step when running a GENERIC agent session * **session_type** [text] - Type of session to be run (TRAINING or EVALUATION) + Type of session to be run (TRAINING, EVALUATION, or BOTH) * **load_agent** [bool] - Determine whether to load an agent from file + Determine whether to load an agent from file * **agent_load_file** [text] - File path and file name of agent if you're loading one in + File path and file name of agent if you're loading one in * **observation_space_high_value** [int] - The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases + The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases **Reward-Based Config Values** @@ -73,95 +84,95 @@ Rewards are calculated based on the difference between the current state and ref * **Generic [all_ok]** [int] - The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken) + The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken) * **Node Hardware State [off_should_be_on]** [int] - The score to give when the node should be on, but is off + The score to give when the node should be on, but is off * **Node Hardware State [off_should_be_resetting]** [int] - The score to give when the node should be resetting, but is off + The score to give when the node should be resetting, but is off * **Node Hardware State [on_should_be_off]** [int] - The score to give when the node should be off, but is on + The score to give when the node should be off, but is on * **Node Hardware State [on_should_be_resetting]** [int] - The score to give when the node should be resetting, but is on + The score to give when the node should be resetting, but is on * **Node Hardware State [resetting_should_be_on]** [int] - The score to give when the node should be on, but is resetting + The score to give when the node should be on, but is resetting * **Node Hardware State [resetting_should_be_off]** [int] - The score to give when the node should be off, but is resetting + The score to give when the node should be off, but is resetting * **Node Hardware State [resetting]** [int] - The score to give when the node is resetting + The score to give when the node is resetting * **Node Operating System or Service State [good_should_be_patching]** [int] - The score to give when the state should be patching, but is good + The score to give when the state should be patching, but is good * **Node Operating System or Service State [good_should_be_compromised]** [int] - The score to give when the state should be compromised, but is good + The score to give when the state should be compromised, but is good * **Node Operating System or Service State [good_should_be_overwhelmed]** [int] - The score to give when the state should be overwhelmed, but is good + The score to give when the state should be overwhelmed, but is good * **Node Operating System or Service State [patching_should_be_good]** [int] - The score to give when the state should be good, but is patching + The score to give when the state should be good, but is patching * **Node Operating System or Service State [patching_should_be_compromised]** [int] - The score to give when the state should be compromised, but is patching + The score to give when the state should be compromised, but is patching * **Node Operating System or Service State [patching_should_be_overwhelmed]** [int] - The score to give when the state should be overwhelmed, but is patching + The score to give when the state should be overwhelmed, but is patching * **Node Operating System or Service State [patching]** [int] - The score to give when the state is patching + The score to give when the state is patching * **Node Operating System or Service State [compromised_should_be_good]** [int] - The score to give when the state should be good, but is compromised + The score to give when the state should be good, but is compromised * **Node Operating System or Service State [compromised_should_be_patching]** [int] - The score to give when the state should be patching, but is compromised + The score to give when the state should be patching, but is compromised * **Node Operating System or Service State [compromised_should_be_overwhelmed]** [int] - The score to give when the state should be overwhelmed, but is compromised + The score to give when the state should be overwhelmed, but is compromised * **Node Operating System or Service State [compromised]** [int] - The score to give when the state is compromised + The score to give when the state is compromised * **Node Operating System or Service State [overwhelmed_should_be_good]** [int] - The score to give when the state should be good, but is overwhelmed + The score to give when the state should be good, but is overwhelmed * **Node Operating System or Service State [overwhelmed_should_be_patching]** [int] - The score to give when the state should be patching, but is overwhelmed + The score to give when the state should be patching, but is overwhelmed * **Node Operating System or Service State [overwhelmed_should_be_compromised]** [int] - The score to give when the state should be compromised, but is overwhelmed + The score to give when the state should be compromised, but is overwhelmed * **Node Operating System or Service State [overwhelmed]** [int] - The score to give when the state is overwhelmed + The score to give when the state is overwhelmed * **Node File System State [good_should_be_repairing]** [int] @@ -265,37 +276,37 @@ Rewards are calculated based on the difference between the current state and ref * **IER Status [red_ier_running]** [int] - The score to give when a red agent IER is permitted to run + The score to give when a red agent IER is permitted to run * **IER Status [green_ier_blocked]** [int] - The score to give when a green agent IER is prevented from running + The score to give when a green agent IER is prevented from running **Patching / Reset Durations** * **os_patching_duration** [int] - The number of steps to take when patching an Operating System + The number of steps to take when patching an Operating System * **node_reset_duration** [int] - The number of steps to take when resetting a node's hardware state + The number of steps to take when resetting a node's hardware state * **service_patching_duration** [int] - The number of steps to take when patching a service + The number of steps to take when patching a service * **file_system_repairing_limit** [int]: - The number of steps to take when repairing the file system + The number of steps to take when repairing the file system * **file_system_restoring_limit** [int] - The number of steps to take when restoring the file system + The number of steps to take when restoring the file system * **file_system_scanning_limit** [int] - The number of steps to take when scanning the file system + The number of steps to take when scanning the file system * **deterministic** [bool] @@ -312,22 +323,22 @@ The lay down config file consists of the following attributes: * **itemType: ACTIONS** [enum] - Determines whether a NODE or ACL action space format is adopted for the session + Determines whether a NODE or ACL action space format is adopted for the session * **itemType: OBSERVATION_SPACE** [dict] - Allows for user to configure observation space by combining one or more observation components. List of available - components is is :py:mod:'primaite.environment.observations'. + Allows for user to configure observation space by combining one or more observation components. List of available + components is is :py:mod:'primaite.environment.observations'. - The observation space config item should have a ``components`` key which is a list of components. Each component - config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the - component while it is being initialised. + The observation space config item should have a ``components`` key which is a list of components. Each component + config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the + component while it is being initialised. - This example illustrates the correct format for the observation space config item + This example illustrates the correct format for the observation space config item .. code-block::yaml - - itemType: OBSERVATION_SPACE + - item_type: OBSERVATION_SPACE components: - name: LINK_TRAFFIC_LEVELS options: @@ -340,15 +351,15 @@ The lay down config file consists of the following attributes: * **item_type: PORTS** [int] - Provides a list of ports modelled in this session + Provides a list of ports modelled in this session * **item_type: SERVICES** [freetext] - Provides a list of services modelled in this session + Provides a list of services modelled in this session * **item_type: NODE** - Defines a node included in the system laydown being simulated. It should consist of the following attributes: + Defines a node included in the system laydown being simulated. It should consist of the following attributes: * **id** [int]: Unique ID for this YAML item * **name** [freetext]: Human-readable name of the component @@ -367,7 +378,7 @@ The lay down config file consists of the following attributes: * **item_type: LINK** - Defines a link included in the system laydown being simulated. It should consist of the following attributes: + Defines a link included in the system laydown being simulated. It should consist of the following attributes: * **id** [int]: Unique ID for this YAML item * **name** [freetext]: Human-readable name of the component @@ -377,7 +388,7 @@ The lay down config file consists of the following attributes: * **item_type: GREEN_IER** - Defines a green agent Information Exchange Requirement (IER). It should consist of: + Defines a green agent Information Exchange Requirement (IER). It should consist of: * **id** [int]: Unique ID for this YAML item * **start_step** [int]: The start step (in the episode) for this IER to begin @@ -391,7 +402,7 @@ The lay down config file consists of the following attributes: * **item_type: RED_IER** - Defines a red agent Information Exchange Requirement (IER). It should consist of: + Defines a red agent Information Exchange Requirement (IER). It should consist of: * **id** [int]: Unique ID for this YAML item * **start_step** [int]: The start step (in the episode) for this IER to begin diff --git a/docs/source/primaite-dependencies.rst b/docs/source/primaite-dependencies.rst index 67971d2b..d5511a55 100644 --- a/docs/source/primaite-dependencies.rst +++ b/docs/source/primaite-dependencies.rst @@ -1,323 +1,429 @@ -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| Name | Version | License | URL | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| Babel | 2.12.1 | BSD License | https://babel.pocoo.org/ | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| Jinja2 | 3.1.2 | BSD License | https://palletsprojects.com/p/jinja/ | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| MarkupSafe | 2.1.3 | BSD License | https://palletsprojects.com/p/markupsafe/ | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| Pillow | 9.5.0 | Historical Permission Notice and Disclaimer (HPND) | https://python-pillow.org | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| PyYAML | 6.0 | MIT License | https://pyyaml.org/ | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| Pygments | 2.15.1 | BSD License | https://pygments.org | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| Send2Trash | 1.8.2 | BSD License | https://github.com/arsenetar/send2trash | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| Sphinx | 6.1.3 | BSD License | https://www.sphinx-doc.org/ | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| aiofiles | 22.1.0 | Apache Software License | https://github.com/Tinche/aiofiles | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| aiosqlite | 0.19.0 | MIT License | UNKNOWN | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| alabaster | 0.7.13 | BSD License | https://alabaster.readthedocs.io | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| anyio | 3.7.0 | MIT License | https://anyio.readthedocs.io/en/stable/versionhistory.html | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| argon2-cffi | 21.3.0 | MIT License | https://github.com/hynek/argon2-cffi/blob/main/CHANGELOG.md | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| argon2-cffi-bindings | 21.2.0 | MIT License | https://github.com/hynek/argon2-cffi-bindings | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| arrow | 1.2.3 | Apache Software License | https://arrow.readthedocs.io | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| asttokens | 2.2.1 | Apache 2.0 | https://github.com/gristlabs/asttokens | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| attrs | 23.1.0 | MIT License | https://www.attrs.org/en/stable/changelog.html | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| backcall | 0.2.0 | BSD License | https://github.com/takluyver/backcall | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| beautifulsoup4 | 4.12.2 | MIT License | https://www.crummy.com/software/BeautifulSoup/bs4/ | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| bleach | 6.0.0 | Apache Software License | https://github.com/mozilla/bleach | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| build | 0.10.0 | MIT License | UNKNOWN | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| certifi | 2023.5.7 | Mozilla Public License 2.0 (MPL 2.0) | https://github.com/certifi/python-certifi | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| cffi | 1.15.1 | MIT License | http://cffi.readthedocs.org | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| cfgv | 3.3.1 | MIT License | https://github.com/asottile/cfgv | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| charset-normalizer | 3.1.0 | MIT License | https://github.com/Ousret/charset_normalizer | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| click | 8.1.3 | BSD License | https://palletsprojects.com/p/click/ | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| cloudpickle | 2.2.1 | BSD License | https://github.com/cloudpipe/cloudpickle | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| colorama | 0.4.6 | BSD License | https://github.com/tartley/colorama | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| comm | 0.1.3 | BSD License | https://github.com/ipython/comm | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| contourpy | 1.0.7 | BSD License | UNKNOWN | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| coverage | 7.2.7 | Apache Software License | https://github.com/nedbat/coveragepy | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| cycler | 0.11.0 | BSD License | https://github.com/matplotlib/cycler | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| debugpy | 1.6.7 | Eclipse Public License 2.0 (EPL-2.0); MIT License | https://aka.ms/debugpy | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| decorator | 5.1.1 | BSD License | https://github.com/micheles/decorator | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| defusedxml | 0.7.1 | Python Software Foundation License | https://github.com/tiran/defusedxml | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| distlib | 0.3.6 | Python Software Foundation License | https://github.com/pypa/distlib | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| docutils | 0.19 | BSD License; GNU General Public License (GPL); Public Domain; Python Software Foundation License | https://docutils.sourceforge.io/ | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| exceptiongroup | 1.1.1 | MIT License | https://github.com/agronholm/exceptiongroup/blob/main/CHANGES.rst | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| executing | 1.2.0 | MIT License | https://github.com/alexmojaki/executing | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| fastjsonschema | 2.17.1 | BSD License | https://github.com/horejsek/python-fastjsonschema | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| filelock | 3.12.1 | The Unlicense (Unlicense) | https://github.com/tox-dev/py-filelock | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| flake8 | 6.0.0 | MIT License | https://github.com/pycqa/flake8 | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| fonttools | 4.40.0 | MIT License | http://github.com/fonttools/fonttools | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| fqdn | 1.5.1 | Mozilla Public License 2.0 (MPL 2.0) | https://github.com/ypcrts/fqdn | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| furo | 2023.3.27 | MIT License | UNKNOWN | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| gym | 0.21.0 | UNKNOWN | https://github.com/openai/gym | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| identify | 2.5.24 | MIT License | https://github.com/pre-commit/identify | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| idna | 3.4 | BSD License | https://github.com/kjd/idna | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| imagesize | 1.4.1 | MIT License | https://github.com/shibukawa/imagesize_py | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| importlib-metadata | 4.13.0 | Apache Software License | https://github.com/python/importlib_metadata | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| importlib-resources | 5.12.0 | Apache Software License | https://github.com/python/importlib_resources | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| iniconfig | 2.0.0 | MIT License | https://github.com/pytest-dev/iniconfig | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| ipykernel | 6.23.2 | BSD License | https://ipython.org | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| ipython | 8.12.2 | BSD License | https://ipython.org | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| ipython-genutils | 0.2.0 | BSD License | http://ipython.org | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| isoduration | 20.11.0 | ISC License (ISCL) | https://github.com/bolsote/isoduration | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| jedi | 0.18.2 | MIT License | https://github.com/davidhalter/jedi | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| json5 | 0.9.14 | Apache Software License | https://github.com/dpranke/pyjson5 | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| jsonpointer | 2.3 | BSD License | https://github.com/stefankoegl/python-json-pointer | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| jsonschema | 4.17.3 | MIT License | https://github.com/python-jsonschema/jsonschema | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| jupyter-events | 0.6.3 | BSD License | http://jupyter.org | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| jupyter-ydoc | 0.2.4 | BSD 3-Clause License | https://jupyter.org | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| jupyter_client | 8.2.0 | BSD License | https://jupyter.org | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| jupyter_core | 5.3.0 | BSD License | https://jupyter.org | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| jupyter_server | 2.6.0 | BSD License | https://jupyter-server.readthedocs.io | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| jupyter_server_fileid | 0.9.0 | BSD License | UNKNOWN | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| jupyter_server_terminals | 0.4.4 | BSD License | https://jupyter.org | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| jupyter_server_ydoc | 0.6.1 | BSD License | https://jupyter.org | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| jupyterlab | 3.6.1 | BSD License | https://jupyter.org | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| jupyterlab-pygments | 0.2.2 | BSD | https://github.com/jupyterlab/jupyterlab_pygments | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| jupyterlab_server | 2.23.0 | BSD License | https://jupyterlab-server.readthedocs.io | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| kiwisolver | 1.4.4 | BSD License | UNKNOWN | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| markdown-it-py | 3.0.0 | MIT License | https://github.com/executablebooks/markdown-it-py | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| matplotlib | 3.7.1 | Python Software Foundation License | https://matplotlib.org | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| matplotlib-inline | 0.1.6 | BSD 3-Clause | https://github.com/ipython/matplotlib-inline | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| mccabe | 0.7.0 | MIT License | https://github.com/pycqa/mccabe | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| mdurl | 0.1.2 | MIT License | https://github.com/executablebooks/mdurl | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| mistune | 2.0.5 | BSD License | https://github.com/lepture/mistune | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| mpmath | 1.3.0 | BSD License | http://mpmath.org/ | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| nbclassic | 1.0.0 | BSD License | https://github.com/jupyter/nbclassic | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| nbclient | 0.8.0 | BSD License | https://jupyter.org | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| nbconvert | 7.4.0 | BSD License | https://jupyter.org | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| nbformat | 5.9.0 | BSD License | https://jupyter.org | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| nest-asyncio | 1.5.6 | BSD License | https://github.com/erdewit/nest_asyncio | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| networkx | 3.1 | BSD License | https://networkx.org/ | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| nodeenv | 1.8.0 | BSD License | https://github.com/ekalinin/nodeenv | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| notebook | 6.5.4 | BSD License | http://jupyter.org | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| notebook_shim | 0.2.3 | BSD License | UNKNOWN | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| numpy | 1.23.5 | BSD License | https://www.numpy.org | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| overrides | 7.3.1 | Apache License, Version 2.0 | https://github.com/mkorpela/overrides | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| packaging | 23.1 | Apache Software License; BSD License | https://github.com/pypa/packaging | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| pandas | 2.0.2 | BSD License | UNKNOWN | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| pandocfilters | 1.5.0 | BSD License | http://github.com/jgm/pandocfilters | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| parso | 0.8.3 | MIT License | https://github.com/davidhalter/parso | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| pickleshare | 0.7.5 | MIT License | https://github.com/pickleshare/pickleshare | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| pkgutil_resolve_name | 1.3.10 | MIT License | https://github.com/graingert/pkgutil-resolve-name | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| platformdirs | 3.5.1 | MIT License | https://github.com/platformdirs/platformdirs | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| pluggy | 1.0.0 | MIT License | https://github.com/pytest-dev/pluggy | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| pre-commit | 2.20.0 | MIT License | https://github.com/pre-commit/pre-commit | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| primaite | 2.0.0.dev0 | MIT License | UNKNOWN | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| primaite | 2.0.0.dev0 | MIT License | UNKNOWN | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| prometheus-client | 0.17.0 | Apache Software License | https://github.com/prometheus/client_python | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| prompt-toolkit | 3.0.38 | BSD License | https://github.com/prompt-toolkit/python-prompt-toolkit | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| psutil | 5.9.5 | BSD License | https://github.com/giampaolo/psutil | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| pure-eval | 0.2.2 | MIT License | http://github.com/alexmojaki/pure_eval | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| pycodestyle | 2.10.0 | MIT License | https://pycodestyle.pycqa.org/ | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| pycparser | 2.21 | BSD License | https://github.com/eliben/pycparser | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| pyflakes | 3.0.1 | MIT License | https://github.com/PyCQA/pyflakes | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| pyparsing | 3.0.9 | MIT License | https://github.com/pyparsing/pyparsing/ | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| pyproject_hooks | 1.0.0 | MIT License | https://github.com/pypa/pyproject-hooks | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| pyrsistent | 0.19.3 | MIT License | https://github.com/tobgu/pyrsistent/ | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| pytest | 7.2.0 | MIT License | https://docs.pytest.org/en/latest/ | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| pytest-cov | 4.0.0 | MIT License | https://github.com/pytest-dev/pytest-cov | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| pytest-flake8 | 1.1.1 | BSD License | https://github.com/tholo/pytest-flake8 | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| python-dateutil | 2.8.2 | Apache Software License; BSD License | https://github.com/dateutil/dateutil | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| python-json-logger | 2.0.7 | BSD License | http://github.com/madzak/python-json-logger | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| pytz | 2023.3 | MIT License | http://pythonhosted.org/pytz | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| pywin32 | 306 | Python Software Foundation License | https://github.com/mhammond/pywin32 | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| pywinpty | 2.0.10 | MIT | UNKNOWN | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| pyzmq | 25.1.0 | BSD License; GNU Library or Lesser General Public License (LGPL) | https://pyzmq.readthedocs.org | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| requests | 2.31.0 | Apache Software License | https://requests.readthedocs.io | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| rfc3339-validator | 0.1.4 | MIT License | https://github.com/naimetti/rfc3339-validator | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| rfc3986-validator | 0.1.1 | MIT License | https://github.com/naimetti/rfc3986-validator | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| rich | 13.4.2 | MIT License | https://github.com/Textualize/rich | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| shellingham | 1.5.0.post1 | ISC License (ISCL) | https://github.com/sarugaku/shellingham | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| six | 1.16.0 | MIT License | https://github.com/benjaminp/six | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| sniffio | 1.3.0 | Apache Software License; MIT License | https://github.com/python-trio/sniffio | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| snowballstemmer | 2.2.0 | BSD License | https://github.com/snowballstem/snowball | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| soupsieve | 2.4.1 | MIT License | https://github.com/facelessuser/soupsieve | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| sphinx-basic-ng | 1.0.0b1 | MIT License | https://github.com/pradyunsg/sphinx-basic-ng | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| sphinx-code-tabs | 0.5.3 | The Unlicense (Unlicense) | https://github.com/coldfix/sphinx-code-tabs | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| sphinx-copybutton | 0.5.2 | MIT License | https://github.com/executablebooks/sphinx-copybutton | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| sphinxcontrib-applehelp | 1.0.4 | BSD License | https://www.sphinx-doc.org/ | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| sphinxcontrib-devhelp | 1.0.2 | BSD License | http://sphinx-doc.org/ | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| sphinxcontrib-htmlhelp | 2.0.1 | BSD License | https://www.sphinx-doc.org/ | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| sphinxcontrib-jsmath | 1.0.1 | BSD License | http://sphinx-doc.org/ | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| sphinxcontrib-qthelp | 1.0.3 | BSD License | http://sphinx-doc.org/ | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| sphinxcontrib-serializinghtml | 1.1.5 | BSD License | http://sphinx-doc.org/ | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| stable-baselines3 | 1.6.2 | MIT | https://github.com/DLR-RM/stable-baselines3 | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| stack-data | 0.6.2 | MIT License | http://github.com/alexmojaki/stack_data | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| sympy | 1.12 | BSD License | https://sympy.org | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| terminado | 0.17.1 | BSD License | https://github.com/jupyter/terminado | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| tinycss2 | 1.2.1 | BSD License | https://www.courtbouillon.org/tinycss2 | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| toml | 0.10.2 | MIT License | https://github.com/uiri/toml | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| tomli | 2.0.1 | MIT License | https://github.com/hukkin/tomli | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| torch | 2.0.1 | BSD License | https://pytorch.org/ | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| tornado | 6.3.2 | Apache Software License | http://www.tornadoweb.org/ | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| traitlets | 5.9.0 | BSD License | https://github.com/ipython/traitlets | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| typer | 0.9.0 | MIT License | https://github.com/tiangolo/typer | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| typing_extensions | 4.6.3 | Python Software Foundation License | https://github.com/python/typing_extensions/issues | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| tzdata | 2023.3 | Apache Software License | https://github.com/python/tzdata | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| uri-template | 1.2.0 | MIT License | https://github.com/plinss/uri_template/ | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| urllib3 | 2.0.3 | MIT License | https://github.com/urllib3/urllib3/blob/main/CHANGES.rst | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| virtualenv | 20.23.0 | MIT License | https://github.com/pypa/virtualenv | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| webcolors | 1.13 | BSD License | UNKNOWN | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| webencodings | 0.5.1 | BSD License | https://github.com/SimonSapin/python-webencodings | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| websocket-client | 1.5.3 | Apache Software License | https://github.com/websocket-client/websocket-client.git | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| y-py | 0.5.9 | MIT License | https://github.com/y-crdt/ypy | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| ypy-websocket | 0.8.2 | UNKNOWN | https://github.com/y-crdt/ypy-websocket | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ -| zipp | 3.15.0 | MIT License | https://github.com/jaraco/zipp | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| Name | Version | License | URL | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| Babel | 2.12.1 | BSD License | https://babel.pocoo.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| GPUtil | 1.4.0 | MIT | https://github.com/anderskm/gputil | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| Gymnasium | 0.26.3 | MIT | https://gymnasium.farama.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| Jinja2 | 3.1.2 | BSD License | https://palletsprojects.com/p/jinja/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| Markdown | 3.4.3 | BSD License | https://Python-Markdown.github.io/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| MarkupSafe | 2.1.2 | BSD License | https://palletsprojects.com/p/markupsafe/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| Pillow | 9.5.0 | Historical Permission Notice and Disclaimer (HPND) | https://python-pillow.org | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| PyWavelets | 1.4.1 | MIT License | https://github.com/PyWavelets/pywt | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| PyYAML | 6.0 | MIT License | https://pyyaml.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| Pygments | 2.15.1 | BSD License | https://pygments.org | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| Send2Trash | 1.8.2 | BSD License | https://github.com/arsenetar/send2trash | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| Sphinx | 6.1.3 | BSD License | https://www.sphinx-doc.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| Werkzeug | 2.3.4 | BSD License | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| absl-py | 1.4.0 | Apache Software License | https://github.com/abseil/abseil-py | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| aiofiles | 22.1.0 | Apache Software License | https://github.com/Tinche/aiofiles | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| aiosignal | 1.3.1 | Apache Software License | https://github.com/aio-libs/aiosignal | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| aiosqlite | 0.19.0 | MIT License | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| alabaster | 0.7.13 | BSD License | https://alabaster.readthedocs.io | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| anyio | 3.7.0 | MIT License | https://anyio.readthedocs.io/en/stable/versionhistory.html | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| argon2-cffi | 21.3.0 | MIT License | https://github.com/hynek/argon2-cffi/blob/main/CHANGELOG.md | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| argon2-cffi-bindings | 21.2.0 | MIT License | https://github.com/hynek/argon2-cffi-bindings | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| arrow | 1.2.3 | Apache Software License | https://arrow.readthedocs.io | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| asttokens | 2.2.1 | Apache 2.0 | https://github.com/gristlabs/asttokens | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| astunparse | 1.6.3 | BSD License | https://github.com/simonpercivall/astunparse | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| attrs | 23.1.0 | MIT License | https://www.attrs.org/en/stable/changelog.html | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| backcall | 0.2.0 | BSD License | https://github.com/takluyver/backcall | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| beautifulsoup4 | 4.12.2 | MIT License | https://www.crummy.com/software/BeautifulSoup/bs4/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| bleach | 6.0.0 | Apache Software License | https://github.com/mozilla/bleach | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| build | 0.10.0 | MIT License | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| cachetools | 5.3.0 | MIT License | https://github.com/tkem/cachetools/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| certifi | 2023.5.7 | Mozilla Public License 2.0 (MPL 2.0) | https://github.com/certifi/python-certifi | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| cffi | 1.15.1 | MIT License | http://cffi.readthedocs.org | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| cfgv | 3.3.1 | MIT License | https://github.com/asottile/cfgv | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| charset-normalizer | 3.1.0 | MIT License | https://github.com/Ousret/charset_normalizer | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| click | 8.1.3 | BSD License | https://palletsprojects.com/p/click/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| cloudpickle | 2.2.1 | BSD License | https://github.com/cloudpipe/cloudpickle | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| colorama | 0.4.6 | BSD License | https://github.com/tartley/colorama | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| comm | 0.1.3 | BSD License | https://github.com/ipython/comm | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| contourpy | 1.0.7 | BSD License | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| coverage | 7.2.6 | Apache Software License | https://github.com/nedbat/coveragepy | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| cycler | 0.11.0 | BSD License | https://github.com/matplotlib/cycler | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| debugpy | 1.6.7 | Eclipse Public License 2.0 (EPL-2.0); MIT License | https://aka.ms/debugpy | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| decorator | 5.1.1 | BSD License | https://github.com/micheles/decorator | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| defusedxml | 0.7.1 | Python Software Foundation License | https://github.com/tiran/defusedxml | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| distlib | 0.3.6 | Python Software Foundation License | https://github.com/pypa/distlib | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| dm-tree | 0.1.8 | Apache Software License | https://github.com/deepmind/tree | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| docutils | 0.19 | BSD License; GNU General Public License (GPL); Public Domain; Python Software Foundation License | https://docutils.sourceforge.io/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| exceptiongroup | 1.1.1 | MIT License | https://github.com/agronholm/exceptiongroup/blob/main/CHANGES.rst | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| executing | 1.2.0 | MIT License | https://github.com/alexmojaki/executing | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| fastjsonschema | 2.17.1 | BSD License | https://github.com/horejsek/python-fastjsonschema | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| filelock | 3.12.0 | The Unlicense (Unlicense) | https://github.com/tox-dev/py-filelock | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| flake8 | 6.0.0 | MIT License | https://github.com/pycqa/flake8 | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| flatbuffers | 23.5.26 | Apache Software License | https://google.github.io/flatbuffers/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| fonttools | 4.39.4 | MIT License | http://github.com/fonttools/fonttools | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| fqdn | 1.5.1 | Mozilla Public License 2.0 (MPL 2.0) | https://github.com/ypcrts/fqdn | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| frozenlist | 1.3.3 | Apache Software License | https://github.com/aio-libs/frozenlist | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| furo | 2023.3.27 | MIT License | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| gast | 0.4.0 | BSD License | https://github.com/serge-sans-paille/gast/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| google-auth | 2.19.0 | Apache Software License | https://github.com/googleapis/google-auth-library-python | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| google-auth-oauthlib | 0.4.6 | Apache Software License | https://github.com/GoogleCloudPlatform/google-auth-library-python-oauthlib | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| google-pasta | 0.2.0 | Apache Software License | https://github.com/google/pasta | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| grpcio | 1.51.3 | Apache Software License | https://grpc.io | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| gym | 0.21.0 | UNKNOWN | https://github.com/openai/gym | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| gymnasium-notices | 0.0.1 | MIT License | https://github.com/Farama-Foundation/gym-notices | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| h5py | 3.9.0 | BSD License | https://www.h5py.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| identify | 2.5.24 | MIT License | https://github.com/pre-commit/identify | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| idna | 3.4 | BSD License | https://github.com/kjd/idna | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| imageio | 2.29.0 | BSD License | https://github.com/imageio/imageio | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| imagesize | 1.4.1 | MIT License | https://github.com/shibukawa/imagesize_py | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| importlib-metadata | 4.13.0 | Apache Software License | https://github.com/python/importlib_metadata | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| iniconfig | 2.0.0 | MIT License | https://github.com/pytest-dev/iniconfig | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| ipykernel | 6.23.1 | BSD License | https://ipython.org | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| ipython | 8.13.2 | BSD License | https://ipython.org | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| ipython-genutils | 0.2.0 | BSD License | http://ipython.org | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| isoduration | 20.11.0 | ISC License (ISCL) | https://github.com/bolsote/isoduration | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| jax | 0.4.12 | Apache-2.0 | https://github.com/google/jax | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| jedi | 0.18.2 | MIT License | https://github.com/davidhalter/jedi | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| json5 | 0.9.14 | Apache Software License | https://github.com/dpranke/pyjson5 | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| jsonpointer | 2.3 | BSD License | https://github.com/stefankoegl/python-json-pointer | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| jsonschema | 4.17.3 | MIT License | https://github.com/python-jsonschema/jsonschema | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| jupyter-events | 0.6.3 | BSD License | http://jupyter.org | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| jupyter-server | 1.24.0 | BSD License | https://jupyter-server.readthedocs.io | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| jupyter-ydoc | 0.2.4 | BSD 3-Clause License | https://jupyter.org | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| jupyter_client | 8.2.0 | BSD License | https://jupyter.org | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| jupyter_core | 5.3.0 | BSD License | https://jupyter.org | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| jupyter_server_fileid | 0.9.0 | BSD License | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| jupyter_server_terminals | 0.4.4 | BSD License | https://jupyter.org | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| jupyter_server_ydoc | 0.6.1 | BSD License | https://jupyter.org | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| jupyterlab | 3.6.1 | BSD License | https://jupyter.org | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| jupyterlab-pygments | 0.2.2 | BSD | https://github.com/jupyterlab/jupyterlab_pygments | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| jupyterlab_server | 2.22.1 | BSD License | https://jupyterlab-server.readthedocs.io | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| keras | 2.12.0 | Apache Software License | https://keras.io/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| kiwisolver | 1.4.4 | BSD License | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| lazy_loader | 0.2 | BSD License | https://github.com/scientific-python/lazy_loader | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| libclang | 16.0.0 | Apache Software License | https://github.com/sighingnow/libclang | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| lz4 | 4.3.2 | BSD License | https://github.com/python-lz4/python-lz4 | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| markdown-it-py | 2.2.0 | MIT License | https://github.com/executablebooks/markdown-it-py | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| matplotlib | 3.7.1 | Python Software Foundation License | https://matplotlib.org | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| matplotlib-inline | 0.1.6 | BSD 3-Clause | https://github.com/ipython/matplotlib-inline | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| mavizstyle | 1.0.0 | UNKNOWN | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| mccabe | 0.7.0 | MIT License | https://github.com/pycqa/mccabe | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| mdurl | 0.1.2 | MIT License | https://github.com/executablebooks/mdurl | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| mistune | 2.0.5 | BSD License | https://github.com/lepture/mistune | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| ml-dtypes | 0.2.0 | Apache Software License | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| mock | 5.0.2 | BSD License | http://mock.readthedocs.org/en/latest/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| mpmath | 1.3.0 | BSD License | http://mpmath.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| msgpack | 1.0.5 | Apache Software License | https://msgpack.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| nbclassic | 0.5.6 | BSD License | https://github.com/jupyter/nbclassic | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| nbclient | 0.8.0 | BSD License | https://jupyter.org | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| nbconvert | 7.4.0 | BSD License | https://jupyter.org | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| nbformat | 5.9.0 | BSD License | https://jupyter.org | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| nest-asyncio | 1.5.6 | BSD License | https://github.com/erdewit/nest_asyncio | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| networkx | 3.1 | BSD License | https://networkx.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| nodeenv | 1.8.0 | BSD License | https://github.com/ekalinin/nodeenv | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| notebook | 6.5.4 | BSD License | http://jupyter.org | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| notebook_shim | 0.2.3 | BSD License | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| numpy | 1.23.5 | BSD License | https://www.numpy.org | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| oauthlib | 3.2.2 | BSD License | https://github.com/oauthlib/oauthlib | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| opt-einsum | 3.3.0 | MIT | https://github.com/dgasmith/opt_einsum | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| overrides | 7.3.1 | Apache License, Version 2.0 | https://github.com/mkorpela/overrides | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| packaging | 23.1 | Apache Software License; BSD License | https://github.com/pypa/packaging | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| pandas | 2.0.1 | BSD License | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| pandocfilters | 1.5.0 | BSD License | http://github.com/jgm/pandocfilters | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| parso | 0.8.3 | MIT License | https://github.com/davidhalter/parso | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| pickleshare | 0.7.5 | MIT License | https://github.com/pickleshare/pickleshare | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| platformdirs | 3.5.1 | MIT License | https://github.com/platformdirs/platformdirs | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| plotly | 5.15.0 | MIT License | https://plotly.com/python/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| pluggy | 1.0.0 | MIT License | https://github.com/pytest-dev/pluggy | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| pre-commit | 2.20.0 | MIT License | https://github.com/pre-commit/pre-commit | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| primaite | 2.0.0rc1 | MIT License | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| primaite | 2.0.0rc1 | MIT License | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| primaite | 2.0.0rc1 | MIT License | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| prometheus-client | 0.17.0 | Apache Software License | https://github.com/prometheus/client_python | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| prompt-toolkit | 3.0.38 | BSD License | https://github.com/prompt-toolkit/python-prompt-toolkit | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| protobuf | 3.20.3 | BSD-3-Clause | https://developers.google.com/protocol-buffers/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| psutil | 5.9.5 | BSD License | https://github.com/giampaolo/psutil | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| pure-eval | 0.2.2 | MIT License | http://github.com/alexmojaki/pure_eval | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| pyasn1 | 0.5.0 | BSD License | https://github.com/pyasn1/pyasn1 | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| pyasn1-modules | 0.3.0 | BSD License | https://github.com/pyasn1/pyasn1-modules | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| pycodestyle | 2.10.0 | MIT License | https://pycodestyle.pycqa.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| pycparser | 2.21 | BSD License | https://github.com/eliben/pycparser | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| pyflakes | 3.0.1 | MIT License | https://github.com/PyCQA/pyflakes | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| pyparsing | 3.0.9 | MIT License | https://github.com/pyparsing/pyparsing/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| pyproject_hooks | 1.0.0 | MIT License | https://github.com/pypa/pyproject-hooks | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| pyrsistent | 0.19.3 | MIT License | https://github.com/tobgu/pyrsistent/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| pytest | 7.2.0 | MIT License | https://docs.pytest.org/en/latest/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| pytest-cov | 4.0.0 | MIT License | https://github.com/pytest-dev/pytest-cov | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| pytest-flake8 | 1.1.1 | BSD License | https://github.com/tholo/pytest-flake8 | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| python-dateutil | 2.8.2 | Apache Software License; BSD License | https://github.com/dateutil/dateutil | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| python-json-logger | 2.0.7 | BSD License | http://github.com/madzak/python-json-logger | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| pytz | 2023.3 | MIT License | http://pythonhosted.org/pytz | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| pywin32 | 306 | Python Software Foundation License | https://github.com/mhammond/pywin32 | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| pywinpty | 2.0.10 | MIT | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| pyzmq | 25.1.0 | BSD License; GNU Library or Lesser General Public License (LGPL) | https://pyzmq.readthedocs.org | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| ray | 2.2.0 | Apache 2.0 | https://github.com/ray-project/ray | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| requests | 2.31.0 | Apache Software License | https://requests.readthedocs.io | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| requests-oauthlib | 1.3.1 | BSD License | https://github.com/requests/requests-oauthlib | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| rfc3339-validator | 0.1.4 | MIT License | https://github.com/naimetti/rfc3339-validator | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| rfc3986-validator | 0.1.1 | MIT License | https://github.com/naimetti/rfc3986-validator | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| rich | 13.3.5 | MIT License | https://github.com/Textualize/rich | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| rsa | 4.9 | Apache Software License | https://stuvel.eu/rsa | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| ruff | 0.0.272 | MIT License | https://github.com/charliermarsh/ruff | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| scikit-image | 0.20.0 | BSD License | https://scikit-image.org | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| scipy | 1.10.1 | BSD License | https://scipy.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| shellingham | 1.5.0.post1 | ISC License (ISCL) | https://github.com/sarugaku/shellingham | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| six | 1.16.0 | MIT License | https://github.com/benjaminp/six | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| sniffio | 1.3.0 | Apache Software License; MIT License | https://github.com/python-trio/sniffio | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| snowballstemmer | 2.2.0 | BSD License | https://github.com/snowballstem/snowball | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| soupsieve | 2.4.1 | MIT License | https://github.com/facelessuser/soupsieve | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| sphinx-basic-ng | 1.0.0b1 | MIT License | https://github.com/pradyunsg/sphinx-basic-ng | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| sphinx-code-tabs | 0.5.3 | The Unlicense (Unlicense) | https://github.com/coldfix/sphinx-code-tabs | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| sphinx-copybutton | 0.5.2 | MIT License | https://github.com/executablebooks/sphinx-copybutton | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| sphinxcontrib-applehelp | 1.0.4 | BSD License | https://www.sphinx-doc.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| sphinxcontrib-devhelp | 1.0.2 | BSD License | http://sphinx-doc.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| sphinxcontrib-htmlhelp | 2.0.1 | BSD License | https://www.sphinx-doc.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| sphinxcontrib-jsmath | 1.0.1 | BSD License | http://sphinx-doc.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| sphinxcontrib-qthelp | 1.0.3 | BSD License | http://sphinx-doc.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| sphinxcontrib-serializinghtml | 1.1.5 | BSD License | http://sphinx-doc.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| stable-baselines3 | 1.6.2 | MIT | https://github.com/DLR-RM/stable-baselines3 | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| stack-data | 0.6.2 | MIT License | http://github.com/alexmojaki/stack_data | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| sympy | 1.12 | BSD License | https://sympy.org | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tabulate | 0.9.0 | MIT License | https://github.com/astanin/python-tabulate | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tenacity | 8.2.2 | Apache Software License | https://github.com/jd/tenacity | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tensorboard | 2.11.2 | Apache Software License | https://github.com/tensorflow/tensorboard | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tensorboard-data-server | 0.6.1 | Apache Software License | https://github.com/tensorflow/tensorboard/tree/master/tensorboard/data/server | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tensorboard-plugin-wit | 1.8.1 | Apache 2.0 | https://whatif-tool.dev | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tensorboardX | 2.6 | MIT License | https://github.com/lanpa/tensorboardX | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tensorflow | 2.12.0 | Apache Software License | https://www.tensorflow.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tensorflow-estimator | 2.12.0 | Apache Software License | https://www.tensorflow.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tensorflow-intel | 2.12.0 | Apache Software License | https://www.tensorflow.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tensorflow-io-gcs-filesystem | 0.31.0 | Apache Software License | https://github.com/tensorflow/io | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| termcolor | 2.3.0 | MIT License | https://github.com/termcolor/termcolor | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| terminado | 0.17.1 | BSD License | https://github.com/jupyter/terminado | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tifffile | 2023.4.12 | BSD License | https://www.cgohlke.com | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tinycss2 | 1.2.1 | BSD License | https://www.courtbouillon.org/tinycss2 | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| toml | 0.10.2 | MIT License | https://github.com/uiri/toml | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tomli | 2.0.1 | MIT License | https://github.com/hukkin/tomli | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| torch | 2.0.1 | BSD License | https://pytorch.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tornado | 6.3.2 | Apache Software License | http://www.tornadoweb.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| traitlets | 5.9.0 | BSD License | https://github.com/ipython/traitlets | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| typer | 0.9.0 | MIT License | https://github.com/tiangolo/typer | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| typing_extensions | 4.6.2 | Python Software Foundation License | https://github.com/python/typing_extensions/issues | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tzdata | 2023.3 | Apache Software License | https://github.com/python/tzdata | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| uri-template | 1.2.0 | MIT License | https://github.com/plinss/uri_template/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| urllib3 | 1.26.16 | MIT License | https://urllib3.readthedocs.io/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| virtualenv | 20.21.0 | MIT License | https://github.com/pypa/virtualenv | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| webcolors | 1.13 | BSD License | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| webencodings | 0.5.1 | BSD License | https://github.com/SimonSapin/python-webencodings | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| websocket-client | 1.5.2 | Apache Software License | https://github.com/websocket-client/websocket-client.git | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| wrapt | 1.14.1 | BSD License | https://github.com/GrahamDumpleton/wrapt | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| y-py | 0.5.9 | MIT License | https://github.com/y-crdt/ypy | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| ypy-websocket | 0.8.2 | UNKNOWN | https://github.com/y-crdt/ypy-websocket | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| zipp | 3.15.0 | MIT License | https://github.com/jaraco/zipp | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ diff --git a/pyproject.toml b/pyproject.toml index 7ddf7710..86418eaa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "primaite" description = "PrimAITE (Primary-level AI Training Environment) is a simulation environment for training AI under the ARCD programme." authors = [{name="QinetiQ Training and Simulation Ltd"}] license = {text = "MIT License"} -requires-python = ">=3.8" +requires-python = ">=3.8, <3.11" dynamic = ["version", "readme"] classifiers = [ "License :: MIT License", @@ -20,19 +20,23 @@ classifiers = [ "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3 :: Only", ] dependencies = [ "gym==0.21.0", "jupyterlab==3.6.1", + "kaleido==0.2.1", "matplotlib==3.7.1", "networkx==3.1", "numpy==1.23.5", "platformdirs==3.5.1", + "plotly==5.15.0", + "polars==0.18.4", "PyYAML==6.0", + "ray[rllib]==2.2.0", "stable-baselines3==1.6.2", + "tensorflow==2.12.0", "typer[all]==0.9.0" ] @@ -62,9 +66,15 @@ dev = [ "sphinx-copybutton==0.5.2", "wheel==0.38.4" ] -tensorflow = [ - "tensorflow==2.12.0", -] [project.scripts] primaite = "primaite.cli:app" + +[tool.isort] +profile = "black" +line_length = 120 +force_sort_within_sections = "False" +order_by_type = "False" + +[tool.black] +line-length = 120 diff --git a/src/primaite/VERSION b/src/primaite/VERSION index bd82b28c..4111d137 100644 --- a/src/primaite/VERSION +++ b/src/primaite/VERSION @@ -1 +1 @@ -2.0.0dev0 +2.0.0rc1 diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index 420420f4..030860d8 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -2,10 +2,11 @@ import logging import logging.config import sys -from logging import Logger, StreamHandler +from bisect import bisect +from logging import Formatter, Logger, LogRecord, StreamHandler from logging.handlers import RotatingFileHandler from pathlib import Path -from typing import Final +from typing import Dict, Final import pkg_resources import yaml @@ -18,11 +19,7 @@ _PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite") def _get_primaite_config(): config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml" if not config_path.exists(): - config_path = Path( - pkg_resources.resource_filename( - "primaite", "setup/_package_data/primaite_config.yaml" - ) - ) + config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml")) with open(config_path, "r") as file: primaite_config = yaml.safe_load(file) log_level_map = { @@ -33,7 +30,7 @@ def _get_primaite_config(): "ERROR": logging.ERROR, "CRITICAL": logging.CRITICAL, } - primaite_config["log_level"] = log_level_map[primaite_config["log_level"]] + primaite_config["log_level"] = log_level_map[primaite_config["logging"]["log_level"]] return primaite_config @@ -68,6 +65,28 @@ Users PrimAITE Sessions are stored at: ``~/primaite/sessions``. # region Setup Logging +class _LevelFormatter(Formatter): + """ + A custom level-specific formatter. + + Credit to: https://stackoverflow.com/a/68154386 + """ + + def __init__(self, formats: Dict[int, str], **kwargs): + super().__init__() + + if "fmt" in kwargs: + raise ValueError("Format string must be passed to level-surrogate formatters, " "not this one") + + self.formats = sorted((level, Formatter(fmt, **kwargs)) for level, fmt in formats.items()) + + def format(self, record: LogRecord) -> str: + """Overrides ``Formatter.format``.""" + idx = bisect(self.formats, (record.levelno,), hi=len(self.formats) - 1) + level, formatter = self.formats[idx] + return formatter.format(record) + + def _log_dir() -> Path: if sys.platform == "win32": dir_path = _PLATFORM_DIRS.user_data_path / "logs" @@ -76,6 +95,16 @@ def _log_dir() -> Path: return dir_path +_LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter( + { + logging.DEBUG: _PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"], + logging.INFO: _PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"], + logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"], + logging.ERROR: _PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"], + logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"], + } +) + LOG_DIR: Final[Path] = _log_dir() """The path to the app log directory as an instance of `Path` or `PosixPath`, depending on the OS.""" @@ -85,18 +114,19 @@ LOG_PATH: Final[Path] = LOG_DIR / "primaite.log" """The primaite.log file path as an instance of `Path` or `PosixPath`, depending on the OS.""" _STREAM_HANDLER: Final[StreamHandler] = StreamHandler() + _FILE_HANDLER: Final[RotatingFileHandler] = RotatingFileHandler( filename=LOG_PATH, maxBytes=10485760, # 10MB backupCount=9, # Max 100MB of logs encoding="utf8", ) -_STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["log_level"]) -_FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["log_level"]) +_STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"]) +_FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"]) -_LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logger_format"] -_STREAM_HANDLER.setFormatter(logging.Formatter(_LOG_FORMAT_STR)) -_FILE_HANDLER.setFormatter(logging.Formatter(_LOG_FORMAT_STR)) +_LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logging"]["logger_format"] +_STREAM_HANDLER.setFormatter(_LEVEL_FORMATTER) +_FILE_HANDLER.setFormatter(_LEVEL_FORMATTER) _LOGGER = logging.getLogger(__name__) @@ -104,7 +134,7 @@ _LOGGER.addHandler(_STREAM_HANDLER) _LOGGER.addHandler(_FILE_HANDLER) -def getLogger(name: str) -> Logger: +def getLogger(name: str) -> Logger: # noqa """ Get a PrimAITE logger. diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index 284ed764..3b0e9234 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -25,18 +25,9 @@ class AccessControlList: True if match; False otherwise. """ if ( - ( - _rule.get_source_ip() == _source_ip_address - and _rule.get_dest_ip() == _dest_ip_address - ) - or ( - _rule.get_source_ip() == "ANY" - and _rule.get_dest_ip() == _dest_ip_address - ) - or ( - _rule.get_source_ip() == _source_ip_address - and _rule.get_dest_ip() == "ANY" - ) + (_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address) + or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == _dest_ip_address) + or (_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == "ANY") or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY") ): return True @@ -57,15 +48,9 @@ class AccessControlList: Indicates block if all conditions are satisfied. """ for rule_key, rule_value in self.acl.items(): - if self.check_address_match( - rule_value, _source_ip_address, _dest_ip_address - ): - if ( - rule_value.get_protocol() == _protocol - or rule_value.get_protocol() == "ANY" - ) and ( - str(rule_value.get_port()) == str(_port) - or rule_value.get_port() == "ANY" + if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address): + if (rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY") and ( + str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY" ): # There's a matching rule. Get the permission if rule_value.get_permission() == "DENY": diff --git a/src/primaite/acl/acl_rule.py b/src/primaite/acl/acl_rule.py index ef631a70..05daecc4 100644 --- a/src/primaite/acl/acl_rule.py +++ b/src/primaite/acl/acl_rule.py @@ -30,7 +30,13 @@ class ACLRule: Returns hash of core parameters. """ return hash( - (self.permission, self.source_ip, self.dest_ip, self.protocol, self.port) + ( + self.permission, + self.source_ip, + self.dest_ip, + self.protocol, + self.port, + ) ) def get_permission(self): diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py new file mode 100644 index 00000000..685fe776 --- /dev/null +++ b/src/primaite/agents/agent.py @@ -0,0 +1,383 @@ +from __future__ import annotations + +import json +import time +from abc import ABC, abstractmethod +from datetime import datetime +from pathlib import Path +from typing import Dict, Final, Union +from uuid import uuid4 + +import yaml + +import primaite +from primaite import getLogger, SESSIONS_DIR +from primaite.config import lay_down_config, training_config +from primaite.config.training_config import TrainingConfig +from primaite.data_viz.session_plots import plot_av_reward_per_episode +from primaite.environment.primaite_env import Primaite + +_LOGGER = getLogger(__name__) + + +def get_session_path(session_timestamp: datetime) -> Path: + """ + Get the directory path the session will output to. + + This is set in the format of: + ~/primaite/sessions//_. + + :param session_timestamp: This is the datetime that the session started. + :return: The session directory path. + """ + date_dir = session_timestamp.strftime("%Y-%m-%d") + session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + session_path = SESSIONS_DIR / date_dir / session_path + session_path.mkdir(exist_ok=True, parents=True) + + return session_path + + +class AgentSessionABC(ABC): + """ + An ABC that manages training and/or evaluation of agents in PrimAITE. + + This class cannot be directly instantiated and must be inherited from + with all implemented abstract methods implemented. + """ + + @abstractmethod + def __init__(self, training_config_path, lay_down_config_path): + if not isinstance(training_config_path, Path): + training_config_path = Path(training_config_path) + self._training_config_path: Final[Union[Path, str]] = training_config_path + self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path) + + if not isinstance(lay_down_config_path, Path): + lay_down_config_path = Path(lay_down_config_path) + self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path + self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) + self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level + + self._env: Primaite + self._agent = None + self._can_learn: bool = False + self._can_evaluate: bool = False + self.is_eval = False + + self._uuid = str(uuid4()) + self.session_timestamp: datetime = datetime.now() + "The session timestamp" + self.session_path = get_session_path(self.session_timestamp) + "The Session path" + + @property + def timestamp_str(self) -> str: + """The session timestamp as a string.""" + return self.session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + + @property + def learning_path(self) -> Path: + """The learning outputs path.""" + path = self.session_path / "learning" + path.mkdir(exist_ok=True, parents=True) + return path + + @property + def evaluation_path(self) -> Path: + """The evaluation outputs path.""" + path = self.session_path / "evaluation" + path.mkdir(exist_ok=True, parents=True) + return path + + @property + def checkpoints_path(self) -> Path: + """The Session checkpoints path.""" + path = self.learning_path / "checkpoints" + path.mkdir(exist_ok=True, parents=True) + return path + + @property + def uuid(self): + """The Agent Session UUID.""" + return self._uuid + + def _write_session_metadata_file(self): + """ + Write the ``session_metadata.json`` file. + + Creates a ``session_metadata.json`` in the ``session_path`` directory + and adds the following key/value pairs: + + - uuid: The UUID assigned to the session upon instantiation. + - start_datetime: The date & time the session started in iso format. + - end_datetime: NULL. + - total_episodes: NULL. + - total_time_steps: NULL. + - env: + - training_config: + - All training config items + - lay_down_config: + - All lay down config items + + """ + metadata_dict = { + "uuid": self.uuid, + "start_datetime": self.session_timestamp.isoformat(), + "end_datetime": None, + "learning": {"total_episodes": None, "total_time_steps": None}, + "evaluation": {"total_episodes": None, "total_time_steps": None}, + "env": { + "training_config": self._training_config.to_dict(json_serializable=True), + "lay_down_config": self._lay_down_config, + }, + } + filepath = self.session_path / "session_metadata.json" + _LOGGER.debug(f"Writing Session Metadata file: {filepath}") + with open(filepath, "w") as file: + json.dump(metadata_dict, file) + _LOGGER.debug("Finished writing session metadata file") + + def _update_session_metadata_file(self): + """ + Update the ``session_metadata.json`` file. + + Updates the `session_metadata.json`` in the ``session_path`` directory + with the following key/value pairs: + + - end_datetime: The date & time the session ended in iso format. + - total_episodes: The total number of training episodes completed. + - total_time_steps: The total number of training time steps completed. + """ + with open(self.session_path / "session_metadata.json", "r") as file: + metadata_dict = json.load(file) + + metadata_dict["end_datetime"] = datetime.now().isoformat() + + if not self.is_eval: + metadata_dict["learning"]["total_episodes"] = self._env.episode_count # noqa + metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa + else: + metadata_dict["evaluation"]["total_episodes"] = self._env.episode_count # noqa + metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa + + filepath = self.session_path / "session_metadata.json" + _LOGGER.debug(f"Updating Session Metadata file: {filepath}") + with open(filepath, "w") as file: + json.dump(metadata_dict, file) + _LOGGER.debug("Finished updating session metadata file") + + @abstractmethod + def _setup(self): + _LOGGER.info( + "Welcome to the Primary-level AI Training Environment " f"(PrimAITE) (version: {primaite.__version__})" + ) + _LOGGER.info(f"The output directory for this session is: {self.session_path}") + self._write_session_metadata_file() + self._can_learn = True + self._can_evaluate = False + + @abstractmethod + def _save_checkpoint(self): + pass + + @abstractmethod + def learn( + self, + **kwargs, + ): + """ + Train the agent. + + :param kwargs: Any agent-specific key-word args to be passed. + """ + if self._can_learn: + _LOGGER.info("Finished learning") + _LOGGER.debug("Writing transactions") + self._update_session_metadata_file() + self._can_evaluate = True + self.is_eval = False + self._plot_av_reward_per_episode(learning_session=True) + + @abstractmethod + def evaluate( + self, + **kwargs, + ): + """ + Evaluate the agent. + + :param kwargs: Any agent-specific key-word args to be passed. + """ + self._env.set_as_eval() # noqa + self.is_eval = True + self._plot_av_reward_per_episode(learning_session=False) + _LOGGER.info("Finished evaluation") + + @abstractmethod + def _get_latest_checkpoint(self): + pass + + @classmethod + @abstractmethod + def load(cls, path: Union[str, Path]) -> AgentSessionABC: + """Load an agent from file.""" + if not isinstance(path, Path): + path = Path(path) + + if path.exists(): + # Unpack the session_metadata.json file + md_file = path / "session_metadata.json" + with open(md_file, "r") as file: + md_dict = json.load(file) + + # Create a temp directory and dump the training and lay down + # configs into it + temp_dir = path / ".temp" + temp_dir.mkdir(exist_ok=True) + + temp_tc = temp_dir / "tc.yaml" + with open(temp_tc, "w") as file: + yaml.dump(md_dict["env"]["training_config"], file) + + temp_ldc = temp_dir / "ldc.yaml" + with open(temp_ldc, "w") as file: + yaml.dump(md_dict["env"]["lay_down_config"], file) + + agent = cls(temp_tc, temp_ldc) + + agent.session_path = path + + return agent + + else: + # Session path does not exist + msg = f"Failed to load PrimAITE Session, path does not exist: {path}" + _LOGGER.error(msg) + raise FileNotFoundError(msg) + pass + + @abstractmethod + def save(self): + """Save the agent.""" + self._agent.save(self.session_path) + + @abstractmethod + def export(self): + """Export the agent to transportable file format.""" + pass + + def close(self): + """Closes the agent.""" + self._env.episode_av_reward_writer.close() # noqa + self._env.transaction_writer.close() # noqa + + def _plot_av_reward_per_episode(self, learning_session: bool = True): + # self.close() + title = f"PrimAITE Session {self.timestamp_str} " + subtitle = str(self._training_config) + csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" + image_file = f"average_reward_per_episode_{self.timestamp_str}.png" + if learning_session: + title += "(Learning)" + path = self.learning_path / csv_file + image_path = self.learning_path / image_file + else: + title += "(Evaluation)" + path = self.evaluation_path / csv_file + image_path = self.evaluation_path / image_file + + fig = plot_av_reward_per_episode(path, title, subtitle) + fig.write_image(image_path) + _LOGGER.debug(f"Saved average rewards per episode plot to: {path}") + + +class HardCodedAgentSessionABC(AgentSessionABC): + """ + An Agent Session ABC for evaluation deterministic agents. + + This class cannot be directly instantiated and must be inherited from + with all implemented abstract methods implemented. + """ + + def __init__(self, training_config_path, lay_down_config_path): + super().__init__(training_config_path, lay_down_config_path) + self._setup() + + def _setup(self): + self._env: Primaite = Primaite( + training_config_path=self._training_config_path, + lay_down_config_path=self._lay_down_config_path, + session_path=self.session_path, + timestamp_str=self.timestamp_str, + ) + super()._setup() + self._can_learn = False + self._can_evaluate = True + + def _save_checkpoint(self): + pass + + def _get_latest_checkpoint(self): + pass + + def learn( + self, + **kwargs, + ): + """ + Train the agent. + + :param kwargs: Any agent-specific key-word args to be passed. + """ + _LOGGER.warning("Deterministic agents cannot learn") + + @abstractmethod + def _calculate_action(self, obs): + pass + + def evaluate( + self, + **kwargs, + ): + """ + Evaluate the agent. + + :param kwargs: Any agent-specific key-word args to be passed. + """ + self._env.set_as_eval() # noqa + self.is_eval = True + + time_steps = self._training_config.num_steps + episodes = self._training_config.num_episodes + + obs = self._env.reset() + for episode in range(episodes): + # Reset env and collect initial observation + for step in range(time_steps): + # Calculate action + action = self._calculate_action(obs) + + # Perform the step + obs, reward, done, info = self._env.step(action) + + if done: + break + + # Introduce a delay between steps + time.sleep(self._training_config.time_delay / 1000) + obs = self._env.reset() + self._env.close() + + @classmethod + def load(cls): + """Load an agent from file.""" + _LOGGER.warning("Deterministic agents cannot be loaded") + + def save(self): + """Save the agent.""" + _LOGGER.warning("Deterministic agents cannot be saved") + + def export(self): + """Export the agent to transportable file format.""" + _LOGGER.warning("Deterministic agents cannot be exported") diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py new file mode 100644 index 00000000..263ccbdc --- /dev/null +++ b/src/primaite/agents/hardcoded_acl.py @@ -0,0 +1,431 @@ +import numpy as np + +from primaite.agents.agent import HardCodedAgentSessionABC +from primaite.agents.utils import ( + get_new_action, + get_node_of_ip, + transform_action_acl_enum, + transform_change_obs_readable, +) +from primaite.common.enums import HardCodedAgentView + + +class HardCodedACLAgent(HardCodedAgentSessionABC): + """An Agent Session class that implements a deterministic ACL agent.""" + + def _calculate_action(self, obs): + if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC: + # Basic view action using only the current observation + return self._calculate_action_basic_view(obs) + else: + # full view action using observation space, action + # history and reward feedback + return self._calculate_action_full_view(obs) + + def get_blocked_green_iers(self, green_iers, acl, nodes): + """ + Get blocked green IERs. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + blocked_green_iers = {} + + for green_ier_id, green_ier in green_iers.items(): + source_node_id = green_ier.get_source_node_id() + source_node_address = nodes[source_node_id].ip_address + dest_node_id = green_ier.get_dest_node_id() + dest_node_address = nodes[dest_node_id].ip_address + protocol = green_ier.get_protocol() # e.g. 'TCP' + port = green_ier.get_port() + + # Can be blocked by an ACL or by default (no allow rule exists) + if acl.is_blocked(source_node_address, dest_node_address, protocol, port): + blocked_green_iers[green_ier_id] = green_ier + + return blocked_green_iers + + def get_matching_acl_rules_for_ier(self, ier, acl, nodes): + """ + Get matching ACL rules for an IER. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + source_node_id = ier.get_source_node_id() + source_node_address = nodes[source_node_id].ip_address + dest_node_id = ier.get_dest_node_id() + dest_node_address = nodes[dest_node_id].ip_address + protocol = ier.get_protocol() # e.g. 'TCP' + port = ier.get_port() + + matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port) + return matching_rules + + def get_blocking_acl_rules_for_ier(self, ier, acl, nodes): + """ + Get blocking ACL rules for an IER. + + .. warning:: + Can return empty dict but IER can still be blocked by default + (No ALLOW rule, therefore blocked). + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes) + + blocked_rules = {} + for rule_key, rule_value in matching_rules.items(): + if rule_value.get_permission() == "DENY": + blocked_rules[rule_key] = rule_value + + return blocked_rules + + def get_allow_acl_rules_for_ier(self, ier, acl, nodes): + """ + Get all allowing ACL rules for an IER. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes) + + allowed_rules = {} + for rule_key, rule_value in matching_rules.items(): + if rule_value.get_permission() == "ALLOW": + allowed_rules[rule_key] = rule_value + + return allowed_rules + + def get_matching_acl_rules( + self, + source_node_id, + dest_node_id, + protocol, + port, + acl, + nodes, + services_list, + ): + """ + Get matching ACL rules. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + if source_node_id != "ANY": + source_node_address = nodes[str(source_node_id)].ip_address + else: + source_node_address = source_node_id + + if dest_node_id != "ANY": + dest_node_address = nodes[str(dest_node_id)].ip_address + else: + dest_node_address = dest_node_id + + if protocol != "ANY": + protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services + + matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port) + return matching_rules + + def get_allow_acl_rules( + self, + source_node_id, + dest_node_id, + protocol, + port, + acl, + nodes, + services_list, + ): + """ + Get the ALLOW ACL rules. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + matching_rules = self.get_matching_acl_rules( + source_node_id, + dest_node_id, + protocol, + port, + acl, + nodes, + services_list, + ) + + allowed_rules = {} + for rule_key, rule_value in matching_rules.items(): + if rule_value.get_permission() == "ALLOW": + allowed_rules[rule_key] = rule_value + + return allowed_rules + + def get_deny_acl_rules( + self, + source_node_id, + dest_node_id, + protocol, + port, + acl, + nodes, + services_list, + ): + """ + Get the DENY ACL rules. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + matching_rules = self.get_matching_acl_rules( + source_node_id, + dest_node_id, + protocol, + port, + acl, + nodes, + services_list, + ) + + allowed_rules = {} + for rule_key, rule_value in matching_rules.items(): + if rule_value.get_permission() == "DENY": + allowed_rules[rule_key] = rule_value + + return allowed_rules + + def _calculate_action_full_view(self, obs): + """ + Calculate a good acl-based action for the blue agent to take. + + Knowledge of just the observation space is insufficient for a perfect solution, as we need to know: + + - Which ACL rules already exist, - otherwise: + - The agent would perminently get stuck in a loop of performing the same action over and over. + (best action is to block something, but its already blocked but doesn't know this) + - The agent would be unable to interact with existing rules (e.g. how would it know to delete a rule, + if it doesnt know what rules exist) + - The Green IERs (optional) - It often needs to know which traffic it should be allowing. For example + in the default config one of the green IERs is blocked by default, but it has no way of knowing this + based on the observation space. Additionally, potentially in the future, once a node state + has been fixed (no longer compromised), it needs a way to know it should reallow traffic. + A RL agent can learn what the green IERs are on its own - but the rule based agent cannot easily do this. + + There doesn't seem like there's much that can be done if an Operating or OS State is compromised + + If a service node becomes compromised there's a decision to make - do we block that service? + Pros: It cannot launch an attack on another node, so the node will not be able to be OVERWHELMED + Cons: Will block a green IER, decreasing the reward + We decide to block the service. + + Potentially a better solution (for the reward) would be to block the incomming traffic from compromised + nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing + an overwhelmed state, so we don't do this. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + # obs = convert_to_old_obs(obs) + r_obs = transform_change_obs_readable(obs) + _, _, _, *s = r_obs + + if len(r_obs) == 4: # only 1 service + s = [*s] + + # 1. Check if node is compromised. If so we want to block its outwards services + # a. If it is comprimised check if there's an allow rule we should delete. + # cons: might delete a multi-rule from any source node (ANY -> x) + # b. OPTIONAL (Deny rules not needed): Check if there already exists an existing Deny Rule so not to duplicate + # c. OPTIONAL (no allow rule = blocked): Add a DENY rule + found_action = False + for service_num, service_states in enumerate(s): + for x, service_state in enumerate(service_states): + if service_state == "COMPROMISED": + action_source_id = x + 1 # +1 as 0 is any + action_destination_id = "ANY" + action_protocol = service_num + 1 # +1 as 0 is any + action_port = "ANY" + + allow_rules = self.get_allow_acl_rules( + action_source_id, + action_destination_id, + action_protocol, + action_port, + self._env.acl, + self._env.nodes, + self._env.services_list, + ) + deny_rules = self.get_deny_acl_rules( + action_source_id, + action_destination_id, + action_protocol, + action_port, + self._env.acl, + self._env.nodes, + self._env.services_list, + ) + if len(allow_rules) > 0: + # Check if there's an allow rule we should delete + rule = list(allow_rules.values())[0] + action_decision = "DELETE" + action_permission = "ALLOW" + action_source_ip = rule.get_source_ip() + action_source_id = int(get_node_of_ip(action_source_ip, self._env.nodes)) + action_destination_ip = rule.get_dest_ip() + action_destination_id = int(get_node_of_ip(action_destination_ip, self._env.nodes)) + action_protocol_name = rule.get_protocol() + action_protocol = ( + self._env.services_list.index(action_protocol_name) + 1 + ) # convert name e.g. 'TCP' to index + action_port_name = rule.get_port() + action_port = ( + self._env.ports_list.index(action_port_name) + 1 + ) # convert port name e.g. '80' to index + + found_action = True + break + elif len(deny_rules) > 0: + # TODO OPTIONAL + # If there's already a DENY RULE, that blocks EVERYTHING from the source ip we don't need + # to create another + # Check to see if the DENY rule really blocks everything (ANY) or just a specific rule + continue + else: + # TODO OPTIONAL: Add a DENY rule, optional as by default no allow rule == blocked + action_decision = "CREATE" + action_permission = "DENY" + break + if found_action: + break + + # 2. If NO Node is Comprimised, or the node has already been blocked, check the green IERs and + # add an Allow rule if the green IER is being blocked. + # a. OPTIONAL - NOT IMPLEMENTED (optional as a deny rule does not overwrite an allow rule): + # If there's a DENY rule delete it if: + # - There isn't already a deny rule + # - It doesnt allows a comprimised node to become operational. + # b. Add an ALLOW rule if: + # - There isn't already an allow rule + # - It doesnt allows a comprimised node to become operational + + if not found_action: + # Which Green IERS are blocked + blocked_green_iers = self.get_blocked_green_iers(self._env.green_iers, self._env.acl, self._env.nodes) + for ier_key, ier in blocked_green_iers.items(): + # Which ALLOW rules are allowing this IER (none) + allowing_rules = self.get_allow_acl_rules_for_ier(ier, self._env.acl, self._env.nodes) + + # If there are no blocking rules, it may be being blocked by default + # If there is already an allow rule + node_id_to_check = int(ier.get_source_node_id()) + service_name_to_check = ier.get_protocol() + service_id_to_check = self._env.services_list.index(service_name_to_check) + + # Service state of the the source node in the ier + service_state = s[service_id_to_check][node_id_to_check - 1] + + if len(allowing_rules) == 0 and service_state != "COMPROMISED": + action_decision = "CREATE" + action_permission = "ALLOW" + action_source_id = int(ier.get_source_node_id()) + action_destination_id = int(ier.get_dest_node_id()) + action_protocol_name = ier.get_protocol() + action_protocol = ( + self._env.services_list.index(action_protocol_name) + 1 + ) # convert name e.g. 'TCP' to index + action_port_name = ier.get_port() + action_port = ( + self._env.ports_list.index(action_port_name) + 1 + ) # convert port name e.g. '80' to index + + found_action = True + break + + if found_action: + action = [ + action_decision, + action_permission, + action_source_id, + action_destination_id, + action_protocol, + action_port, + ] + action = transform_action_acl_enum(action) + action = get_new_action(action, self._env.action_dict) + else: + # If no good/useful action has been found, just perform a nothing action + action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"] + action = transform_action_acl_enum(action) + action = get_new_action(action, self._env.action_dict) + return action + + def _calculate_action_basic_view(self, obs): + """Calculate a good acl-based action for the blue agent to take. + + Uses ONLY information from the current observation with NO knowledge + of previous actions taken and NO reward feedback. + + We rely on randomness to select the precise action, as we want to + block all traffic originating from a compromised node, without being + able to tell: + 1. Which ACL rules already exist + 2. Which actions the agent has already tried. + + There is a high probability that the correct rule will not be deleted + before the state becomes overwhelmed. + + Currently, a deny rule does not overwrite an allow rule. The allow + rules must be deleted. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + action_dict = self._env.action_dict + r_obs = transform_change_obs_readable(obs) + _, o, _, *s = r_obs + + if len(r_obs) == 4: # only 1 service + s = [*s] + + number_of_nodes = len([i for i in o if i != "NONE"]) # number of nodes (not links) + for service_num, service_states in enumerate(s): + comprimised_states = [n for n, i in enumerate(service_states) if i == "COMPROMISED"] + if len(comprimised_states) == 0: + # No states are COMPROMISED, try the next service + continue + + compromised_node = np.random.choice(comprimised_states) + 1 # +1 as 0 would be any + action_decision = "DELETE" + action_permission = "ALLOW" + action_source_ip = compromised_node + # Randomly select a destination ID to block + action_destination_ip = np.random.choice(list(range(1, number_of_nodes + 1)) + ["ANY"]) + action_destination_ip = ( + int(action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip + ) + action_protocol = service_num + 1 # +1 as 0 is any + # Randomly select a port to block + # Bad assumption that number of protocols equals number of ports + # AND no rules exist with an ANY port + action_port = np.random.choice(list(range(1, len(s) + 1))) + + action = [ + action_decision, + action_permission, + action_source_ip, + action_destination_ip, + action_protocol, + action_port, + ] + action = transform_action_acl_enum(action) + action = get_new_action(action, action_dict) + # We can only perform 1 action on each step + return action + + # If no good/useful action has been found, just perform a nothing action + nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"] + nothing_action = transform_action_acl_enum(nothing_action) + nothing_action = get_new_action(nothing_action, action_dict) + return nothing_action diff --git a/src/primaite/agents/hardcoded_node.py b/src/primaite/agents/hardcoded_node.py new file mode 100644 index 00000000..310fc178 --- /dev/null +++ b/src/primaite/agents/hardcoded_node.py @@ -0,0 +1,119 @@ +from primaite.agents.agent import HardCodedAgentSessionABC +from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable + + +class HardCodedNodeAgent(HardCodedAgentSessionABC): + """An Agent Session class that implements a deterministic Node agent.""" + + def _calculate_action(self, obs): + """ + Calculate a good node-based action for the blue agent to take. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + action_dict = self._env.action_dict + r_obs = transform_change_obs_readable(obs) + _, o, os, *s = r_obs + + if len(r_obs) == 4: # only 1 service + s = [*s] + + # Check in order of most important states (order doesn't currently + # matter, but it probably should) + # First see if any OS states are compromised + for x, os_state in enumerate(os): + if os_state == "COMPROMISED": + action_node_id = x + 1 + action_node_property = "OS" + property_action = "PATCHING" + action_service_index = 0 # does nothing isn't relevant for os + action = [ + action_node_id, + action_node_property, + property_action, + action_service_index, + ] + action = transform_action_node_enum(action) + action = get_new_action(action, action_dict) + # We can only perform 1 action on each step + return action + + # Next, see if any Services are compromised + # We fix the compromised state before overwhelemd state, + # If a compromised entry node is fixed before the overwhelmed state is triggered, instruction is ignored + for service_num, service in enumerate(s): + for x, service_state in enumerate(service): + if service_state == "COMPROMISED": + action_node_id = x + 1 + action_node_property = "SERVICE" + property_action = "PATCHING" + action_service_index = service_num + + action = [ + action_node_id, + action_node_property, + property_action, + action_service_index, + ] + action = transform_action_node_enum(action) + action = get_new_action(action, action_dict) + # We can only perform 1 action on each step + return action + + # Next, See if any services are overwhelmed + # perhaps this should be fixed automatically when the compromised PCs issues are also resolved + # Currently there's no reason that an Overwhelmed state cannot be resolved before resolving the compromised PCs + + for service_num, service in enumerate(s): + for x, service_state in enumerate(service): + if service_state == "OVERWHELMED": + action_node_id = x + 1 + action_node_property = "SERVICE" + property_action = "PATCHING" + action_service_index = service_num + + action = [ + action_node_id, + action_node_property, + property_action, + action_service_index, + ] + action = transform_action_node_enum(action) + action = get_new_action(action, action_dict) + # We can only perform 1 action on each step + return action + + # Finally, turn on any off nodes + for x, operating_state in enumerate(o): + if os_state == "OFF": + action_node_id = x + 1 + action_node_property = "OPERATING" + property_action = "ON" # Why reset it when we can just turn it on + action_service_index = 0 # does nothing isn't relevant for operating state + action = [ + action_node_id, + action_node_property, + property_action, + action_service_index, + ] + action = transform_action_node_enum(action, action_dict) + action = get_new_action(action, action_dict) + # We can only perform 1 action on each step + return action + + # If no good actions, just go with an action that wont do any harm + action_node_id = 1 + action_node_property = "NONE" + property_action = "ON" + action_service_index = 0 + action = [ + action_node_id, + action_node_property, + property_action, + action_service_index, + ] + action = transform_action_node_enum(action) + action = get_new_action(action, action_dict) + + return action diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py new file mode 100644 index 00000000..d851ba9c --- /dev/null +++ b/src/primaite/agents/rllib.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +import json +from datetime import datetime +from pathlib import Path +from typing import Union + +from ray.rllib.algorithms import Algorithm +from ray.rllib.algorithms.a2c import A2CConfig +from ray.rllib.algorithms.ppo import PPOConfig +from ray.tune.logger import UnifiedLogger +from ray.tune.registry import register_env + +from primaite import getLogger +from primaite.agents.agent import AgentSessionABC +from primaite.common.enums import AgentFramework, AgentIdentifier +from primaite.environment.primaite_env import Primaite + +_LOGGER = getLogger(__name__) + + +def _env_creator(env_config): + return Primaite( + training_config_path=env_config["training_config_path"], + lay_down_config_path=env_config["lay_down_config_path"], + session_path=env_config["session_path"], + timestamp_str=env_config["timestamp_str"], + ) + + +def _custom_log_creator(session_path: Path): + logdir = session_path / "ray_results" + logdir.mkdir(parents=True, exist_ok=True) + + def logger_creator(config): + return UnifiedLogger(config, logdir, loggers=None) + + return logger_creator + + +class RLlibAgent(AgentSessionABC): + """An AgentSession class that implements a Ray RLlib agent.""" + + def __init__(self, training_config_path, lay_down_config_path): + super().__init__(training_config_path, lay_down_config_path) + if not self._training_config.agent_framework == AgentFramework.RLLIB: + msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}" + _LOGGER.error(msg) + raise ValueError(msg) + if self._training_config.agent_identifier == AgentIdentifier.PPO: + self._agent_config_class = PPOConfig + elif self._training_config.agent_identifier == AgentIdentifier.A2C: + self._agent_config_class = A2CConfig + else: + msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}" + _LOGGER.error(msg) + raise ValueError(msg) + self._agent_config: Union[PPOConfig, A2CConfig] + + self._current_result: dict + self._setup() + _LOGGER.debug( + f"Created {self.__class__.__name__} using: " + f"agent_framework={self._training_config.agent_framework}, " + f"agent_identifier=" + f"{self._training_config.agent_identifier}, " + f"deep_learning_framework=" + f"{self._training_config.deep_learning_framework}" + ) + + def _update_session_metadata_file(self): + """ + Update the ``session_metadata.json`` file. + + Updates the `session_metadata.json`` in the ``session_path`` directory + with the following key/value pairs: + + - end_datetime: The date & time the session ended in iso format. + - total_episodes: The total number of training episodes completed. + - total_time_steps: The total number of training time steps completed. + """ + with open(self.session_path / "session_metadata.json", "r") as file: + metadata_dict = json.load(file) + + metadata_dict["end_datetime"] = datetime.now().isoformat() + metadata_dict["total_episodes"] = self._current_result["episodes_total"] + metadata_dict["total_time_steps"] = self._current_result["timesteps_total"] + + filepath = self.session_path / "session_metadata.json" + _LOGGER.debug(f"Updating Session Metadata file: {filepath}") + with open(filepath, "w") as file: + json.dump(metadata_dict, file) + _LOGGER.debug("Finished updating session metadata file") + + def _setup(self): + super()._setup() + register_env("primaite", _env_creator) + self._agent_config = self._agent_config_class() + + self._agent_config.environment( + env="primaite", + env_config=dict( + training_config_path=self._training_config_path, + lay_down_config_path=self._lay_down_config_path, + session_path=self.session_path, + timestamp_str=self.timestamp_str, + ), + ) + + self._agent_config.training(train_batch_size=self._training_config.num_steps) + self._agent_config.framework(framework="tf") + + self._agent_config.rollouts( + num_rollout_workers=1, + num_envs_per_worker=1, + horizon=self._training_config.num_steps, + ) + self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path)) + + def _save_checkpoint(self): + checkpoint_n = self._training_config.checkpoint_every_n_episodes + episode_count = self._current_result["episodes_total"] + if checkpoint_n > 0 and episode_count > 0: + if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes): + self._agent.save(str(self.checkpoints_path)) + + def learn( + self, + **kwargs, + ): + """ + Evaluate the agent. + + :param kwargs: Any agent-specific key-word args to be passed. + """ + time_steps = self._training_config.num_steps + episodes = self._training_config.num_episodes + + _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...") + for i in range(episodes): + self._current_result = self._agent.train() + self._save_checkpoint() + self._agent.stop() + super().learn() + + def evaluate( + self, + **kwargs, + ): + """ + Evaluate the agent. + + :param kwargs: Any agent-specific key-word args to be passed. + """ + raise NotImplementedError + + def _get_latest_checkpoint(self): + raise NotImplementedError + + @classmethod + def load(cls, path: Union[str, Path]) -> RLlibAgent: + """Load an agent from file.""" + raise NotImplementedError + + def save(self): + """Save the agent.""" + raise NotImplementedError + + def export(self): + """Export the agent to transportable file format.""" + raise NotImplementedError diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py new file mode 100644 index 00000000..f5ac44cb --- /dev/null +++ b/src/primaite/agents/sb3.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Union + +import numpy as np +from stable_baselines3 import A2C, PPO +from stable_baselines3.ppo import MlpPolicy as PPOMlp + +from primaite import getLogger +from primaite.agents.agent import AgentSessionABC +from primaite.common.enums import AgentFramework, AgentIdentifier +from primaite.environment.primaite_env import Primaite + +_LOGGER = getLogger(__name__) + + +class SB3Agent(AgentSessionABC): + """An AgentSession class that implements a Stable Baselines3 agent.""" + + def __init__(self, training_config_path, lay_down_config_path): + super().__init__(training_config_path, lay_down_config_path) + if not self._training_config.agent_framework == AgentFramework.SB3: + msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}" + _LOGGER.error(msg) + raise ValueError(msg) + if self._training_config.agent_identifier == AgentIdentifier.PPO: + self._agent_class = PPO + elif self._training_config.agent_identifier == AgentIdentifier.A2C: + self._agent_class = A2C + else: + msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier}" + _LOGGER.error(msg) + raise ValueError(msg) + + self._tensorboard_log_path = self.learning_path / "tensorboard_logs" + self._tensorboard_log_path.mkdir(parents=True, exist_ok=True) + self._setup() + _LOGGER.debug( + f"Created {self.__class__.__name__} using: " + f"agent_framework={self._training_config.agent_framework}, " + f"agent_identifier=" + f"{self._training_config.agent_identifier}" + ) + + self.is_eval = False + + def _setup(self): + super()._setup() + self._env = Primaite( + training_config_path=self._training_config_path, + lay_down_config_path=self._lay_down_config_path, + session_path=self.session_path, + timestamp_str=self.timestamp_str, + ) + self._agent = self._agent_class( + PPOMlp, + self._env, + verbose=self.sb3_output_verbose_level, + n_steps=self._training_config.num_steps, + tensorboard_log=str(self._tensorboard_log_path), + ) + + def _save_checkpoint(self): + checkpoint_n = self._training_config.checkpoint_every_n_episodes + episode_count = self._env.episode_count + if checkpoint_n > 0 and episode_count > 0: + if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes): + checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip" + self._agent.save(checkpoint_path) + _LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}") + + def _get_latest_checkpoint(self): + pass + + def learn( + self, + **kwargs, + ): + """ + Train the agent. + + :param kwargs: Any agent-specific key-word args to be passed. + """ + time_steps = self._training_config.num_steps + episodes = self._training_config.num_episodes + self.is_eval = False + _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...") + for i in range(episodes): + self._agent.learn(total_timesteps=time_steps) + self._save_checkpoint() + self._env.reset() + self._env.close() + super().learn() + + def evaluate( + self, + deterministic: bool = True, + **kwargs, + ): + """ + Evaluate the agent. + + :param deterministic: Whether the evaluation is deterministic. + :param kwargs: Any agent-specific key-word args to be passed. + """ + time_steps = self._training_config.num_steps + episodes = self._training_config.num_episodes + self._env.set_as_eval() + self.is_eval = True + if deterministic: + deterministic_str = "deterministic" + else: + deterministic_str = "non-deterministic" + _LOGGER.info( + f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..." + ) + for episode in range(episodes): + obs = self._env.reset() + + for step in range(time_steps): + action, _states = self._agent.predict(obs, deterministic=deterministic) + if isinstance(action, np.ndarray): + action = np.int64(action) + obs, rewards, done, info = self._env.step(action) + self._env.reset() + self._env.close() + super().evaluate() + + @classmethod + def load(cls, path: Union[str, Path]) -> SB3Agent: + """Load an agent from file.""" + raise NotImplementedError + + def save(self): + """Save the agent.""" + raise NotImplementedError + + def export(self): + """Export the agent to transportable file format.""" + raise NotImplementedError diff --git a/src/primaite/agents/simple.py b/src/primaite/agents/simple.py new file mode 100644 index 00000000..5a6c9da5 --- /dev/null +++ b/src/primaite/agents/simple.py @@ -0,0 +1,56 @@ +from primaite.agents.agent import HardCodedAgentSessionABC +from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum + + +class RandomAgent(HardCodedAgentSessionABC): + """ + A Random Agent. + + Get a completely random action from the action space. + """ + + def _calculate_action(self, obs): + return self._env.action_space.sample() + + +class DummyAgent(HardCodedAgentSessionABC): + """ + A Dummy Agent. + + All action spaces setup so dummy action is always 0 regardless of action + type used. + """ + + def _calculate_action(self, obs): + return 0 + + +class DoNothingACLAgent(HardCodedAgentSessionABC): + """ + A do nothing ACL agent. + + A valid ACL action that has no effect; does nothing. + """ + + def _calculate_action(self, obs): + nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"] + nothing_action = transform_action_acl_enum(nothing_action) + nothing_action = get_new_action(nothing_action, self._env.action_dict) + + return nothing_action + + +class DoNothingNodeAgent(HardCodedAgentSessionABC): + """ + A do nothing Node agent. + + A valid Node action that has no effect; does nothing. + """ + + def _calculate_action(self, obs): + nothing_action = [1, "NONE", "ON", 0] + nothing_action = transform_action_node_enum(nothing_action) + nothing_action = get_new_action(nothing_action, self._env.action_dict) + # nothing_action should currently always be 0 + + return nothing_action diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index bb967906..8c59faf7 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -1,4 +1,13 @@ -from primaite.common.enums import NodeHardwareAction, NodePOLType, NodeSoftwareAction +import numpy as np + +from primaite.common.enums import ( + HardwareState, + LinkStatus, + NodeHardwareAction, + NodePOLType, + NodeSoftwareAction, + SoftwareState, +) def transform_action_node_readable(action): @@ -7,14 +16,15 @@ def transform_action_node_readable(action): example: [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] + + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ action_node_property = NodePOLType(action[1]).name if action_node_property == "OPERATING": property_action = NodeHardwareAction(action[2]).name - elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[ - 2 - ] <= 1: + elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1: property_action = NodeSoftwareAction(action[2]).name else: property_action = "NONE" @@ -29,6 +39,9 @@ def transform_action_acl_readable(action): example: [0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1] + + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"} action_permissions = {0: "DENY", 1: "ALLOW"} @@ -53,6 +66,9 @@ def is_valid_node_action(action): Does NOT consider: - Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch - Node already being in that state (turning an ON node ON) + + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ action_r = transform_action_node_readable(action) @@ -68,7 +84,10 @@ def is_valid_node_action(action): if node_property == "OPERATING" and node_action == "PATCHING": # Operating State cannot PATCH return False - if node_property != "OPERATING" and node_action not in ["NONE", "PATCHING"]: + if node_property != "OPERATING" and node_action not in [ + "NONE", + "PATCHING", + ]: # Software States can only do Nothing or Patch return False return True @@ -83,6 +102,9 @@ def is_valid_acl_action(action): Does NOT consider: - Trying to create identical rules - Trying to create a rule which is a subset of another rule (caused by "ANY") + + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ action_r = transform_action_acl_readable(action) @@ -93,11 +115,7 @@ def is_valid_acl_action(action): if action_decision == "NONE": return False - if ( - action_source_id == action_destination_id - and action_source_id != "ANY" - and action_destination_id != "ANY" - ): + if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY": # ACL rule towards itself return False if action_permission == "DENY": @@ -109,7 +127,12 @@ def is_valid_acl_action(action): def is_valid_acl_action_extra(action): - """Harsher version of valid acl actions, does not allow action.""" + """ + Harsher version of valid acl actions, does not allow action. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ if is_valid_acl_action(action) is False: return False @@ -125,3 +148,406 @@ def is_valid_acl_action_extra(action): return False return True + + +def transform_change_obs_readable(obs): + """ + Transform list of transactions to readable list of each observation property. + + example: + np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']] + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + ids = [i for i in obs[:, 0]] + operating_states = [HardwareState(i).name for i in obs[:, 1]] + os_states = [SoftwareState(i).name for i in obs[:, 2]] + new_obs = [ids, operating_states, os_states] + + for service in range(3, obs.shape[1]): + # Links bit/s don't have a service state + service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]] + new_obs.append(service_states) + + return new_obs + + +def transform_obs_readable(obs): + """ + Transform observation to readable format. + + np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']] + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + changed_obs = transform_change_obs_readable(obs) + new_obs = list(zip(*changed_obs)) + # Convert list of tuples to list of lists + new_obs = [list(i) for i in new_obs] + + return new_obs + + +def convert_to_new_obs(obs, num_nodes=10): + """ + Convert original gym Box observation space to new multiDiscrete observation space. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + # Remove ID columns, remove links and flatten to MultiDiscrete observation space + new_obs = obs[:num_nodes, 1:].flatten() + return new_obs + + +def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): + """ + Convert to old observation. + + Links filled with 0's as no information is included in new observation space. + + example: + obs = array([1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1, 1]) + + new_obs = array([[ 1, 1, 1, 1], + [ 2, 1, 1, 1], + [ 3, 1, 1, 1], + ... + [20, 0, 0, 0]]) + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + # Convert back to more readable, original format + reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2) + + # Add empty links back and add node ID back + s = np.zeros( + [reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1], + dtype=np.int64, + ) + s[:, 0] = range(1, num_nodes + num_links + 1) # Adding ID back + s[:num_nodes, 1:] = reshaped_nodes # put values back in + new_obs = s + + # Add links back in + links = obs[-num_links:] + # Links will be added to the last protocol/service slot but they are not specific to that service + new_obs[num_nodes:, -1] = links + + return new_obs + + +def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): + """ + Return string describing change between two observations. + + example: + obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]]) + obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]]) + output = 'ID 1: SERVICE 2 set to GOOD' + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services) + obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services) + list_of_changes = [] + for n, row in enumerate(obs1 - obs2): + if row.any() != 0: + relevant_changes = np.where(row != 0, obs2[n], -1) + relevant_changes[0] = obs2[n, 0] # ID is always relevant + is_link = relevant_changes[0] > num_nodes + desc = _describe_obs_change_helper(relevant_changes, is_link) + list_of_changes.append(desc) + + change_string = "\n ".join(list_of_changes) + if len(list_of_changes) > 0: + change_string = "\n " + change_string + return change_string + + +def _describe_obs_change_helper(obs_change, is_link): + """ + Helper funcion to describe what has changed. + + example: + [ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD" + + Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.' + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + # Indexes where a change has occured, not including 0th index + index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1] + # Node pol types, Indexes >= 3 are service nodes + NodePOLTypes = [NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) for i in index_changed] + # Account for hardware states, software sattes and links + states = [ + LinkStatus(obs_change[i]).name + if is_link + else HardwareState(obs_change[i]).name + if i == 1 + else SoftwareState(obs_change[i]).name + for i in index_changed + ] + + if not is_link: + desc = f"ID {obs_change[0]}:" + for node_pol_type, state in list(zip(NodePOLTypes, states)): + desc = desc + " " + node_pol_type + " changed to " + state + "." + else: + desc = f"ID {obs_change[0]}: Link traffic changed to {states[0]}." + + return desc + + +def transform_action_node_enum(action): + """ + Convert a node action from readable string format, to enumerated format. + + example: + [1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0] + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + action_node_id = action[0] + action_node_property = NodePOLType[action[1]].value + + if action[1] == "OPERATING": + property_action = NodeHardwareAction[action[2]].value + elif action[1] == "OS" or action[1] == "SERVICE": + property_action = NodeSoftwareAction[action[2]].value + else: + property_action = 0 + + action_service_index = action[3] + + new_action = [ + action_node_id, + action_node_property, + property_action, + action_service_index, + ] + + return new_action + + +def transform_action_node_readable(action): + """ + Convert a node action from enumerated format to readable format. + + example: + [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + action_node_property = NodePOLType(action[1]).name + + if action_node_property == "OPERATING": + property_action = NodeHardwareAction(action[2]).name + elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1: + property_action = NodeSoftwareAction(action[2]).name + else: + property_action = "NONE" + + new_action = [action[0], action_node_property, property_action, action[3]] + return new_action + + +def node_action_description(action): + """ + Generate string describing a node-based action. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + if isinstance(action[1], (int, np.int64)): + # transform action to readable format + action = transform_action_node_readable(action) + + node_id = action[0] + node_property = action[1] + property_action = action[2] + service_id = action[3] + + if property_action == "NONE": + return "" + if node_property == "OPERATING" or node_property == "OS": + description = f"NODE {node_id}, {node_property}, SET TO {property_action}" + elif node_property == "SERVICE": + description = f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}" + else: + return "" + + return description + + +def transform_action_acl_enum(action): + """ + Convert acl action from readable str format, to enumerated format. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2} + action_permissions = {"DENY": 0, "ALLOW": 1} + + action_decision = action_decisions[action[0]] + action_permission = action_permissions[action[1]] + + # For IPs, Ports and Protocols, ANY has value 0, otherwise its just an index + new_action = [action_decision, action_permission] + list(action[2:6]) + for n, val in enumerate(list(action[2:6])): + if val == "ANY": + new_action[n + 2] = 0 + + new_action = np.array(new_action) + return new_action + + +def acl_action_description(action): + """ + Generate string describing an acl-based action. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + if isinstance(action[0], (int, np.int64)): + # transform action to readable format + action = transform_action_acl_readable(action) + if action[0] == "NONE": + description = "NO ACL RULE APPLIED" + else: + description = ( + f"{action[0]} RULE: {action[1]} traffic from IP {action[2]} to IP {action[3]}," + f" for protocol/service index {action[4]} on port index {action[5]}" + ) + + return description + + +def get_node_of_ip(ip, node_dict): + """ + Get the node ID of an IP address. + + node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes) + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + for node_key, node_value in node_dict.items(): + node_ip = node_value.ip_address + if node_ip == ip: + return node_key + + +def is_valid_node_action(action): + """Is the node action an actual valid action. + + Only uses information about the action to determine if the action has an effect + + Does NOT consider: + - Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch + - Node already being in that state (turning an ON node ON) + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + action_r = transform_action_node_readable(action) + + node_property = action_r[1] + node_action = action_r[2] + + if node_property == "NONE": + return False + if node_action == "NONE": + return False + if node_property == "OPERATING" and node_action == "PATCHING": + # Operating State cannot PATCH + return False + if node_property != "OPERATING" and node_action not in [ + "NONE", + "PATCHING", + ]: + # Software States can only do Nothing or Patch + return False + return True + + +def is_valid_acl_action(action): + """ + Is the ACL action an actual valid action. + + Only uses information about the action to determine if the action has an effect + + Does NOT consider: + - Trying to create identical rules + - Trying to create a rule which is a subset of another rule (caused by "ANY") + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + action_r = transform_action_acl_readable(action) + + action_decision = action_r[0] + action_permission = action_r[1] + action_source_id = action_r[2] + action_destination_id = action_r[3] + + if action_decision == "NONE": + return False + if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY": + # ACL rule towards itself + return False + if action_permission == "DENY": + # DENY is unnecessary, we can create and delete allow rules instead + # No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY. + return False + + return True + + +def is_valid_acl_action_extra(action): + """ + Harsher version of valid acl actions, does not allow action. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + if is_valid_acl_action(action) is False: + return False + + action_r = transform_action_acl_readable(action) + action_protocol = action_r[4] + action_port = action_r[5] + + # Don't allow protocols or ports to be ANY + # in the future we might want to do the opposite, and only have ANY option for ports and service + if action_protocol == "ANY": + return False + if action_port == "ANY": + return False + + return True + + +def get_new_action(old_action, action_dict): + """ + Get new action (e.g. 32) from old action e.g. [1,1,1,0]. + + Old_action can be either node or acl action type + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + for key, val in action_dict.items(): + if list(val) == list(old_action): + return key + # Not all possible actions are included in dict, only valid action are + # if action is not in the dict, its an invalid action so return 0 + return 0 diff --git a/src/primaite/cli.py b/src/primaite/cli.py index 319d643c..40e8cf0d 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -13,6 +13,8 @@ import yaml from platformdirs import PlatformDirs from typing_extensions import Annotated +from primaite.data_viz import PlotlyTemplate + app = typer.Typer() @@ -76,11 +78,12 @@ def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None): primaite_config = yaml.safe_load(file) if level: - primaite_config["log_level"] = level.value + primaite_config["logging"]["log_level"] = level.value with open(user_config_path, "w") as file: yaml.dump(primaite_config, file) + print(f"PrimAITE Log Level: {level}") else: - level = primaite_config["log_level"] + level = primaite_config["logging"]["log_level"] print(f"PrimAITE Log Level: {level}") @@ -119,30 +122,18 @@ def setup(overwrite_existing: bool = True): app_dirs = PlatformDirs(appname="primaite") app_dirs.user_config_path.mkdir(exist_ok=True, parents=True) user_config_path = app_dirs.user_config_path / "primaite_config.yaml" - build_config = overwrite_existing or (not user_config_path.exists()) - if build_config: - pkg_config_path = Path( - pkg_resources.resource_filename( - "primaite", "setup/_package_data/primaite_config.yaml" - ) - ) + pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml")) - shutil.copy2(pkg_config_path, user_config_path) + shutil.copy2(pkg_config_path, user_config_path) from primaite import getLogger - from primaite.setup import ( - old_installation_clean_up, - reset_demo_notebooks, - reset_example_configs, - setup_app_dirs, - ) + from primaite.setup import old_installation_clean_up, reset_demo_notebooks, reset_example_configs, setup_app_dirs _LOGGER = getLogger(__name__) _LOGGER.info("Performing the PrimAITE first-time setup...") - if build_config: - _LOGGER.info("Building primaite_config.yaml...") + _LOGGER.info("Building primaite_config.yaml...") _LOGGER.info("Building the PrimAITE app directories...") setup_app_dirs.run() @@ -160,13 +151,54 @@ def setup(overwrite_existing: bool = True): @app.command() -def session(tc: str, ldc: str): +def session(tc: Optional[str] = None, ldc: Optional[str] = None): """ Run a PrimAITE session. - :param tc: The training config filepath. - :param ldc: The lay down config file path. + tc: The training config filepath. Optional. If no value is passed then + example default training config is used from: + ~/primaite/config/example_config/training/training_config_main.yaml. + + ldc: The lay down config file path. Optional. If no value is passed then + example default lay down config is used from: + ~/primaite/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml. """ + from primaite.config.lay_down_config import dos_very_basic_config_path + from primaite.config.training_config import main_training_config_path from primaite.main import run + if not tc: + tc = main_training_config_path() + + if not ldc: + ldc = dos_very_basic_config_path() + run(training_config_path=tc, lay_down_config_path=ldc) + + +@app.command() +def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None): + """ + View or set the plotly template for Session plots. + + To View, simply call: primaite plotly-template + + To set, call: primaite plotly-template + + For example, to set as plotly_dark, call: primaite plotly-template PLOTLY_DARK + """ + app_dirs = PlatformDirs(appname="primaite") + app_dirs.user_config_path.mkdir(exist_ok=True, parents=True) + user_config_path = app_dirs.user_config_path / "primaite_config.yaml" + if user_config_path.exists(): + with open(user_config_path, "r") as file: + primaite_config = yaml.safe_load(file) + + if template: + primaite_config["session"]["outputs"]["plots"]["template"] = template.value + with open(user_config_path, "w") as file: + yaml.dump(primaite_config, file) + print(f"PrimAITE plotly template: {template.value}") + else: + template = primaite_config["session"]["outputs"]["plots"]["template"] + print(f"PrimAITE plotly template: {template}") diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 68ad80f2..db5d153c 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -1,7 +1,7 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """Enumerations for APE.""" -from enum import Enum +from enum import Enum, IntEnum class NodeType(Enum): @@ -32,6 +32,7 @@ class Priority(Enum): class HardwareState(Enum): """Node hardware state enumeration.""" + NONE = 0 ON = 1 OFF = 2 RESETTING = 3 @@ -42,6 +43,7 @@ class HardwareState(Enum): class SoftwareState(Enum): """Software or Service state enumeration.""" + NONE = 0 GOOD = 1 PATCHING = 2 COMPROMISED = 3 @@ -79,6 +81,65 @@ class Protocol(Enum): NONE = 7 +class SessionType(Enum): + """The type of PrimAITE Session to be run.""" + + TRAIN = 1 + "Train an agent" + EVAL = 2 + "Evaluate an agent" + TRAIN_EVAL = 3 + "Train then evaluate an agent" + + +class AgentFramework(Enum): + """The agent algorithm framework/package.""" + + CUSTOM = 0 + "Custom Agent" + SB3 = 1 + "Stable Baselines3" + RLLIB = 2 + "Ray RLlib" + + +class DeepLearningFramework(Enum): + """The deep learning framework.""" + + TF = "tf" + "Tensorflow" + TF2 = "tf2" + "Tensorflow 2.x" + TORCH = "torch" + "PyTorch" + + +class AgentIdentifier(Enum): + """The Red Agent algo/class.""" + + A2C = 1 + "Advantage Actor Critic" + PPO = 2 + "Proximal Policy Optimization" + HARDCODED = 3 + "The Hardcoded agents" + DO_NOTHING = 4 + "The DoNothing agents" + RANDOM = 5 + "The RandomAgent" + DUMMY = 6 + "The DummyAgent" + + +class HardCodedAgentView(Enum): + """The view the deterministic hard-coded agent has of the environment.""" + + BASIC = 1 + "The current observation space only" + FULL = 2 + "Full environment view with actions taken and reward feedback" + + class ActionType(Enum): """Action type enumeration.""" @@ -128,3 +189,11 @@ class LinkStatus(Enum): MEDIUM = 2 HIGH = 3 OVERLOAD = 4 + + +class SB3OutputVerboseLevel(IntEnum): + """The Stable Baselines3 learn/eval output verbosity level.""" + + NONE = 0 + INFO = 1 + DEBUG = 2 diff --git a/src/primaite/common/training_config.py b/src/primaite/common/training_config.py deleted file mode 100644 index d45bedf9..00000000 --- a/src/primaite/common/training_config.py +++ /dev/null @@ -1,95 +0,0 @@ -# # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. -# """The config class.""" -# from dataclasses import dataclass -# -# from primaite.common.enums import ActionType -# -# -# @dataclass() -# class TrainingConfig: -# """Class to hold main config values.""" -# -# # Generic -# agent_identifier: str # The Red Agent algo/class to be used -# action_type: ActionType # type of action to use (NODE/ACL/ANY) -# num_episodes: int # number of episodes to train over -# num_steps: int # number of steps in an episode -# time_delay: int # delay between steps (ms) - applies to generic agents only -# # file -# session_type: str # the session type to run (TRAINING or EVALUATION) -# load_agent: str # Determine whether to load an agent from file -# agent_load_file: str # File path and file name of agent if you're loading one in -# -# # Environment -# observation_space_high_value: int # The high value for the observation space -# -# # Reward values -# # Generic -# all_ok: int -# # Node Hardware State -# off_should_be_on: int -# off_should_be_resetting: int -# on_should_be_off: int -# on_should_be_resetting: int -# resetting_should_be_on: int -# resetting_should_be_off: int -# resetting: int -# # Node Software or Service State -# good_should_be_patching: int -# good_should_be_compromised: int -# good_should_be_overwhelmed: int -# patching_should_be_good: int -# patching_should_be_compromised: int -# patching_should_be_overwhelmed: int -# patching: int -# compromised_should_be_good: int -# compromised_should_be_patching: int -# compromised_should_be_overwhelmed: int -# compromised: int -# overwhelmed_should_be_good: int -# overwhelmed_should_be_patching: int -# overwhelmed_should_be_compromised: int -# overwhelmed: int -# # Node File System State -# good_should_be_repairing: int -# good_should_be_restoring: int -# good_should_be_corrupt: int -# good_should_be_destroyed: int -# repairing_should_be_good: int -# repairing_should_be_restoring: int -# repairing_should_be_corrupt: int -# repairing_should_be_destroyed: int # Repairing does not fix destroyed state - you need to restore -# -# repairing: int -# restoring_should_be_good: int -# restoring_should_be_repairing: int -# restoring_should_be_corrupt: int # Not the optimal method (as repair will fix corruption) -# -# restoring_should_be_destroyed: int -# restoring: int -# corrupt_should_be_good: int -# corrupt_should_be_repairing: int -# corrupt_should_be_restoring: int -# corrupt_should_be_destroyed: int -# corrupt: int -# destroyed_should_be_good: int -# destroyed_should_be_repairing: int -# destroyed_should_be_restoring: int -# destroyed_should_be_corrupt: int -# destroyed: int -# scanning: int -# # IER status -# red_ier_running: int -# green_ier_blocked: int -# -# # Patching / Reset -# os_patching_duration: int # The time taken to patch the OS -# node_reset_duration: int # The time taken to reset a node (hardware) -# node_booting_duration = 0 # The Time taken to turn on the node -# node_shutdown_duration = 0 # The time taken to turn off the node -# service_patching_duration: int # The time taken to patch a service -# file_system_repairing_limit: int # The time take to repair a file -# file_system_restoring_limit: int # The time take to restore a file -# file_system_scanning_limit: int # The time taken to scan the file system -# # Patching / Reset -# diff --git a/src/primaite/config/_package_data/lay_down/lay_down_config_1_DDOS_basic.yaml b/src/primaite/config/_package_data/lay_down/lay_down_config_1_DDOS_basic.yaml index f7c1e372..3f0c546a 100644 --- a/src/primaite/config/_package_data/lay_down/lay_down_config_1_DDOS_basic.yaml +++ b/src/primaite/config/_package_data/lay_down/lay_down_config_1_DDOS_basic.yaml @@ -1,7 +1,3 @@ -- item_type: ACTIONS - type: NODE -- item_type: STEPS - steps: 128 - item_type: PORTS ports_list: - port: '80' diff --git a/src/primaite/config/_package_data/lay_down/lay_down_config_2_DDOS_basic.yaml b/src/primaite/config/_package_data/lay_down/lay_down_config_2_DDOS_basic.yaml index e4a3385d..39bf7dac 100644 --- a/src/primaite/config/_package_data/lay_down/lay_down_config_2_DDOS_basic.yaml +++ b/src/primaite/config/_package_data/lay_down/lay_down_config_2_DDOS_basic.yaml @@ -1,7 +1,3 @@ -- item_type: ACTIONS - type: NODE -- item_type: STEPS - steps: 128 - item_type: PORTS ports_list: - port: '80' diff --git a/src/primaite/config/_package_data/lay_down/lay_down_config_3_DOS_very_basic.yaml b/src/primaite/config/_package_data/lay_down/lay_down_config_3_DOS_very_basic.yaml index 9f37a6f0..619a0d35 100644 --- a/src/primaite/config/_package_data/lay_down/lay_down_config_3_DOS_very_basic.yaml +++ b/src/primaite/config/_package_data/lay_down/lay_down_config_3_DOS_very_basic.yaml @@ -1,7 +1,3 @@ -- item_type: ACTIONS - type: NODE -- item_type: STEPS - steps: 256 - item_type: PORTS ports_list: - port: '80' diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index b4bfa75e..a638fe14 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -1,16 +1,42 @@ -# Main Config File +# Training Config File -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: STABLE_BASELINES3_A2C +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: SB3 -# RED AGENT IDENTIFIER -# RANDOM or NONE +# Sets which deep learning framework will be used (by RLlib ONLY). +# Default is TF (Tensorflow). +# Options are: +# "TF" (Tensorflow) +# TF2 (Tensorflow 2.X) +# TORCH (PyTorch) +deep_learning_framework: TF2 + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: PPO + +# Sets whether Red Agent POL and IER is randomised. +# Options are: +# True +# False random_red_agent: False +# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC. +# Options are: +# "BASIC" (The current observation space only) +# "FULL" (Full environment view with actions taken and reward feedback) +hard_coded_agent_view: FULL + # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -25,21 +51,34 @@ observation_space: # - name: LINK_TRAFFIC_LEVELS # Number of episodes to run per session num_episodes: 10 + # Number of time_steps per episode num_steps: 256 -# Time delay between steps (for generic agents) -time_delay: 10 -# Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING -# Determine whether to load an agent from file -load_agent: False -# File path and file name of agent if you're loading one in -agent_load_file: C:\[Path]\[agent_saved_filename.zip] + +# Sets how often the agent will save a checkpoint (every n time episodes). +# Set to 0 if no checkpoints are required. Default is 10 +checkpoint_every_n_episodes: 10 + +# Time delay (milliseconds) between steps for CUSTOM agents. +time_delay: 5 + +# Type of session to be run. Options are: +# "TRAIN" (Trains an agent) +# "EVAL" (Evaluates an agent) +# "TRAIN_EVAL" (Trains then evaluates an agent) +session_type: TRAIN_EVAL # Environment config values # The high value for the observation space observation_space_high_value: 1000000000 +# The Stable Baselines3 learn/eval output verbosity level: +# Options are: +# "NONE" (No Output) +# "INFO" (Info Messages (such as devices and wrappers used)) +# "DEBUG" (All Messages) +sb3_output_verbose_level: NONE + # Reward values # Generic all_ok: 0 diff --git a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml index 3e0a3e2f..96243daf 100644 --- a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml +++ b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml @@ -7,8 +7,10 @@ # "GENERIC" agent_identifier: STABLE_BASELINES3_A2C -# RED AGENT IDENTIFIER -# RANDOM or NONE +# Sets whether Red Agent POL and IER is randomised. +# Options are: +# True +# False random_red_agent: True # Sets How the Action Space is defined: diff --git a/src/primaite/config/lay_down_config.py b/src/primaite/config/lay_down_config.py index 46389297..08f77b2f 100644 --- a/src/primaite/config/lay_down_config.py +++ b/src/primaite/config/lay_down_config.py @@ -1,14 +1,57 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. from pathlib import Path -from typing import Final +from typing import Any, Dict, Final, Union -from primaite import USERS_CONFIG_DIR, getLogger +import yaml + +from primaite import getLogger, USERS_CONFIG_DIR _LOGGER = getLogger(__name__) _EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down" +def convert_legacy_lay_down_config_dict(legacy_config_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert a legacy lay down config dict to the new format. + + :param legacy_config_dict: A legacy lay down config dict. + """ + _LOGGER.warning("Legacy lay down config conversion not yet implemented") + return legacy_config_dict + + +def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict: + """ + Read in a lay down config yaml file. + + :param file_path: The config file path. + :param legacy_file: True if the config file is legacy format, otherwise + False. + :return: The lay down config as a dict. + :raises ValueError: If the file_path does not exist. + """ + if not isinstance(file_path, Path): + file_path = Path(file_path) + if file_path.exists(): + with open(file_path, "r") as file: + config = yaml.safe_load(file) + _LOGGER.debug(f"Loading lay down config file: {file_path}") + if legacy_file: + try: + config = convert_legacy_lay_down_config_dict(config) + except KeyError: + msg = ( + f"Failed to convert lay down config file {file_path} " + f"from legacy format. Attempting to use file as is." + ) + _LOGGER.error(msg) + return config + msg = f"Cannot load the lay down config as it does not exist: {file_path}" + _LOGGER.error(msg) + raise ValueError(msg) + + def ddos_basic_one_config_path() -> Path: """ The path to the example lay_down_config_1_DDOS_basic.yaml file. diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 5ec4d942..99664f02 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -1,58 +1,96 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +from __future__ import annotations + from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, Final, Optional, Union import yaml -from primaite import USERS_CONFIG_DIR, getLogger -from primaite.common.enums import ActionType +from primaite import getLogger, USERS_CONFIG_DIR +from primaite.common.enums import ( + ActionType, + AgentFramework, + AgentIdentifier, + DeepLearningFramework, + HardCodedAgentView, + SB3OutputVerboseLevel, + SessionType, +) _LOGGER = getLogger(__name__) _EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training" +def main_training_config_path() -> Path: + """ + The path to the example training_config_main.yaml file. + + :return: The file path. + """ + path = _EXAMPLE_TRAINING / "training_config_main.yaml" + if not path.exists(): + msg = "Example config not found. Please run 'primaite setup'" + _LOGGER.critical(msg) + raise FileNotFoundError(msg) + + return path + + @dataclass() class TrainingConfig: """The Training Config class.""" - # Generic - agent_identifier: str = "STABLE_BASELINES3_A2C" - "The Red Agent algo/class to be used." + agent_framework: AgentFramework = AgentFramework.SB3 + "The AgentFramework" + + deep_learning_framework: DeepLearningFramework = DeepLearningFramework.TF + "The DeepLearningFramework" + + agent_identifier: AgentIdentifier = AgentIdentifier.PPO + "The AgentIdentifier" + + hard_coded_agent_view: HardCodedAgentView = HardCodedAgentView.FULL + "The view the deterministic hard-coded agent has of the environment" random_red_agent: bool = False "Creates Random Red Agent Attacks" action_type: ActionType = ActionType.ANY - "The ActionType to use." + "The ActionType to use" num_episodes: int = 10 - "The number of episodes to train over." + "The number of episodes to train over" num_steps: int = 256 - "The number of steps in an episode." - observation_space: dict = field( - default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]} - ) - "The observation space config dict." + "The number of steps in an episode" + + checkpoint_every_n_episodes: int = 5 + "The agent will save a checkpoint every n episodes" + + observation_space: dict = field(default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]}) + "The observation space config dict" time_delay: int = 10 - "The delay between steps (ms). Applies to generic agents only." + "The delay between steps (ms). Applies to generic agents only" # file - session_type: str = "TRAINING" - "the session type to run (TRAINING or EVALUATION)" + session_type: SessionType = SessionType.TRAIN + "The type of PrimAITE session to run" load_agent: str = False - "Determine whether to load an agent from file." + "Determine whether to load an agent from file" agent_load_file: Optional[str] = None - "File path and file name of agent if you're loading one in." + "File path and file name of agent if you're loading one in" # Environment observation_space_high_value: int = 1000000000 - "The high value for the observation space." + "The high value for the observation space" + + sb3_output_verbose_level: SB3OutputVerboseLevel = SB3OutputVerboseLevel.NONE + "Stable Baselines3 learn/eval output verbosity level" # Reward values # Generic @@ -117,28 +155,51 @@ class TrainingConfig: # Patching / Reset durations os_patching_duration: int = 5 - "The time taken to patch the OS." + "The time taken to patch the OS" node_reset_duration: int = 5 - "The time taken to reset a node (hardware)." + "The time taken to reset a node (hardware)" node_booting_duration: int = 3 - "The Time taken to turn on the node." + "The Time taken to turn on the node" node_shutdown_duration: int = 2 - "The time taken to turn off the node." + "The time taken to turn off the node" service_patching_duration: int = 5 - "The time taken to patch a service." + "The time taken to patch a service" file_system_repairing_limit: int = 5 - "The time take to repair the file system." + "The time take to repair the file system" file_system_restoring_limit: int = 5 - "The time take to restore the file system." + "The time take to restore the file system" file_system_scanning_limit: int = 5 - "The time taken to scan the file system." + "The time taken to scan the file system" + + @classmethod + def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig: + """ + Create an instance of TrainingConfig from a dict. + + :param config_dict: The training config dict. + :return: The instance of TrainingConfig. + """ + field_enum_map = { + "agent_framework": AgentFramework, + "deep_learning_framework": DeepLearningFramework, + "agent_identifier": AgentIdentifier, + "action_type": ActionType, + "session_type": SessionType, + "sb3_output_verbose_level": SB3OutputVerboseLevel, + "hard_coded_agent_view": HardCodedAgentView, + } + + for key, value in field_enum_map.items(): + if key in config_dict: + config_dict[key] = value[config_dict[key]] + return TrainingConfig(**config_dict) deterministic: bool = False "If true, the training will be deterministic" @@ -156,24 +217,28 @@ class TrainingConfig: """ data = self.__dict__ if json_serializable: - data["action_type"] = self.action_type.value + data["agent_framework"] = self.agent_framework.name + data["deep_learning_framework"] = self.deep_learning_framework.name + data["agent_identifier"] = self.agent_identifier.name + data["action_type"] = self.action_type.name + data["sb3_output_verbose_level"] = self.sb3_output_verbose_level.name + data["session_type"] = self.session_type.name + data["hard_coded_agent_view"] = self.hard_coded_agent_view.name return data - -def main_training_config_path() -> Path: - """ - The path to the example training_config_main.yaml file. - - :return: The file path. - """ - path = _EXAMPLE_TRAINING / "training_config_main.yaml" - if not path.exists(): - msg = "Example config not found. Please run 'primaite setup'" - _LOGGER.critical(msg) - raise FileNotFoundError(msg) - - return path + def __str__(self) -> str: + tc = f"{self.agent_framework}, " + if self.agent_framework is AgentFramework.RLLIB: + tc += f"{self.deep_learning_framework}, " + tc += f"{self.agent_identifier}, " + if self.agent_identifier is AgentIdentifier.HARDCODED: + tc += f"{self.hard_coded_agent_view}, " + tc += f"{self.action_type}, " + tc += f"observation_space={self.observation_space}, " + tc += f"{self.num_episodes} episodes @ " + tc += f"{self.num_steps} steps" + return tc def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig: @@ -204,15 +269,10 @@ def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConf f"from legacy format. Attempting to use file as is." ) _LOGGER.error(msg) - # Convert values to Enums - config["action_type"] = ActionType[config["action_type"]] try: - return TrainingConfig(**config) + return TrainingConfig.from_dict(config) except TypeError as e: - msg = ( - f"Error when creating an instance of {TrainingConfig} " - f"from the training config file {file_path}" - ) + msg = f"Error when creating an instance of {TrainingConfig} " f"from the training config file {file_path}" _LOGGER.critical(msg, exc_info=True) raise e msg = f"Cannot load the training config as it does not exist: {file_path}" @@ -221,19 +281,35 @@ def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConf def convert_legacy_training_config_dict( - legacy_config_dict: Dict[str, Any], num_steps: int = 256, action_type: str = "ANY" + legacy_config_dict: Dict[str, Any], + agent_framework: AgentFramework = AgentFramework.SB3, + agent_identifier: AgentIdentifier = AgentIdentifier.PPO, + action_type: ActionType = ActionType.ANY, + num_steps: int = 256, ) -> Dict[str, Any]: """ Convert a legacy training config dict to the new format. :param legacy_config_dict: A legacy training config dict. - :param num_steps: The number of steps to set as legacy training configs - don't have num_steps values. + :param agent_framework: The agent framework to use as legacy training + configs don't have agent_framework values. + :param agent_identifier: The red agent identifier to use as legacy + training configs don't have agent_identifier values. :param action_type: The action space type to set as legacy training configs don't have action_type values. + :param num_steps: The number of steps to set as legacy training configs + don't have num_steps values. :return: The converted training config dict. """ - config_dict = {"num_steps": num_steps, "action_type": action_type} + config_dict = { + "agent_framework": agent_framework.name, + "agent_identifier": agent_identifier.name, + "action_type": action_type.name, + "num_steps": num_steps, + "sb3_output_verbose_level": SB3OutputVerboseLevel.INFO.name, + } + session_type_map = {"TRAINING": "TRAIN", "EVALUATION": "EVAL"} + legacy_config_dict["sessionType"] = session_type_map[legacy_config_dict["sessionType"]] for legacy_key, value in legacy_config_dict.items(): new_key = _get_new_key_from_legacy(legacy_key) if new_key: @@ -249,7 +325,7 @@ def _get_new_key_from_legacy(legacy_key: str) -> str: :return: The mapped key. """ key_mapping = { - "agentIdentifier": "agent_identifier", + "agentIdentifier": None, "numEpisodes": "num_episodes", "timeDelay": "time_delay", "configFilename": None, diff --git a/src/primaite/data_viz/__init__.py b/src/primaite/data_viz/__init__.py new file mode 100644 index 00000000..a7cc3e8b --- /dev/null +++ b/src/primaite/data_viz/__init__.py @@ -0,0 +1,13 @@ +from enum import Enum + + +class PlotlyTemplate(Enum): + """The built-in plotly templates.""" + + PLOTLY = "plotly" + PLOTLY_WHITE = "plotly_white" + PLOTLY_DARK = "plotly_dark" + GGPLOT2 = "ggplot2" + SEABORN = "seaborn" + SIMPLE_WHITE = "simple_white" + NONE = "none" diff --git a/src/primaite/data_viz/session_plots.py b/src/primaite/data_viz/session_plots.py new file mode 100644 index 00000000..245b9774 --- /dev/null +++ b/src/primaite/data_viz/session_plots.py @@ -0,0 +1,73 @@ +from pathlib import Path +from typing import Dict, Optional, Union + +import plotly.graph_objects as go +import polars as pl +import yaml +from plotly.graph_objs import Figure + +from primaite import _PLATFORM_DIRS + + +def _get_plotly_config() -> Dict: + """Get the plotly config from primaite_config.yaml.""" + user_config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml" + with open(user_config_path, "r") as file: + primaite_config = yaml.safe_load(file) + return primaite_config["session"]["outputs"]["plots"] + + +def plot_av_reward_per_episode( + av_reward_per_episode_csv: Union[str, Path], + title: Optional[str] = None, + subtitle: Optional[str] = None, +) -> Figure: + """ + Plot the average reward per episode from a csv session output. + + :param av_reward_per_episode_csv: The average reward per episode csv + file path. + :param title: The plot title. This is optional. + :param subtitle: The plot subtitle. This is optional. + :return: The plot as an instance of ``plotly.graph_objs._figure.Figure``. + """ + df = pl.read_csv(av_reward_per_episode_csv) + + if title: + if subtitle: + title = f"{title}
{subtitle}" + else: + if subtitle: + title = subtitle + + config = _get_plotly_config() + layout = go.Layout( + autosize=config["size"]["auto_size"], + width=config["size"]["width"], + height=config["size"]["height"], + ) + # Create the line graph with a colored line + fig = go.Figure(layout=layout) + fig.update_layout(template=config["template"]) + fig.add_trace( + go.Scatter( + x=df["Episode"], + y=df["Average Reward"], + mode="lines", + name="Mean Reward per Episode", + ) + ) + + # Set the layout of the graph + fig.update_layout( + xaxis={ + "title": "Episode", + "type": "linear", + "rangeslider": {"visible": config["range_slider"]}, + }, + yaxis={"title": "Average Reward"}, + title=title, + showlegend=False, + ) + + return fig diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 81ddaaf5..23bc4a39 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -1,7 +1,7 @@ """Module for handling configurable observation spaces in PrimAITE.""" import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Dict, Final, List, Tuple, Union +from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union import numpy as np from gym import spaces @@ -63,7 +63,7 @@ class NodeLinkTable(AbstractObservationComponent): """ _FIXED_PARAMETERS: int = 4 - _MAX_VAL: int = 1_000_000 + _MAX_VAL: int = 1_000_000_000 _DATA_TYPE: type = np.int64 def __init__(self, env: "Primaite"): @@ -101,9 +101,7 @@ class NodeLinkTable(AbstractObservationComponent): self.current_observation[item_index][1] = node.hardware_state.value if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): self.current_observation[item_index][2] = node.software_state.value - self.current_observation[item_index][ - 3 - ] = node.file_system_state_observed.value + self.current_observation[item_index][3] = node.file_system_state_observed.value else: self.current_observation[item_index][2] = 0 self.current_observation[item_index][3] = 0 @@ -111,9 +109,7 @@ class NodeLinkTable(AbstractObservationComponent): if isinstance(node, ServiceNode): for service in self.env.services_list: if node.has_service(service): - self.current_observation[item_index][ - service_index - ] = node.get_service_state(service).value + self.current_observation[item_index][service_index] = node.get_service_state(service).value else: self.current_observation[item_index][service_index] = 0 service_index += 1 @@ -133,9 +129,7 @@ class NodeLinkTable(AbstractObservationComponent): protocol_list = link.get_protocol_list() protocol_index = 0 for protocol in protocol_list: - self.current_observation[item_index][ - protocol_index + 4 - ] = protocol.get_load() + self.current_observation[item_index][protocol_index + 4] = protocol.get_load() protocol_index += 1 item_index += 1 @@ -244,7 +238,12 @@ class NodeStatuses(AbstractObservationComponent): if node.has_service(service): service_states[i] = node.get_service_state(service).value obs.extend( - [hardware_state, software_state, file_system_state, *service_states] + [ + hardware_state, + software_state, + file_system_state, + *service_states, + ] ) self.current_observation[:] = obs @@ -267,9 +266,7 @@ class NodeStatuses(AbstractObservationComponent): for service in services: structure.append(f"node_{node_id}_service_{service}_state_NONE") for state in SoftwareState: - structure.append( - f"node_{node_id}_service_{service}_state_{state.name}" - ) + structure.append(f"node_{node_id}_service_{service}_state_{state.name}") return structure @@ -325,9 +322,7 @@ class LinkTrafficLevels(AbstractObservationComponent): self._entries_per_link = self.env.num_services # 1. Define the shape of your observation space component - shape = ( - [self._quantisation_levels] * self.env.num_links * self._entries_per_link - ) + shape = [self._quantisation_levels] * self.env.num_links * self._entries_per_link # 2. Create Observation space self.space = spaces.MultiDiscrete(shape) @@ -356,9 +351,7 @@ class LinkTrafficLevels(AbstractObservationComponent): elif load >= bandwidth: traffic_level = self._quantisation_levels - 1 else: - traffic_level = (load / bandwidth) // ( - 1 / (self._quantisation_levels - 2) - ) + 1 + traffic_level = (load / bandwidth) // (1 / (self._quantisation_levels - 2)) + 1 obs.append(int(traffic_level)) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index ce092cbd..dee80717 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -1,13 +1,11 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """Main environment module containing the PRIMmary AI Training Evironment (Primaite) class.""" import copy -import csv import logging import uuid as uuid -from datetime import datetime from pathlib import Path from random import choice, randint, sample, uniform -from typing import Dict, Tuple, Union +from typing import Dict, Final, Tuple, Union import networkx as nx import numpy as np @@ -15,11 +13,13 @@ import yaml from gym import Env, spaces from matplotlib import pyplot as plt +from primaite import getLogger from primaite.acl.access_control_list import AccessControlList from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( ActionType, + AgentFramework, FileSystemState, HardwareState, NodePOLInitiator, @@ -27,6 +27,7 @@ from primaite.common.enums import ( NodeType, ObservationType, Priority, + SessionType, SoftwareState, ) from primaite.common.service import Service @@ -45,9 +46,9 @@ from primaite.pol.green_pol import apply_iers, apply_node_pol from primaite.pol.ier import IER from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol from primaite.transactions.transaction import Transaction +from primaite.utils.session_output_writer import SessionOutputWriter -_LOGGER = logging.getLogger(__name__) -_LOGGER.setLevel(logging.INFO) +_LOGGER = getLogger(__name__) class Primaite(Env): @@ -63,7 +64,6 @@ class Primaite(Env): self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path], - transaction_list, session_path: Path, timestamp_str: str, ): @@ -72,26 +72,23 @@ class Primaite(Env): :param training_config_path: The training config filepath. :param lay_down_config_path: The lay down config filepath. - :param transaction_list: The list of transactions to populate. :param session_path: The directory path the session is writing to. :param timestamp_str: The session timestamp in the format: _. """ + self.session_path: Final[Path] = session_path + self.timestamp_str: Final[str] = timestamp_str self._training_config_path = training_config_path self._lay_down_config_path = lay_down_config_path - self.training_config: TrainingConfig = training_config.load( - training_config_path - ) + self.training_config: TrainingConfig = training_config.load(training_config_path) + _LOGGER.info(f"Using: {str(self.training_config)}") # Number of steps in an episode self.episode_steps = self.training_config.num_steps super(Primaite, self).__init__() - # Transaction list - self.transaction_list = transaction_list - # The agent in use self.agent_identifier = self.training_config.agent_identifier @@ -178,6 +175,9 @@ class Primaite(Env): # It will be initialised later. self.obs_handler: ObservationsHandler + self._obs_space_description = None + "The env observation space description for transactions writing" + # Open the config file and build the environment laydown with open(self._lay_down_config_path, "r") as file: # Open the config file and build the environment laydown @@ -204,27 +204,23 @@ class Primaite(Env): plt.savefig(file_path, format="PNG") plt.clf() except Exception: - _LOGGER.error("Could not save network diagram") - _LOGGER.error("Exception occured", exc_info=True) - print("Could not save network diagram") + _LOGGER.error("Could not save network diagram", exc_info=True) # Initiate observation space self.observation_space, self.env_obs = self.init_observations() # Define Action Space - depends on action space type (Node or ACL) if self.training_config.action_type == ActionType.NODE: - _LOGGER.info("Action space type NODE selected") + _LOGGER.debug("Action space type NODE selected") # Terms (for node action space): # [0, num nodes] - node ID (0 = nothing, node ID) # [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa # [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa # [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa self.action_dict = self.create_node_action_dict() - self.action_space = spaces.Discrete( - len(self.action_dict), seed=self.training_config.seed - ) + self.action_space = spaces.Discrete(len(self.action_dict), seed=self.training_config.seed) elif self.training_config.action_type == ActionType.ACL: - _LOGGER.info("Action space type ACL selected") + _LOGGER.debug("Action space type ACL selected") # Terms (for ACL action space): # [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule) # [0, 1] - Permission (0 = DENY, 1 = ALLOW) @@ -233,33 +229,31 @@ class Primaite(Env): # [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol) # [0, num ports] - Port (0 = any, then 1 -> x resolving to port) self.action_dict = self.create_acl_action_dict() - self.action_space = spaces.Discrete( - len(self.action_dict), seed=self.training_config.seed - ) + self.action_space = spaces.Discrete(len(self.action_dict), seed=self.training_config.seed) elif self.training_config.action_type == ActionType.ANY: - _LOGGER.info("Action space type ANY selected - Node + ACL") + _LOGGER.debug("Action space type ANY selected - Node + ACL") self.action_dict = self.create_node_and_acl_action_dict() - self.action_space = spaces.Discrete( - len(self.action_dict), seed=self.training_config.seed - ) + self.action_space = spaces.Discrete(len(self.action_dict), seed=self.training_config.seed) else: - _LOGGER.info( - f"Invalid action type selected: {self.training_config.action_type}" - ) - # Set up a csv to store the results of the training - try: - header = ["Episode", "Average Reward"] + _LOGGER.error(f"Invalid action type selected: {self.training_config.action_type}") - file_name = f"average_reward_per_episode_{timestamp_str}.csv" - file_path = session_path / file_name - self.csv_file = open(file_path, "w", encoding="UTF8", newline="") - self.csv_writer = csv.writer(self.csv_file) - self.csv_writer.writerow(header) - except Exception: - _LOGGER.error( - "Could not create csv file to hold average reward per episode" - ) - _LOGGER.error("Exception occured", exc_info=True) + self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=True) + self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=True) + + @property + def actual_episode_count(self) -> int: + """Shifts the episode_count by -1 for RLlib.""" + if self.training_config.agent_framework is AgentFramework.RLLIB: + return self.episode_count - 1 + return self.episode_count + + def set_as_eval(self): + """Set the writers to write to eval directories.""" + self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=False) + self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=False) + self.episode_count = 0 + self.step_count = 0 + self.total_step_count = 0 def reset(self): """ @@ -268,12 +262,14 @@ class Primaite(Env): Returns: Environment observation space (reset) """ - csv_data = self.episode_count, self.average_reward - self.csv_writer.writerow(csv_data) + if self.actual_episode_count > 0: + csv_data = self.actual_episode_count, self.average_reward + self.episode_av_reward_writer.write(csv_data) self.episode_count += 1 - # Don't need to reset links, as they are cleared and recalculated every step + # Don't need to reset links, as they are cleared and recalculated every + # step # Clear the ACL self.init_acl() @@ -293,6 +289,7 @@ class Primaite(Env): # Update observations space and return self.update_environent_obs() + return self.env_obs def step(self, action): @@ -308,15 +305,10 @@ class Primaite(Env): done: Indicates episode is complete if True step_info: Additional information relating to this step """ - if self.step_count == 0: - print(f"Episode: {str(self.episode_count)}") - # TEMP done = False - self.step_count += 1 self.total_step_count += 1 - # print("Episode step: " + str(self.step_count)) # Need to clear traffic on all links first for link_key, link_value in self.links.items(): @@ -326,14 +318,15 @@ class Primaite(Env): link.clear_traffic() # Create a Transaction (metric) object for this step - transaction = Transaction( - datetime.now(), self.agent_identifier, self.episode_count, self.step_count - ) + transaction = Transaction(self.agent_identifier, self.actual_episode_count, self.step_count) # Load the initial observation space into the transaction - transaction.set_obs_space(self.obs_handler._flat_observation) + transaction.obs_space = self.obs_handler._flat_observation + + # Set the transaction obs space description + transaction.obs_space_description = self._obs_space_description # Load the action space into the transaction - transaction.set_action_space(copy.deepcopy(action)) + transaction.action_space = copy.deepcopy(action) # 1. Implement Blue Action self.interpret_action_and_apply(action) @@ -377,9 +370,7 @@ class Primaite(Env): self.acl, self.step_count, ) - apply_red_agent_node_pol( - self.nodes, self.red_iers, self.red_node_pol, self.step_count - ) + apply_red_agent_node_pol(self.nodes, self.red_iers, self.red_node_pol, self.step_count) # Take snapshots of nodes and links self.nodes_post_red = copy.deepcopy(self.nodes) self.links_post_red = copy.deepcopy(self.links) @@ -395,17 +386,17 @@ class Primaite(Env): self.step_count, self.training_config, ) - # print(f" Step {self.step_count} Reward: {str(reward)}") + _LOGGER.debug(f"Episode: {self.actual_episode_count}, " f"Step {self.step_count}, " f"Reward: {reward}") self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count - if self.training_config.session_type == "EVALUATION": + if self.training_config.session_type is SessionType.EVAL: # For evaluation, need to trigger the done value = True when # step count is reached in order to prevent neverending episode done = True - print(f" Average Reward: {str(self.average_reward)}") + _LOGGER.info(f"Episode: {self.actual_episode_count}, " f"Average Reward: {self.average_reward}") # Load the reward into the transaction - transaction.set_reward(reward) + transaction.reward = reward # 6. Output Verbose # self.output_link_status() @@ -413,19 +404,21 @@ class Primaite(Env): # 7. Update env_obs self.update_environent_obs() - # 8. Add the transaction to the list of transactions - self.transaction_list.append(copy.deepcopy(transaction)) + # Write transaction to file + if self.actual_episode_count > 0: + self.transaction_writer.write(transaction) # Return return self.env_obs, reward, done, self.step_info def close(self): - """Calls the __close__ method.""" - self.__close__() + """Override parent close and close writers.""" + # Close files if last episode/step + # if self.can_finish: + super().close() - def __close__(self): - """Override close function.""" - self.csv_file.close() + self.transaction_writer.close() + self.episode_av_reward_writer.close() def init_acl(self): """Initialise the Access Control List.""" @@ -434,14 +427,9 @@ class Primaite(Env): def output_link_status(self): """Output the link status of all links to the console.""" for link_key, link_value in self.links.items(): - print("Link ID: " + link_value.get_id()) + _LOGGER.debug("Link ID: " + link_value.get_id()) for protocol in link_value.protocol_list: - print( - " Protocol: " - + protocol.get_name().name - + ", Load: " - + str(protocol.get_load()) - ) + print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load())) def interpret_action_and_apply(self, _action): """ @@ -456,13 +444,9 @@ class Primaite(Env): self.apply_actions_to_nodes(_action) elif self.training_config.action_type == ActionType.ACL: self.apply_actions_to_acl(_action) - elif ( - len(self.action_dict[_action]) == 6 - ): # ACL actions in multidiscrete form have len 6 + elif len(self.action_dict[_action]) == 6: # ACL actions in multidiscrete form have len 6 self.apply_actions_to_acl(_action) - elif ( - len(self.action_dict[_action]) == 4 - ): # Node actions in multdiscrete (array) from have len 4 + elif len(self.action_dict[_action]) == 4: # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: logging.error("Invalid action type found") @@ -528,9 +512,7 @@ class Primaite(Env): return elif property_action == 1: # Patch (valid action if it's good or compromised) - node.set_service_state( - self.services_list[service_index], SoftwareState.PATCHING - ) + node.set_service_state(self.services_list[service_index], SoftwareState.PATCHING) else: # Node is not of Service Type return @@ -685,6 +667,9 @@ class Primaite(Env): """ self.obs_handler = ObservationsHandler.from_config(self, self.obs_config) + if not self._obs_space_description: + self._obs_space_description = self.obs_handler.describe_structure() + return self.obs_handler.space, self.obs_handler.current_observation def update_environent_obs(self): @@ -1235,11 +1220,7 @@ class Primaite(Env): # Change node keys to not overlap with acl keys # Only 1 nothing action (key 0) is required, remove the other - new_node_action_dict = { - k + len(acl_action_dict) - 1: v - for k, v in node_action_dict.items() - if k != 0 - } + new_node_action_dict = {k + len(acl_action_dict) - 1: v for k, v in node_action_dict.items() if k != 0} # Combine the Node dict and ACL dict combined_action_dict = {**acl_action_dict, **new_node_action_dict} @@ -1254,9 +1235,7 @@ class Primaite(Env): # Decide how many nodes become compromised node_list = list(self.nodes.values()) computers = [node for node in node_list if node.node_type == NodeType.COMPUTER] - max_num_nodes_compromised = len( - computers - ) # only computers can become compromised + max_num_nodes_compromised = len(computers) # only computers can become compromised # random select between 1 and max_num_nodes_compromised num_nodes_to_compromise = randint(1, max_num_nodes_compromised) @@ -1267,9 +1246,7 @@ class Primaite(Env): source_node = choice(nodes_to_be_compromised) # For each of the nodes to be compromised decide which step they become compromised - max_step_compromised = ( - self.episode_steps // 2 - ) # always compromise in first half of episode + max_step_compromised = self.episode_steps // 2 # always compromise in first half of episode # Bandwidth for all links bandwidths = [i.get_bandwidth() for i in list(self.links.values())] @@ -1318,9 +1295,7 @@ class Primaite(Env): ier_protocol = pol_service_name # Same protocol as compromised node ier_service = node.services[pol_service_name] ier_port = ier_service.port - ier_mission_criticality = ( - 0 # Red IER will never be important to green agent success - ) + ier_mission_criticality = 0 # Red IER will never be important to green agent success # We choose a node to attack based on the first that applies: # a. Green IERs, select dest node of the red ier based on dest node of green IER # b. Attack a random server that doesn't have a DENY acl rule in default config diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index 1a1a0770..19094a18 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -41,27 +41,19 @@ def calculate_reward_function( reference_node = reference_nodes[node_key] # Hardware State - reward_value += score_node_operating_state( - final_node, initial_node, reference_node, config_values - ) + reward_value += score_node_operating_state(final_node, initial_node, reference_node, config_values) # Software State if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode): - reward_value += score_node_os_state( - final_node, initial_node, reference_node, config_values - ) + reward_value += score_node_os_state(final_node, initial_node, reference_node, config_values) # Service State if isinstance(final_node, ServiceNode): - reward_value += score_node_service_state( - final_node, initial_node, reference_node, config_values - ) + reward_value += score_node_service_state(final_node, initial_node, reference_node, config_values) # File System State if isinstance(final_node, ActiveNode): - reward_value += score_node_file_system( - final_node, initial_node, reference_node, config_values - ) + reward_value += score_node_file_system(final_node, initial_node, reference_node, config_values) # Go through each red IER - penalise if it is running for ier_key, ier_value in red_iers.items(): @@ -80,14 +72,9 @@ def calculate_reward_function( if step_count >= start_step and step_count <= stop_step: reference_blocked = not reference_ier.get_is_running() live_blocked = not ier_value.get_is_running() - ier_reward = ( - config_values.green_ier_blocked * ier_value.get_mission_criticality() - ) + ier_reward = config_values.green_ier_blocked * ier_value.get_mission_criticality() if live_blocked and not reference_blocked: - _LOGGER.debug( - f"Applying reward of {ier_reward} because IER {ier_key} is blocked" - ) reward_value += ier_reward elif live_blocked and reference_blocked: _LOGGER.debug( diff --git a/src/primaite/main.py b/src/primaite/main.py index f315cd34..7b1d7ab3 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -1,367 +1,29 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. -""" -The main PrimAITE session runner module. - -TODO: This will eventually be refactored out into a proper Session class. -TODO: The passing about of session_dir and timestamp_str is temporary and - will be cleaned up once we move to a proper Session class. -""" +"""The main PrimAITE session runner module.""" import argparse -import json -import time -from datetime import datetime from pathlib import Path -from typing import Final, Union -from uuid import uuid4 +from typing import Union -import numpy as np -from stable_baselines3 import A2C, PPO -from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm -from stable_baselines3.ppo import MlpPolicy as PPOMlp - -from primaite import SESSIONS_DIR, getLogger -from primaite.config.training_config import TrainingConfig -from primaite.environment.primaite_env import Primaite -from primaite.transactions.transactions_to_file import write_transaction_to_file +from primaite import getLogger +from primaite.primaite_session import PrimaiteSession _LOGGER = getLogger(__name__) -def run_generic(env: Primaite, config_values: TrainingConfig): - """ - Run against a generic agent. - - :param env: An instance of - :class:`~primaite.environment.primaite_env.Primaite`. - :param config_values: An instance of - :class:`~primaite.config.training_config.TrainingConfig`. - """ - for episode in range(0, config_values.num_episodes): - env.reset() - for step in range(0, config_values.num_steps): - # Send the observation space to the agent to get an action - # TEMP - random action for now - # action = env.blue_agent_action(obs) - action = env.action_space.sample() - - # Run the simulation step on the live environment - obs, reward, done, info = env.step(action) - - # Break if done is True - if done: - break - - # Introduce a delay between steps - time.sleep(config_values.time_delay / 1000) - env.close() - - -def run_stable_baselines3_ppo( - env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str +def run( + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], ): - """ - Run against a stable_baselines3 PPO agent. - - :param env: An instance of - :class:`~primaite.environment.primaite_env.Primaite`. - :param config_values: An instance of - :class:`~primaite.config.training_config.TrainingConfig`. - :param session_path: The directory path the session is writing to. - :param timestamp_str: The session timestamp in the format: - _. - """ - if config_values.load_agent: - try: - agent = PPO.load( - config_values.agent_load_file, - env, - verbose=0, - n_steps=config_values.num_steps, - ) - except Exception: - print( - "ERROR: Could not load agent at location: " - + config_values.agent_load_file - ) - _LOGGER.error("Could not load agent") - _LOGGER.error("Exception occured", exc_info=True) - else: - agent = PPO( - PPOMlp, - env, - verbose=0, - n_steps=config_values.num_steps, - seed=env.training_config.seed, - ) - - if config_values.session_type == "TRAINING": - # We're in a training session - print("Starting training session...") - _LOGGER.debug("Starting training session...") - for episode in range(config_values.num_episodes): - agent.learn(total_timesteps=config_values.num_steps) - _save_agent(agent, session_path, timestamp_str) - else: - # Default to being in an evaluation session - print("Starting evaluation session...") - _LOGGER.debug("Starting evaluation session...") - - for episode in range(0, config_values.num_episodes): - obs = env.reset() - - for step in range(0, config_values.num_steps): - action, _states = agent.predict( - obs, deterministic=env.training_config.deterministic - ) - # convert to int if action is a numpy array - if isinstance(action, np.ndarray): - action = np.int64(action) - obs, rewards, done, info = env.step(action) - env.close() - - -def run_stable_baselines3_a2c( - env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str -): - """ - Run against a stable_baselines3 A2C agent. - - :param env: An instance of - :class:`~primaite.environment.primaite_env.Primaite`. - :param config_values: An instance of - :class:`~primaite.config.training_config.TrainingConfig`. - param session_path: The directory path the session is writing to. - :param timestamp_str: The session timestamp in the format: - _. - """ - if config_values.load_agent: - try: - agent = A2C.load( - config_values.agent_load_file, - env, - verbose=0, - n_steps=config_values.num_steps, - ) - except Exception: - print( - "ERROR: Could not load agent at location: " - + config_values.agent_load_file - ) - _LOGGER.error("Could not load agent") - _LOGGER.error("Exception occured", exc_info=True) - else: - agent = A2C( - "MlpPolicy", - env, - verbose=0, - n_steps=config_values.num_steps, - seed=env.training_config.seed, - ) - - if config_values.session_type == "TRAINING": - # We're in a training session - print("Starting training session...") - _LOGGER.debug("Starting training session...") - for episode in range(config_values.num_episodes): - agent.learn(total_timesteps=config_values.num_steps) - _save_agent(agent, session_path, timestamp_str) - else: - # Default to being in an evaluation session - print("Starting evaluation session...") - _LOGGER.debug("Starting evaluation session...") - for episode in range(0, config_values.num_episodes): - obs = env.reset() - - for step in range(0, config_values.num_steps): - action, _states = agent.predict( - obs, deterministic=env.training_config.deterministic - ) - # convert to int if action is a numpy array - if isinstance(action, np.ndarray): - action = np.int64(action) - obs, rewards, done, info = env.step(action) - - env.close() - - -def _write_session_metadata_file( - session_dir: Path, uuid: str, session_timestamp: datetime, env: Primaite -): - """ - Write the ``session_metadata.json`` file. - - Creates a ``session_metadata.json`` in the ``session_dir`` directory - and adds the following key/value pairs: - - - uuid: The UUID assigned to the session upon instantiation. - - start_datetime: The date & time the session started in iso format. - - end_datetime: NULL. - - total_episodes: NULL. - - total_time_steps: NULL. - - env: - - training_config: - - All training config items - - lay_down_config: - - All lay down config items - - """ - metadata_dict = { - "uuid": uuid, - "start_datetime": session_timestamp.isoformat(), - "end_datetime": None, - "total_episodes": None, - "total_time_steps": None, - "env": { - "training_config": env.training_config.to_dict(json_serializable=True), - "lay_down_config": env.lay_down_config, - }, - } - filepath = session_dir / "session_metadata.json" - _LOGGER.debug(f"Writing Session Metadata file: {filepath}") - with open(filepath, "w") as file: - json.dump(metadata_dict, file) - - -def _update_session_metadata_file(session_dir: Path, env: Primaite): - """ - Update the ``session_metadata.json`` file. - - Updates the `session_metadata.json`` in the ``session_dir`` directory - with the following key/value pairs: - - - end_datetime: NULL. - - total_episodes: NULL. - - total_time_steps: NULL. - """ - with open(session_dir / "session_metadata.json", "r") as file: - metadata_dict = json.load(file) - - metadata_dict["end_datetime"] = datetime.now().isoformat() - metadata_dict["total_episodes"] = env.episode_count - metadata_dict["total_time_steps"] = env.total_step_count - - filepath = session_dir / "session_metadata.json" - _LOGGER.debug(f"Updating Session Metadata file: {filepath}") - with open(filepath, "w") as file: - json.dump(metadata_dict, file) - - -def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str): - """ - Persist an agent. - - Only works for stable baselines3 agents at present. - - :param session_path: The directory path the session is writing to. - :param timestamp_str: The session timestamp in the format: - _. - """ - if not isinstance(agent, OnPolicyAlgorithm): - msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}." - _LOGGER.error(msg) - else: - filepath = session_path / f"agent_saved_{timestamp_str}" - agent.save(filepath) - _LOGGER.debug(f"Trained agent saved as: {filepath}") - - -def _get_session_path(session_timestamp: datetime) -> Path: - """ - Get the directory path the session will output to. - - This is set in the format of: - ~/primaite/sessions//_. - - :param session_timestamp: This is the datetime that the session started. - :return: The session directory path. - """ - date_dir = session_timestamp.strftime("%Y-%m-%d") - session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = SESSIONS_DIR / date_dir / session_dir - session_path.mkdir(exist_ok=True, parents=True) - - return session_path - - -def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]): """Run the PrimAITE Session. :param training_config_path: The training config filepath. :param lay_down_config_path: The lay down config filepath. """ - # Welcome message - print("Welcome to the Primary-level AI Training Environment (PrimAITE)") - uuid = str(uuid4()) - session_timestamp: Final[datetime] = datetime.now() - session_dir = _get_session_path(session_timestamp) - timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + session = PrimaiteSession(training_config_path, lay_down_config_path) - print(f"The output directory for this session is: {session_dir}") - - # Create a list of transactions - # A transaction is an object holding the: - # - episode # - # - step # - # - initial observation space - # - action - # - reward - # - new observation space - transaction_list = [] - - # Create the Primaite environment - env = Primaite( - training_config_path=training_config_path, - lay_down_config_path=lay_down_config_path, - transaction_list=transaction_list, - session_path=session_dir, - timestamp_str=timestamp_str, - ) - - print("Writing Session Metadata file...") - - _write_session_metadata_file( - session_dir=session_dir, uuid=uuid, session_timestamp=session_timestamp, env=env - ) - - config_values = env.training_config - - # Get the number of steps (which is stored in the child config file) - config_values.num_steps = env.episode_steps - - # Run environment against an agent - if config_values.agent_identifier == "GENERIC": - run_generic(env=env, config_values=config_values) - elif config_values.agent_identifier == "STABLE_BASELINES3_PPO": - run_stable_baselines3_ppo( - env=env, - config_values=config_values, - session_path=session_dir, - timestamp_str=timestamp_str, - ) - elif config_values.agent_identifier == "STABLE_BASELINES3_A2C": - run_stable_baselines3_a2c( - env=env, - config_values=config_values, - session_path=session_dir, - timestamp_str=timestamp_str, - ) - - print("Session finished") - _LOGGER.debug("Session finished") - - print("Saving transaction logs...") - write_transaction_to_file( - transaction_list=transaction_list, - session_path=session_dir, - timestamp_str=timestamp_str, - obs_space_description=env.obs_handler.describe_structure(), - ) - - print("Updating Session Metadata file...") - _update_session_metadata_file(session_dir=session_dir, env=env) - - print("Finished") - _LOGGER.debug("Finished") + session.setup() + session.learn() + session.evaluate() if __name__ == "__main__": @@ -370,11 +32,7 @@ if __name__ == "__main__": parser.add_argument("--ldc") args = parser.parse_args() if not args.tc: - _LOGGER.error( - "Please provide a training config file using the --tc " "argument" - ) + _LOGGER.error("Please provide a training config file using the --tc " "argument") if not args.ldc: - _LOGGER.error( - "Please provide a lay down config file using the --ldc " "argument" - ) + _LOGGER.error("Please provide a lay down config file using the --ldc " "argument") run(training_config_path=args.tc, lay_down_config_path=args.ldc) diff --git a/src/primaite/nodes/active_node.py b/src/primaite/nodes/active_node.py index 57fa4c68..07a0ea0a 100644 --- a/src/primaite/nodes/active_node.py +++ b/src/primaite/nodes/active_node.py @@ -3,13 +3,7 @@ import logging from typing import Final -from primaite.common.enums import ( - FileSystemState, - HardwareState, - NodeType, - Priority, - SoftwareState, -) +from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState from primaite.config.training_config import TrainingConfig from primaite.nodes.node import Node @@ -44,9 +38,7 @@ class ActiveNode(Node): :param file_system_state: The node file system state :param config_values: The config values """ - super().__init__( - node_id, name, node_type, priority, hardware_state, config_values - ) + super().__init__(node_id, name, node_type, priority, hardware_state, config_values) self.ip_address: str = ip_address # Related to Software self._software_state: SoftwareState = software_state @@ -125,14 +117,10 @@ class ActiveNode(Node): self.file_system_state_actual = file_system_state if file_system_state == FileSystemState.REPAIRING: - self.file_system_action_count = ( - self.config_values.file_system_repairing_limit - ) + self.file_system_action_count = self.config_values.file_system_repairing_limit self.file_system_state_observed = FileSystemState.REPAIRING elif file_system_state == FileSystemState.RESTORING: - self.file_system_action_count = ( - self.config_values.file_system_restoring_limit - ) + self.file_system_action_count = self.config_values.file_system_restoring_limit self.file_system_state_observed = FileSystemState.RESTORING elif file_system_state == FileSystemState.GOOD: self.file_system_state_observed = FileSystemState.GOOD @@ -145,9 +133,7 @@ class ActiveNode(Node): f"Node.file_system_state.actual:{self.file_system_state_actual}" ) - def set_file_system_state_if_not_compromised( - self, file_system_state: FileSystemState - ): + def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState): """ Sets the file system state (actual and observed) if not in a compromised state. @@ -164,14 +150,10 @@ class ActiveNode(Node): self.file_system_state_actual = file_system_state if file_system_state == FileSystemState.REPAIRING: - self.file_system_action_count = ( - self.config_values.file_system_repairing_limit - ) + self.file_system_action_count = self.config_values.file_system_repairing_limit self.file_system_state_observed = FileSystemState.REPAIRING elif file_system_state == FileSystemState.RESTORING: - self.file_system_action_count = ( - self.config_values.file_system_restoring_limit - ) + self.file_system_action_count = self.config_values.file_system_restoring_limit self.file_system_state_observed = FileSystemState.RESTORING elif file_system_state == FileSystemState.GOOD: self.file_system_state_observed = FileSystemState.GOOD diff --git a/src/primaite/nodes/passive_node.py b/src/primaite/nodes/passive_node.py index 6515097a..9aa5c7d7 100644 --- a/src/primaite/nodes/passive_node.py +++ b/src/primaite/nodes/passive_node.py @@ -28,9 +28,7 @@ class PassiveNode(Node): :param config_values: Config values. """ # Pass through to Super for now - super().__init__( - node_id, name, node_type, priority, hardware_state, config_values - ) + super().__init__(node_id, name, node_type, priority, hardware_state, config_values) @property def ip_address(self) -> str: diff --git a/src/primaite/nodes/service_node.py b/src/primaite/nodes/service_node.py index 324592c3..5d69df92 100644 --- a/src/primaite/nodes/service_node.py +++ b/src/primaite/nodes/service_node.py @@ -3,13 +3,7 @@ import logging from typing import Dict, Final -from primaite.common.enums import ( - FileSystemState, - HardwareState, - NodeType, - Priority, - SoftwareState, -) +from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState from primaite.common.service import Service from primaite.config.training_config import TrainingConfig from primaite.nodes.active_node import ActiveNode @@ -128,9 +122,7 @@ class ServiceNode(ActiveNode): ) or software_state != SoftwareState.COMPROMISED: service_value.software_state = software_state if software_state == SoftwareState.PATCHING: - service_value.patching_count = ( - self.config_values.service_patching_duration - ) + service_value.patching_count = self.config_values.service_patching_duration else: _LOGGER.info( f"The Nodes hardware state is OFF so the state of a service " @@ -141,9 +133,7 @@ class ServiceNode(ActiveNode): f"Node.services[].software_state:{software_state}" ) - def set_service_state_if_not_compromised( - self, protocol_name: str, software_state: SoftwareState - ): + def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState): """ Sets the software_state of a service (protocol) on the node. @@ -159,9 +149,7 @@ class ServiceNode(ActiveNode): if service_value.software_state != SoftwareState.COMPROMISED: service_value.software_state = software_state if software_state == SoftwareState.PATCHING: - service_value.patching_count = ( - self.config_values.service_patching_duration - ) + service_value.patching_count = self.config_values.service_patching_duration else: _LOGGER.info( f"The Nodes hardware state is OFF so the state of a service " diff --git a/src/primaite/notebooks/__init__.py b/src/primaite/notebooks/__init__.py index 71ed343e..0e81e581 100644 --- a/src/primaite/notebooks/__init__.py +++ b/src/primaite/notebooks/__init__.py @@ -4,7 +4,7 @@ import os import subprocess import sys -from primaite import NOTEBOOKS_DIR, getLogger +from primaite import getLogger, NOTEBOOKS_DIR _LOGGER = getLogger(__name__) diff --git a/src/primaite/pol/green_pol.py b/src/primaite/pol/green_pol.py index 1d05dc3f..e9dfef8c 100644 --- a/src/primaite/pol/green_pol.py +++ b/src/primaite/pol/green_pol.py @@ -86,9 +86,7 @@ def apply_iers( and source_node.software_state != SoftwareState.PATCHING ): if source_node.has_service(protocol): - if source_node.service_running( - protocol - ) and not source_node.service_is_overwhelmed(protocol): + if source_node.service_running(protocol) and not source_node.service_is_overwhelmed(protocol): source_valid = True else: source_valid = False @@ -103,10 +101,7 @@ def apply_iers( # 2. Check the dest node situation if dest_node.node_type == NodeType.SWITCH: # It's a switch - if ( - dest_node.hardware_state == HardwareState.ON - and dest_node.software_state != SoftwareState.PATCHING - ): + if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING: dest_valid = True else: # IER no longer valid @@ -116,14 +111,9 @@ def apply_iers( pass else: # It's not a switch or an actuator (so active node) - if ( - dest_node.hardware_state == HardwareState.ON - and dest_node.software_state != SoftwareState.PATCHING - ): + if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING: if dest_node.has_service(protocol): - if dest_node.service_running( - protocol - ) and not dest_node.service_is_overwhelmed(protocol): + if dest_node.service_running(protocol) and not dest_node.service_is_overwhelmed(protocol): dest_valid = True else: dest_valid = False @@ -136,9 +126,7 @@ def apply_iers( dest_valid = False # 3. Check that the ACL doesn't block it - acl_block = acl.is_blocked( - source_node.ip_address, dest_node.ip_address, protocol, port - ) + acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port) if acl_block: if _VERBOSE: print( @@ -169,10 +157,7 @@ def apply_iers( # We might have a switch in the path, so check all nodes are operational for node in path_node_list: - if ( - node.hardware_state != HardwareState.ON - or node.software_state == SoftwareState.PATCHING - ): + if node.hardware_state != HardwareState.ON or node.software_state == SoftwareState.PATCHING: path_valid = False if path_valid: @@ -184,9 +169,7 @@ def apply_iers( # Check that the link capacity is not exceeded by the new load while count < path_node_list_length - 1: # Get the link between the next two nodes - edge_dict = network.get_edge_data( - path_node_list[count], path_node_list[count + 1] - ) + edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1]) link_id = edge_dict[0].get("id") link = links[link_id] # Check whether the new load exceeds the bandwidth @@ -204,7 +187,8 @@ def apply_iers( while count < path_node_list_length - 1: # Get the link between the next two nodes edge_dict = network.get_edge_data( - path_node_list[count], path_node_list[count + 1] + path_node_list[count], + path_node_list[count + 1], ) link_id = edge_dict[0].get("id") link = links[link_id] diff --git a/src/primaite/pol/red_agent_pol.py b/src/primaite/pol/red_agent_pol.py index b23992e7..bff19bf8 100644 --- a/src/primaite/pol/red_agent_pol.py +++ b/src/primaite/pol/red_agent_pol.py @@ -6,13 +6,7 @@ from networkx import MultiGraph, shortest_path from primaite.acl.access_control_list import AccessControlList from primaite.common.custom_typing import NodeUnion -from primaite.common.enums import ( - HardwareState, - NodePOLInitiator, - NodePOLType, - NodeType, - SoftwareState, -) +from primaite.common.enums import HardwareState, NodePOLInitiator, NodePOLType, NodeType, SoftwareState from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed @@ -83,10 +77,7 @@ def apply_red_agent_iers( if source_node.hardware_state == HardwareState.ON: if source_node.has_service(protocol): # Red agents IERs can only be valid if the source service is in a compromised state - if ( - source_node.get_service_state(protocol) - == SoftwareState.COMPROMISED - ): + if source_node.get_service_state(protocol) == SoftwareState.COMPROMISED: source_valid = True else: source_valid = False @@ -124,9 +115,7 @@ def apply_red_agent_iers( dest_valid = False # 3. Check that the ACL doesn't block it - acl_block = acl.is_blocked( - source_node.ip_address, dest_node.ip_address, protocol, port - ) + acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port) if acl_block: if _VERBOSE: print( @@ -170,9 +159,7 @@ def apply_red_agent_iers( # Check that the link capacity is not exceeded by the new load while count < path_node_list_length - 1: # Get the link between the next two nodes - edge_dict = network.get_edge_data( - path_node_list[count], path_node_list[count + 1] - ) + edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1]) link_id = edge_dict[0].get("id") link = links[link_id] # Check whether the new load exceeds the bandwidth @@ -190,7 +177,8 @@ def apply_red_agent_iers( while count < path_node_list_length - 1: # Get the link between the next two nodes edge_dict = network.get_edge_data( - path_node_list[count], path_node_list[count + 1] + path_node_list[count], + path_node_list[count + 1], ) link_id = edge_dict[0].get("id") link = links[link_id] @@ -248,9 +236,7 @@ def apply_red_agent_node_pol( state = node_instruction.get_state() source_node_id = node_instruction.get_source_node_id() source_node_service_name = node_instruction.get_source_node_service() - source_node_service_state_value = ( - node_instruction.get_source_node_service_state() - ) + source_node_service_state_value = node_instruction.get_source_node_service_state() passed_checks = False @@ -292,9 +278,7 @@ def apply_red_agent_node_pol( target_node.hardware_state = state elif pol_type == NodePOLType.OS: # Change OS state - if isinstance(target_node, ActiveNode) or isinstance( - target_node, ServiceNode - ): + if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode): target_node.software_state = state elif pol_type == NodePOLType.SERVICE: # Change a service state @@ -302,9 +286,7 @@ def apply_red_agent_node_pol( target_node.set_service_state(service_name, state) else: # Change the file system status - if isinstance(target_node, ActiveNode) or isinstance( - target_node, ServiceNode - ): + if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode): target_node.set_file_system_state(state) else: if _VERBOSE: diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py new file mode 100644 index 00000000..df3ebec1 --- /dev/null +++ b/src/primaite/primaite_session.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Dict, Final, Union + +from primaite import getLogger +from primaite.agents.agent import AgentSessionABC +from primaite.agents.hardcoded_acl import HardCodedACLAgent +from primaite.agents.hardcoded_node import HardCodedNodeAgent +from primaite.agents.rllib import RLlibAgent +from primaite.agents.sb3 import SB3Agent +from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyAgent, RandomAgent +from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType +from primaite.config import lay_down_config, training_config +from primaite.config.training_config import TrainingConfig + +_LOGGER = getLogger(__name__) + + +class PrimaiteSession: + """ + The PrimaiteSession class. + + Provides a single learning and evaluation entry point for all training + and lay down configurations. + """ + + def __init__( + self, + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + ): + """ + The PrimaiteSession constructor. + + :param training_config_path: The training config path. + :param lay_down_config_path: The lay down config path. + """ + if not isinstance(training_config_path, Path): + training_config_path = Path(training_config_path) + self._training_config_path: Final[Union[Path]] = training_config_path + self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path) + + if not isinstance(lay_down_config_path, Path): + lay_down_config_path = Path(lay_down_config_path) + self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path + self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) + + self._agent_session: AgentSessionABC = None # noqa + self.session_path: Path = None # noqa + self.timestamp_str: str = None # noqa + self.learning_path: Path = None # noqa + self.evaluation_path: Path = None # noqa + + def setup(self): + """Performs the session setup.""" + if self._training_config.agent_framework == AgentFramework.CUSTOM: + _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}") + if self._training_config.agent_identifier == AgentIdentifier.HARDCODED: + _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}") + if self._training_config.action_type == ActionType.NODE: + # Deterministic Hardcoded Agent with Node Action Space + self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path) + + elif self._training_config.action_type == ActionType.ACL: + # Deterministic Hardcoded Agent with ACL Action Space + self._agent_session = HardCodedACLAgent(self._training_config_path, self._lay_down_config_path) + + elif self._training_config.action_type == ActionType.ANY: + # Deterministic Hardcoded Agent with ANY Action Space + raise NotImplementedError + + else: + # Invalid AgentIdentifier ActionType combo + raise ValueError + + elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING: + _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHING}") + if self._training_config.action_type == ActionType.NODE: + self._agent_session = DoNothingNodeAgent(self._training_config_path, self._lay_down_config_path) + + elif self._training_config.action_type == ActionType.ACL: + # Deterministic Hardcoded Agent with ACL Action Space + self._agent_session = DoNothingACLAgent(self._training_config_path, self._lay_down_config_path) + + elif self._training_config.action_type == ActionType.ANY: + # Deterministic Hardcoded Agent with ANY Action Space + raise NotImplementedError + + else: + # Invalid AgentIdentifier ActionType combo + raise ValueError + + elif self._training_config.agent_identifier == AgentIdentifier.RANDOM: + _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}") + self._agent_session = RandomAgent(self._training_config_path, self._lay_down_config_path) + elif self._training_config.agent_identifier == AgentIdentifier.DUMMY: + _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}") + self._agent_session = DummyAgent(self._training_config_path, self._lay_down_config_path) + + else: + # Invalid AgentFramework AgentIdentifier combo + raise ValueError + + elif self._training_config.agent_framework == AgentFramework.SB3: + _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}") + # Stable Baselines3 Agent + self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path) + + elif self._training_config.agent_framework == AgentFramework.RLLIB: + _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}") + # Ray RLlib Agent + self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path) + + else: + # Invalid AgentFramework + raise ValueError + + self.session_path: Path = self._agent_session.session_path + self.timestamp_str: str = self._agent_session.timestamp_str + self.learning_path: Path = self._agent_session.learning_path + self.evaluation_path: Path = self._agent_session.evaluation_path + + def learn( + self, + **kwargs, + ): + """ + Train the agent. + + :param kwargs: Any agent-framework specific key word args. + """ + if not self._training_config.session_type == SessionType.EVAL: + self._agent_session.learn(**kwargs) + + def evaluate( + self, + **kwargs, + ): + """ + Evaluate the agent. + + :param kwargs: Any agent-framework specific key word args. + """ + if not self._training_config.session_type == SessionType.TRAIN: + self._agent_session.evaluate(**kwargs) + + def close(self): + """Closes the agent.""" + self._agent_session.close() diff --git a/src/primaite/setup/_package_data/primaite_config.yaml b/src/primaite/setup/_package_data/primaite_config.yaml index 690544fb..b9e0d73c 100644 --- a/src/primaite/setup/_package_data/primaite_config.yaml +++ b/src/primaite/setup/_package_data/primaite_config.yaml @@ -1,5 +1,22 @@ # The main PrimAITE application config file # Logging -log_level: INFO -logger_format: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' +logging: + log_level: INFO + logger_format: + DEBUG: '%(asctime)s: %(message)s' + INFO: '%(asctime)s: %(message)s' + WARNING: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' + ERROR: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' + CRITICAL: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' + +# Session +session: + outputs: + plots: + size: + auto_size: false + width: 1500 + height: 900 + template: plotly_white + range_slider: false diff --git a/src/primaite/setup/reset_demo_notebooks.py b/src/primaite/setup/reset_demo_notebooks.py index 5192c48f..7fa96783 100644 --- a/src/primaite/setup/reset_demo_notebooks.py +++ b/src/primaite/setup/reset_demo_notebooks.py @@ -6,7 +6,7 @@ from pathlib import Path import pkg_resources -from primaite import NOTEBOOKS_DIR, getLogger +from primaite import getLogger, NOTEBOOKS_DIR _LOGGER = getLogger(__name__) @@ -18,9 +18,7 @@ def run(overwrite_existing: bool = True): :param overwrite_existing: A bool to toggle replacing existing edited notebooks on or off. """ - notebooks_package_data_root = pkg_resources.resource_filename( - "primaite", "notebooks/_package_data" - ) + notebooks_package_data_root = pkg_resources.resource_filename("primaite", "notebooks/_package_data") for subdir, dirs, files in os.walk(notebooks_package_data_root): for file in files: fp = os.path.join(subdir, file) @@ -30,9 +28,7 @@ def run(overwrite_existing: bool = True): copy_file = not target_fp.is_file() if overwrite_existing and not copy_file: - copy_file = (not filecmp.cmp(fp, target_fp)) and ( - ".ipynb_checkpoints" not in str(target_fp) - ) + copy_file = (not filecmp.cmp(fp, target_fp)) and (".ipynb_checkpoints" not in str(target_fp)) if copy_file: shutil.copy2(fp, target_fp) diff --git a/src/primaite/setup/reset_example_configs.py b/src/primaite/setup/reset_example_configs.py index f4166c6a..5d62298c 100644 --- a/src/primaite/setup/reset_example_configs.py +++ b/src/primaite/setup/reset_example_configs.py @@ -5,7 +5,7 @@ from pathlib import Path import pkg_resources -from primaite import USERS_CONFIG_DIR, getLogger +from primaite import getLogger, USERS_CONFIG_DIR _LOGGER = getLogger(__name__) @@ -17,9 +17,7 @@ def run(overwrite_existing=True): :param overwrite_existing: A bool to toggle replacing existing edited config on or off. """ - configs_package_data_root = pkg_resources.resource_filename( - "primaite", "config/_package_data" - ) + configs_package_data_root = pkg_resources.resource_filename("primaite", "config/_package_data") for subdir, dirs, files in os.walk(configs_package_data_root): for file in files: diff --git a/src/primaite/setup/setup_app_dirs.py b/src/primaite/setup/setup_app_dirs.py index 9f6e8a13..693b11c1 100644 --- a/src/primaite/setup/setup_app_dirs.py +++ b/src/primaite/setup/setup_app_dirs.py @@ -1,5 +1,5 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. -from primaite import _USER_DIRS, LOG_DIR, NOTEBOOKS_DIR, getLogger +from primaite import _USER_DIRS, getLogger, LOG_DIR, NOTEBOOKS_DIR _LOGGER = getLogger(__name__) diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index 39236217..7db2444a 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -1,48 +1,99 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """The Transaction class.""" +from datetime import datetime +from typing import List, Tuple + +from primaite.common.enums import AgentIdentifier class Transaction(object): """Transaction class.""" - def __init__(self, _timestamp, _agent_identifier, _episode_number, _step_number): + def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int): """ - Init. + Transaction constructor. - Args: - _timestamp: The time this object was created - _agent_identifier: An identifier for the agent in use - _episode_number: The episode number - _step_number: The step number + :param agent_identifier: An identifier for the agent in use + :param episode_number: The episode number + :param step_number: The step number """ - self.timestamp = _timestamp - self.agent_identifier = _agent_identifier - self.episode_number = _episode_number - self.step_number = _step_number + self.timestamp = datetime.now() + "The datetime of the transaction" + self.agent_identifier: AgentIdentifier = agent_identifier + "The agent identifier" + self.episode_number: int = episode_number + "The episode number" + self.step_number: int = step_number + "The step number" + self.obs_space = None + "The observation space (pre)" + self.obs_space_pre = None + "The observation space before any actions are taken" + self.obs_space_post = None + "The observation space after any actions are taken" + self.reward = None + "The reward value" + self.action_space = None + "The action space invoked by the agent" + self.obs_space_description = None + "The env observation space description" - def set_obs_space(self, _obs_space): + def as_csv_data(self) -> Tuple[List, List]: """ - Sets the observation space (pre). + Converts the Transaction to a csv data row and provides a header. - Args: - _obs_space_pre: The observation space before any actions are taken + :return: A tuple consisting of (header, data). """ - self.obs_space = _obs_space + if isinstance(self.action_space, int): + action_length = self.action_space + else: + action_length = self.action_space.size - def set_reward(self, _reward): - """ - Sets the reward. + # Create the action space headers array + action_header = [] + for x in range(action_length): + action_header.append("AS_" + str(x)) - Args: - _reward: The reward value - """ - self.reward = _reward + # Open up a csv file + header = ["Timestamp", "Episode", "Step", "Reward"] + header = header + action_header + self.obs_space_description - def set_action_space(self, _action_space): - """ - Sets the action space. + row = [ + str(self.timestamp), + str(self.episode_number), + str(self.step_number), + str(self.reward), + ] + row = row + _turn_action_space_to_array(self.action_space) + self.obs_space.tolist() + return header, row - Args: - _action_space: The action space invoked by the agent - """ - self.action_space = _action_space + +def _turn_action_space_to_array(action_space) -> List[str]: + """ + Turns action space into a string array so it can be saved to csv. + + :param action_space: The action space + :return: The action space as an array of strings + """ + if isinstance(action_space, list): + return [str(i) for i in action_space] + else: + return [str(action_space)] + + +def _turn_obs_space_to_array(obs_space, obs_assets, obs_features) -> List[str]: + """ + Turns observation space into a string array so it can be saved to csv. + + :param obs_space: The observation space + :param obs_assets: The number of assets (i.e. nodes or links) in the + observation space + :param obs_features: The number of features associated with the asset + :return: The observation space as an array of strings + """ + return_array = [] + for x in range(obs_assets): + for y in range(obs_features): + return_array.append(str(obs_space[x][y])) + + return return_array diff --git a/src/primaite/transactions/transactions_to_file.py b/src/primaite/transactions/transactions_to_file.py deleted file mode 100644 index 4e364f0b..00000000 --- a/src/primaite/transactions/transactions_to_file.py +++ /dev/null @@ -1,91 +0,0 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. -"""Writes the Transaction log list out to file for evaluation to utilse.""" - -import csv -from pathlib import Path - -from primaite import getLogger - -_LOGGER = getLogger(__name__) - - -def turn_action_space_to_array(_action_space): - """ - Turns action space into a string array so it can be saved to csv. - - Args: - _action_space: The action space. - """ - if isinstance(_action_space, list): - return [str(i) for i in _action_space] - else: - return [str(_action_space)] - - -def write_transaction_to_file( - transaction_list, - session_path: Path, - timestamp_str: str, - obs_space_description: list, -): - """ - Writes transaction logs to file to support training evaluation. - - :param transaction_list: The list of transactions from all steps and all - episodes. - :param session_path: The directory path the session is writing to. - :param timestamp_str: The session timestamp in the format: - _. - """ - # Get the first transaction and use it to determine the makeup of the - # observation space and action space - # Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action - # space as "AS_1" - # This will be tied into the PrimAITE Use Case so that they make sense - template_transation = transaction_list[0] - action_length = template_transation.action_space.size - # obs_shape = template_transation.obs_space_post.shape - # obs_assets = template_transation.obs_space_post.shape[0] - # if len(obs_shape) == 1: - # bit of a workaround but I think the way transactions are written will change soon - # obs_features = 1 - # else: - # obs_features = template_transation.obs_space_post.shape[1] - - # Create the action space headers array - action_header = [] - for x in range(action_length): - action_header.append("AS_" + str(x)) - - # Create the observation space headers array - # obs_header_initial = [f"pre_{o}" for o in obs_space_description] - # obs_header_new = [f"post_{o}" for o in obs_space_description] - - # Open up a csv file - header = ["Timestamp", "Episode", "Step", "Reward"] - header = header + action_header + obs_space_description - - try: - filename = session_path / f"all_transactions_{timestamp_str}.csv" - _LOGGER.debug(f"Saving transaction logs: {filename}") - csv_file = open(filename, "w", encoding="UTF8", newline="") - csv_writer = csv.writer(csv_file) - csv_writer.writerow(header) - - for transaction in transaction_list: - csv_data = [ - str(transaction.timestamp), - str(transaction.episode_number), - str(transaction.step_number), - str(transaction.reward), - ] - csv_data = ( - csv_data - + turn_action_space_to_array(transaction.action_space) - + transaction.obs_space.tolist() - ) - csv_writer.writerow(csv_data) - - csv_file.close() - except Exception: - _LOGGER.error("Could not save the transaction file", exc_info=True) diff --git a/src/primaite/utils/session_output_reader.py b/src/primaite/utils/session_output_reader.py new file mode 100644 index 00000000..d04f375e --- /dev/null +++ b/src/primaite/utils/session_output_reader.py @@ -0,0 +1,20 @@ +from pathlib import Path +from typing import Dict, Union + +# Using polars as it's faster than Pandas; it will speed things up when +# files get big! +import polars as pl + + +def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]: + """ + Read an average rewards per episode csv file and return as a dict. + + The dictionary keys are the episode number, and the values are the mean + reward that episode. + + :param av_rewards_csv_file: The average rewards per episode csv file path. + :return: The average rewards per episode cdv as a dict. + """ + d = pl.read_csv(av_rewards_csv_file).to_dict() + return {v: d["Average Reward"][i] for i, v in enumerate(d["Episode"])} diff --git a/src/primaite/utils/session_output_writer.py b/src/primaite/utils/session_output_writer.py new file mode 100644 index 00000000..a05b0453 --- /dev/null +++ b/src/primaite/utils/session_output_writer.py @@ -0,0 +1,83 @@ +import csv +from logging import Logger +from typing import Final, List, Tuple, TYPE_CHECKING, Union + +from primaite import getLogger +from primaite.transactions.transaction import Transaction + +if TYPE_CHECKING: + from primaite.environment.primaite_env import Primaite + +_LOGGER: Logger = getLogger(__name__) + + +class SessionOutputWriter: + """ + A session output writer class. + + Is used to write session outputs to csv file. + """ + + _AV_REWARD_PER_EPISODE_HEADER: Final[List[str]] = [ + "Episode", + "Average Reward", + ] + + def __init__( + self, + env: "Primaite", + transaction_writer: bool = False, + learning_session: bool = True, + ): + self._env = env + self.transaction_writer = transaction_writer + self.learning_session = learning_session + + if self.transaction_writer: + fn = f"all_transactions_{self._env.timestamp_str}.csv" + else: + fn = f"average_reward_per_episode_{self._env.timestamp_str}.csv" + + if self.learning_session: + self._csv_file_path = self._env.session_path / "learning" / fn + else: + self._csv_file_path = self._env.session_path / "evaluation" / fn + + self._csv_file_path.parent.mkdir(exist_ok=True, parents=True) + + self._csv_file = None + self._csv_writer = None + + self._first_write: bool = True + + def _init_csv_writer(self): + self._csv_file = open(self._csv_file_path, "w", encoding="UTF8", newline="") + + self._csv_writer = csv.writer(self._csv_file) + + def __del__(self): + self.close() + + def close(self): + """Close the cvs file.""" + if self._csv_file: + self._csv_file.close() + _LOGGER.debug(f"Finished writing file: {self._csv_file_path}") + + def write(self, data: Union[Tuple, Transaction]): + """ + Write a row of session data. + + :param data: The row of data to write. Can be a Tuple or an instance + of Transaction. + """ + if isinstance(data, Transaction): + header, data = data.as_csv_data() + else: + header = self._AV_REWARD_PER_EPISODE_HEADER + + if self._first_write: + self._init_csv_writer() + self._csv_writer.writerow(header) + self._first_write = False + self._csv_writer.writerow(data) diff --git a/tests/config/legacy/legacy_training_config.yaml b/tests/config/legacy_conversion/legacy_training_config.yaml similarity index 100% rename from tests/config/legacy/legacy_training_config.yaml rename to tests/config/legacy_conversion/legacy_training_config.yaml diff --git a/tests/config/legacy/new_training_config.yaml b/tests/config/legacy_conversion/new_training_config.yaml similarity index 87% rename from tests/config/legacy/new_training_config.yaml rename to tests/config/legacy_conversion/new_training_config.yaml index becc1799..49e6a00b 100644 --- a/tests/config/legacy/new_training_config.yaml +++ b/tests/config/legacy_conversion/new_training_config.yaml @@ -1,11 +1,20 @@ # Main Config File # Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: STABLE_BASELINES3_A2C + +# Sets which agent algorithm framework will be used: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray[RLlib]) +# "NONE" (Custom Agent) +agent_framework: SB3 + +# Sets which Red Agent algo/class will be used: +# "PPO" (Proximal Policy Optimization) +# "A2C" (Advantage Actor Critic) +# "HARDCODED" (Custom Agent) +# "RANDOM" (Random Action) +agent_identifier: PPO + # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -18,7 +27,7 @@ num_steps: 256 # Time delay between steps (for generic agents) time_delay: 10 # Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING +session_type: TRAIN # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml b/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml index 67aaa9de..d26d7955 100644 --- a/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml +++ b/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml @@ -1,11 +1,22 @@ -# Main Config File +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: SB3 + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: A2C -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: STABLE_BASELINES3_A2C # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -28,7 +39,7 @@ observation_space: time_delay: 1 # Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING +session_type: TRAIN # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml b/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml index 29a89b8d..aae740b6 100644 --- a/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml +++ b/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml @@ -1,11 +1,22 @@ -# Main Config File +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: CUSTOM + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: RANDOM -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: NONE # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -24,7 +35,7 @@ observation_space: time_delay: 1 # Filename of the scenario / laydown -session_type: TRAINING +session_type: TRAIN # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/config/obs_tests/main_config_NODE_STATUSES.yaml b/tests/config/obs_tests/main_config_NODE_STATUSES.yaml index 8f2d9a38..4066eace 100644 --- a/tests/config/obs_tests/main_config_NODE_STATUSES.yaml +++ b/tests/config/obs_tests/main_config_NODE_STATUSES.yaml @@ -1,11 +1,22 @@ -# Main Config File +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: CUSTOM + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: RANDOM -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: NONE # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -25,7 +36,7 @@ observation_space: time_delay: 1 # Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING +session_type: TRAIN # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/config/obs_tests/main_config_without_obs.yaml b/tests/config/obs_tests/main_config_without_obs.yaml index e8bb49ea..08452dda 100644 --- a/tests/config/obs_tests/main_config_without_obs.yaml +++ b/tests/config/obs_tests/main_config_without_obs.yaml @@ -1,11 +1,22 @@ -# Main Config File +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: CUSTOM + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: RANDOM -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: NONE # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -18,7 +29,7 @@ num_steps: 5 # Time delay between steps (for generic agents) time_delay: 1 # Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING +session_type: TRAIN # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/config/one_node_states_on_off_main_config.yaml b/tests/config/one_node_states_on_off_main_config.yaml index 2e752bc9..7f1ced01 100644 --- a/tests/config/one_node_states_on_off_main_config.yaml +++ b/tests/config/one_node_states_on_off_main_config.yaml @@ -1,10 +1,22 @@ -# Main Config File +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: CUSTOM + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: DUMMY -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -agent_identifier: GENERIC # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -18,7 +30,7 @@ num_steps: 15 time_delay: 1 # Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING +session_type: EVAL # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml index 5c5db582..97d0ddaf 100644 --- a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml +++ b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml @@ -1,11 +1,22 @@ -# Main Config File +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: CUSTOM + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: RANDOM -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: GENERIC # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -18,7 +29,7 @@ num_steps: 15 # Time delay between steps (for generic agents) time_delay: 1 # Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING +session_type: EVAL # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/config/single_action_space_main_config.yaml b/tests/config/single_action_space_main_config.yaml index 967fdcce..067b9a6d 100644 --- a/tests/config/single_action_space_main_config.yaml +++ b/tests/config/single_action_space_main_config.yaml @@ -1,11 +1,22 @@ -# Main Config File +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: CUSTOM + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: RANDOM -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: GENERIC # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -18,7 +29,7 @@ num_steps: 5 # Time delay between steps (for generic agents) time_delay: 1 # Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING +session_type: EVAL # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/config/test_random_red_main_config.yaml b/tests/config/test_random_red_main_config.yaml new file mode 100644 index 00000000..800fe808 --- /dev/null +++ b/tests/config/test_random_red_main_config.yaml @@ -0,0 +1,112 @@ +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: CUSTOM + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: DUMMY + +# Sets whether Red Agent POL and IER is randomised. +# Options are: +# True +# False +random_red_agent: True + +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: NODE +# Number of episodes to run per session +num_episodes: 2 +# Number of time_steps per episode +num_steps: 15 +# Time delay between steps (for generic agents) +time_delay: 1 + +# Type of session to be run (TRAINING or EVALUATION) +session_type: EVAL +# Determine whether to load an agent from file +load_agent: False +# File path and file name of agent if you're loading one in +agent_load_file: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observation_space_high_value: 1000000000 + +# Reward values +# Generic +all_ok: 0 +# Node Hardware State +off_should_be_on: -10 +off_should_be_resetting: -5 +on_should_be_off: -2 +on_should_be_resetting: -5 +resetting_should_be_on: -5 +resetting_should_be_off: -2 +resetting: -3 +# Node Software or Service State +good_should_be_patching: 2 +good_should_be_compromised: 5 +good_should_be_overwhelmed: 5 +patching_should_be_good: -5 +patching_should_be_compromised: 2 +patching_should_be_overwhelmed: 2 +patching: -3 +compromised_should_be_good: -20 +compromised_should_be_patching: -20 +compromised_should_be_overwhelmed: -20 +compromised: -20 +overwhelmed_should_be_good: -20 +overwhelmed_should_be_patching: -20 +overwhelmed_should_be_compromised: -20 +overwhelmed: -20 +# Node File System State +good_should_be_repairing: 2 +good_should_be_restoring: 2 +good_should_be_corrupt: 5 +good_should_be_destroyed: 10 +repairing_should_be_good: -5 +repairing_should_be_restoring: 2 +repairing_should_be_corrupt: 2 +repairing_should_be_destroyed: 0 +repairing: -3 +restoring_should_be_good: -10 +restoring_should_be_repairing: -2 +restoring_should_be_corrupt: 1 +restoring_should_be_destroyed: 2 +restoring: -6 +corrupt_should_be_good: -10 +corrupt_should_be_repairing: -10 +corrupt_should_be_restoring: -10 +corrupt_should_be_destroyed: 2 +corrupt: -10 +destroyed_should_be_good: -20 +destroyed_should_be_repairing: -20 +destroyed_should_be_restoring: -20 +destroyed_should_be_corrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +red_ier_running: -5 +green_ier_blocked: -10 + +# Patching / Reset durations +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/tests/conftest.py b/tests/conftest.py index 20ad8b23..8365c8ae 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,46 +1,151 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +import datetime +import shutil import tempfile import time from datetime import datetime from pathlib import Path -from typing import Final, Union +from typing import Dict, Union +from unittest.mock import patch import pandas as pd +import pytest +from primaite import getLogger +from primaite.common.enums import AgentIdentifier from primaite.environment.primaite_env import Primaite -from primaite.main import _get_session_path, _write_session_metadata_file +from primaite.primaite_session import PrimaiteSession +from primaite.utils.session_output_reader import av_rewards_dict +from tests.mock_and_patch.get_session_path_mock import get_temp_session_path ACTION_SPACE_NODE_VALUES = 1 ACTION_SPACE_NODE_ACTION_VALUES = 1 +_LOGGER = getLogger(__name__) -def _get_temp_session_path(session_timestamp: datetime) -> Path: + +class TempPrimaiteSession(PrimaiteSession): + """ + A temporary PrimaiteSession class. + + Uses context manager for deletion of files upon exit. + """ + + def __init__( + self, + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + ): + super().__init__(training_config_path, lay_down_config_path) + self.setup() + + def learn_av_reward_per_episode(self) -> Dict[int, float]: + """Get the learn av reward per episode from file.""" + csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" + return av_rewards_dict(self.learning_path / csv_file) + + def eval_av_reward_per_episode_csv(self) -> Dict[int, float]: + """Get the eval av reward per episode from file.""" + csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" + return av_rewards_dict(self.evaluation_path / csv_file) + + @property + def env(self) -> Primaite: + """Direct access to the env for ease of testing.""" + return self._agent_session._env # noqa + + def __enter__(self): + return self + + def __exit__(self, type, value, tb): + shutil.rmtree(self.session_path) + shutil.rmtree(self.session_path.parent) + _LOGGER.debug(f"Deleted temp session directory: {self.session_path}") + + +@pytest.fixture +def temp_primaite_session(request): + """ + Provides a temporary PrimaiteSession instance. + + It's temporary as it uses a temporary directory as the session path. + + To use this fixture you need to: + + - parametrize your test function with: + + - "temp_primaite_session" + - [[path to training config, path to lay down config]] + - Include the temp_primaite_session fixture as a param in your test + function. + - use the temp_primaite_session as a context manager assigning is the + name 'session'. + + .. code:: python + + from primaite.config.lay_down_config import dos_very_basic_config_path + from primaite.config.training_config import main_training_config_path + @pytest.mark.parametrize( + "temp_primaite_session", + [ + [main_training_config_path(), dos_very_basic_config_path()] + ], + indirect=True + ) + def test_primaite_session(temp_primaite_session): + with temp_primaite_session as session: + # Learning outputs are saved in session.learning_path + session.learn() + + # Evaluation outputs are saved in session.evaluation_path + session.evaluate() + + # To ensure that all files are written, you must call .close() + session.close() + + # If you need to inspect any session outputs, it must be done + # inside the context manager + + # Now that we've exited the context manager, the + # session.session_path directory and its contents are deleted + """ + training_config_path = request.param[0] + lay_down_config_path = request.param[1] + with patch("primaite.agents.agent.get_session_path", get_temp_session_path) as mck: + mck.session_timestamp = datetime.now() + + return TempPrimaiteSession(training_config_path, lay_down_config_path) + + +@pytest.fixture +def temp_session_path() -> Path: """ Get a temp directory session path the test session will output to. - :param session_timestamp: This is the datetime that the session started. :return: The session directory path. """ + session_timestamp = datetime.now() date_dir = session_timestamp.strftime("%Y-%m-%d") - session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_dir + session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path session_path.mkdir(exist_ok=True, parents=True) return session_path def _get_primaite_env_from_config( - training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path] + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + temp_session_path, ): """Takes a config path and returns the created instance of Primaite.""" session_timestamp: datetime = datetime.now() - session_path = _get_temp_session_path(session_timestamp) + session_path = temp_session_path(session_timestamp) timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") env = Primaite( training_config_path=training_config_path, lay_down_config_path=lay_down_config_path, - transaction_list=[], session_path=session_path, timestamp_str=timestamp_str, ) @@ -49,7 +154,7 @@ def _get_primaite_env_from_config( # TOOD: This needs t be refactored to happen outside. Should be part of # a main Session class. - if env.training_config.agent_identifier == "GENERIC": + if env.training_config.agent_identifier is AgentIdentifier.RANDOM: run_generic(env, config_values) return env @@ -95,9 +200,7 @@ def compare_file_content(output_a_file_path: str, output_b_file_path: str): # both files have the same content return True # both files have different content - print( - f"{output_a_file_path} and {output_b_file_path} has different contents" - ) + print(f"{output_a_file_path} and {output_b_file_path} has different contents") return False @@ -116,32 +219,3 @@ def compare_transaction_file(output_a_file_path: str, output_b_file_path: str): # if the comparison is empty, both files are the same i.e. True return data_a.compare(data_b).empty - - -class TestSession: - """Class that contains session values.""" - - def __init__(self, training_config_path, laydown_config_path): - self.session_timestamp: Final[datetime] = datetime.now() - self.session_dir = _get_session_path(self.session_timestamp) - self.timestamp_str = self.session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - self.transaction_list = [] - - print(f"The output directory for this session is: {self.session_dir}") - - self.env = Primaite( - training_config_path=training_config_path, - lay_down_config_path=laydown_config_path, - transaction_list=self.transaction_list, - session_path=self.session_dir, - timestamp_str=self.timestamp_str, - ) - - print("Writing Session Metadata file...") - - _write_session_metadata_file( - session_dir=self.session_dir, - uuid="test", - session_timestamp=self.session_timestamp, - env=self.env, - ) diff --git a/tests/e2e_integration_tests/test_primaite_main.py b/tests/e2e_integration_tests/test_primaite_main.py deleted file mode 100644 index b457557a..00000000 --- a/tests/e2e_integration_tests/test_primaite_main.py +++ /dev/null @@ -1,8 +0,0 @@ -from primaite.config.lay_down_config import data_manipulation_config_path -from primaite.config.training_config import main_training_config_path -from primaite.main import run - - -def test_primaite_main_e2e(): - """Tests the primaite.main.run function end-to-end.""" - run(main_training_config_path(), data_manipulation_config_path()) diff --git a/tests/e2e_integration_tests/test_session_repeatability.py b/tests/e2e_integration_tests/test_session_repeatability.py index a1a8f16a..91be68b2 100644 --- a/tests/e2e_integration_tests/test_session_repeatability.py +++ b/tests/e2e_integration_tests/test_session_repeatability.py @@ -1,404 +1,3 @@ -import time - -from primaite import getLogger -from primaite.config.lay_down_config import data_manipulation_config_path -from primaite.main import ( - _get_session_path, - _update_session_metadata_file, - run_generic, - run_stable_baselines3_a2c, - run_stable_baselines3_ppo, -) -from primaite.transactions.transactions_to_file import write_transaction_to_file -from tests import TEST_CONFIG_ROOT -from tests.conftest import TestSession, compare_file_content, compare_transaction_file - -_LOGGER = getLogger(__name__) - - -def test_generic_same_results(): - """Runs seeded and deterministic Generic Primaite sessions and checks that the results are the same.""" - print("") - print("=======================") - print("Generic test run") - print("=======================") - print("") - - # run session 1 - session1 = TestSession( - TEST_CONFIG_ROOT / "e2e/generic_deterministic_seeded_training_config.yaml", - data_manipulation_config_path(), - ) - - config_values = session1.env.training_config - - # Get the number of steps (which is stored in the child config file) - config_values.num_steps = session1.env.episode_steps - - run_generic(env=session1.env, config_values=session1.env.training_config) - - _update_session_metadata_file(session_dir=session1.session_dir, env=session1.env) - - # run session 2 - session2 = TestSession( - TEST_CONFIG_ROOT / "e2e/generic_deterministic_seeded_training_config.yaml", - data_manipulation_config_path(), - ) - - config_values = session2.env.training_config - - # Get the number of steps (which is stored in the child config file) - config_values.num_steps = session2.env.episode_steps - - run_generic(env=session2.env, config_values=session2.env.training_config) - - _update_session_metadata_file(session_dir=session2.session_dir, env=session2.env) - - # wait until the csv files have been closed - while (not session1.env.csv_file.closed) or (not session2.env.csv_file.closed): - time.sleep(1) - - # check if both outputs are the same - assert ( - compare_file_content( - session1.env.csv_file.name, - session2.env.csv_file.name, - ) - is True - ) - - # deterministic run - deterministic = TestSession( - TEST_CONFIG_ROOT / "e2e/generic_deterministic_seeded_training_config.yaml", - data_manipulation_config_path(), - ) - - deterministic.env.training_config.deterministic = True - - run_generic(env=deterministic.env, config_values=deterministic.env.training_config) - - _update_session_metadata_file( - session_dir=deterministic.session_dir, env=deterministic.env - ) - - # check if both outputs are the same - assert ( - compare_file_content( - deterministic.env.csv_file.name, - TEST_CONFIG_ROOT - / "e2e/deterministic_test_outputs/deterministic_generic.csv", - ) - is True - ) - - -def test_ppo_same_results(): - """Runs seeded and deterministic PPO Primaite sessions and checks that the results are the same.""" - print("") - print("=======================") - print("PPO test run") - print("=======================") - print("") - - training_session = TestSession( - TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml", - data_manipulation_config_path(), - ) - - # Train agent - training_session.env.training_config.session_type = "TRAINING" - - config_values = training_session.env.training_config - - # Get the number of steps (which is stored in the child config file) - config_values.num_steps = training_session.env.episode_steps - - run_stable_baselines3_ppo( - env=training_session.env, - config_values=config_values, - session_path=training_session.session_dir, - timestamp_str=training_session.timestamp_str, - ) - - write_transaction_to_file( - transaction_list=training_session.transaction_list, - session_path=training_session.session_dir, - timestamp_str=training_session.timestamp_str, - ) - - _update_session_metadata_file( - session_dir=training_session.session_dir, env=training_session.env - ) - - # Evaluate Agent again - eval_session1 = TestSession( - TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml", - data_manipulation_config_path(), - ) - - # Get the number of steps (which is stored in the child config file) - config_values.num_steps = eval_session1.env.episode_steps - eval_session1.env.training_config.session_type = "EVALUATE" - - # load the agent that was trained previously - eval_session1.env.training_config.load_agent = True - eval_session1.env.training_config.agent_load_file = ( - _get_session_path(training_session.session_timestamp) - / f"agent_saved_{training_session.timestamp_str}.zip" - ) - - config_values = eval_session1.env.training_config - - run_stable_baselines3_ppo( - env=eval_session1.env, - config_values=config_values, - session_path=eval_session1.session_dir, - timestamp_str=eval_session1.timestamp_str, - ) - - write_transaction_to_file( - transaction_list=eval_session1.transaction_list, - session_path=eval_session1.session_dir, - timestamp_str=eval_session1.timestamp_str, - ) - - _update_session_metadata_file( - session_dir=eval_session1.session_dir, env=eval_session1.env - ) - - eval_session2 = TestSession( - TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml", - data_manipulation_config_path(), - ) - - # Get the number of steps (which is stored in the child config file) - config_values.num_steps = eval_session2.env.episode_steps - eval_session2.env.training_config.session_type = "EVALUATE" - - # load the agent that was trained previously - eval_session2.env.training_config.load_agent = True - eval_session2.env.training_config.agent_load_file = ( - _get_session_path(training_session.session_timestamp) - / f"agent_saved_{training_session.timestamp_str}.zip" - ) - - config_values = eval_session2.env.training_config - - run_stable_baselines3_ppo( - env=eval_session2.env, - config_values=config_values, - session_path=eval_session2.session_dir, - timestamp_str=eval_session2.timestamp_str, - ) - - write_transaction_to_file( - transaction_list=eval_session2.transaction_list, - session_path=eval_session2.session_dir, - timestamp_str=eval_session2.timestamp_str, - ) - - _update_session_metadata_file( - session_dir=eval_session2.session_dir, env=eval_session2.env - ) - - # check if both eval outputs are the same - assert ( - compare_transaction_file( - eval_session1.session_dir - / f"all_transactions_{eval_session1.timestamp_str}.csv", - eval_session2.session_dir - / f"all_transactions_{eval_session2.timestamp_str}.csv", - ) - is True - ) - - # deterministic run - deterministic = TestSession( - TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml", - data_manipulation_config_path(), - ) - - deterministic.env.training_config.deterministic = True - - run_stable_baselines3_ppo( - env=deterministic.env, - config_values=config_values, - session_path=deterministic.session_dir, - timestamp_str=deterministic.timestamp_str, - ) - - write_transaction_to_file( - transaction_list=deterministic.transaction_list, - session_path=deterministic.session_dir, - timestamp_str=deterministic.timestamp_str, - ) - - _update_session_metadata_file( - session_dir=deterministic.session_dir, env=deterministic.env - ) - - # check if both outputs are the same - assert ( - compare_transaction_file( - deterministic.session_dir - / f"all_transactions_{deterministic.timestamp_str}.csv", - TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_ppo.csv", - ) - is True - ) - - -def test_a2c_same_results(): - """Runs seeded and deterministic A2C Primaite sessions and checks that the results are the same.""" - print("") - print("=======================") - print("A2C test run") - print("=======================") - print("") - - training_session = TestSession( - TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml", - data_manipulation_config_path(), - ) - - # Train agent - training_session.env.training_config.session_type = "TRAINING" - - config_values = training_session.env.training_config - - # Get the number of steps (which is stored in the child config file) - config_values.num_steps = training_session.env.episode_steps - - run_stable_baselines3_a2c( - env=training_session.env, - config_values=config_values, - session_path=training_session.session_dir, - timestamp_str=training_session.timestamp_str, - ) - - write_transaction_to_file( - transaction_list=training_session.transaction_list, - session_path=training_session.session_dir, - timestamp_str=training_session.timestamp_str, - ) - - _update_session_metadata_file( - session_dir=training_session.session_dir, env=training_session.env - ) - - # Evaluate Agent again - eval_session1 = TestSession( - TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml", - data_manipulation_config_path(), - ) - - # Get the number of steps (which is stored in the child config file) - config_values.num_steps = eval_session1.env.episode_steps - eval_session1.env.training_config.session_type = "EVALUATE" - - # load the agent that was trained previously - eval_session1.env.training_config.load_agent = True - eval_session1.env.training_config.agent_load_file = ( - _get_session_path(training_session.session_timestamp) - / f"agent_saved_{training_session.timestamp_str}.zip" - ) - - config_values = eval_session1.env.training_config - - run_stable_baselines3_a2c( - env=eval_session1.env, - config_values=config_values, - session_path=eval_session1.session_dir, - timestamp_str=eval_session1.timestamp_str, - ) - - write_transaction_to_file( - transaction_list=eval_session1.transaction_list, - session_path=eval_session1.session_dir, - timestamp_str=eval_session1.timestamp_str, - ) - - _update_session_metadata_file( - session_dir=eval_session1.session_dir, env=eval_session1.env - ) - - eval_session2 = TestSession( - TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml", - data_manipulation_config_path(), - ) - - # Get the number of steps (which is stored in the child config file) - config_values.num_steps = eval_session2.env.episode_steps - eval_session2.env.training_config.session_type = "EVALUATE" - - # load the agent that was trained previously - eval_session2.env.training_config.load_agent = True - eval_session2.env.training_config.agent_load_file = ( - _get_session_path(training_session.session_timestamp) - / f"agent_saved_{training_session.timestamp_str}.zip" - ) - - config_values = eval_session2.env.training_config - - run_stable_baselines3_a2c( - env=eval_session2.env, - config_values=config_values, - session_path=eval_session2.session_dir, - timestamp_str=eval_session2.timestamp_str, - ) - - write_transaction_to_file( - transaction_list=eval_session2.transaction_list, - session_path=eval_session2.session_dir, - timestamp_str=eval_session2.timestamp_str, - ) - - _update_session_metadata_file( - session_dir=eval_session2.session_dir, env=eval_session2.env - ) - - # check if both eval outputs are the same - assert ( - compare_transaction_file( - eval_session1.session_dir - / f"all_transactions_{eval_session1.timestamp_str}.csv", - eval_session2.session_dir - / f"all_transactions_{eval_session2.timestamp_str}.csv", - ) - is True - ) - - # deterministic run - deterministic = TestSession( - TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml", - data_manipulation_config_path(), - ) - - deterministic.env.training_config.deterministic = True - - run_stable_baselines3_a2c( - env=deterministic.env, - config_values=config_values, - session_path=deterministic.session_dir, - timestamp_str=deterministic.timestamp_str, - ) - - write_transaction_to_file( - transaction_list=deterministic.transaction_list, - session_path=deterministic.session_dir, - timestamp_str=deterministic.timestamp_str, - ) - - _update_session_metadata_file( - session_dir=deterministic.session_dir, env=deterministic.env - ) - - # check if both outputs are the same - assert ( - compare_transaction_file( - deterministic.session_dir - / f"all_transactions_{deterministic.timestamp_str}.csv", - TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_a2c.csv", - ) - is True - ) +def temp(): + """TODO rewrite tests.""" + pass diff --git a/tests/e2e_integration_tests/__init__.py b/tests/mock_and_patch/__init__.py similarity index 100% rename from tests/e2e_integration_tests/__init__.py rename to tests/mock_and_patch/__init__.py diff --git a/tests/mock_and_patch/get_session_path_mock.py b/tests/mock_and_patch/get_session_path_mock.py new file mode 100644 index 00000000..feff52f6 --- /dev/null +++ b/tests/mock_and_patch/get_session_path_mock.py @@ -0,0 +1,22 @@ +import tempfile +from datetime import datetime +from pathlib import Path + +from primaite import getLogger + +_LOGGER = getLogger(__name__) + + +def get_temp_session_path(session_timestamp: datetime) -> Path: + """ + Get a temp directory session path the test session will output to. + + :param session_timestamp: This is the datetime that the session started. + :return: The session directory path. + """ + date_dir = session_timestamp.strftime("%Y-%m-%d") + session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path + session_path.mkdir(exist_ok=True, parents=True) + _LOGGER.debug(f"Created temp session directory: {session_path}") + return session_path diff --git a/tests/test_acl.py b/tests/test_acl.py index 260ccffc..30f12697 100644 --- a/tests/test_acl.py +++ b/tests/test_acl.py @@ -95,8 +95,6 @@ def test_rule_hash(): rule = ACLRule("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80") hash_value_local = hash(rule) - hash_value_remote = acl.get_dictionary_hash( - "DENY", "192.168.1.1", "192.168.1.2", "TCP", "80" - ) + hash_value_remote = acl.get_dictionary_hash("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80") assert hash_value_local == hash_value_remote diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index efca7b0b..d1082049 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -2,84 +2,79 @@ import numpy as np import pytest -from primaite.environment.observations import ( - NodeLinkTable, - NodeStatuses, - ObservationsHandler, -) -from primaite.environment.primaite_env import Primaite +from primaite.environment.observations import NodeLinkTable, NodeStatuses, ObservationsHandler from tests import TEST_CONFIG_ROOT -from tests.conftest import _get_primaite_env_from_config -@pytest.fixture -def env(request): - """Build Primaite environment for integration tests of observation space.""" - marker = request.node.get_closest_marker("env_config_paths") - training_config_path = marker.args[0]["training_config_path"] - lay_down_config_path = marker.args[0]["lay_down_config_path"] - env = _get_primaite_env_from_config( - training_config_path=training_config_path, - lay_down_config_path=lay_down_config_path, - ) - yield env - - -@pytest.mark.env_config_paths( - dict( - training_config_path=TEST_CONFIG_ROOT - / "obs_tests/main_config_without_obs.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", - ) +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml", + TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", + ] + ], + indirect=True, ) -def test_default_obs_space(env: Primaite): +def test_default_obs_space(temp_primaite_session): """Create environment with no obs space defined in config and check that the default obs space was created.""" - env.update_environent_obs() + with temp_primaite_session as session: + session.env.update_environent_obs() - components = env.obs_handler.registered_obs_components + components = session.env.obs_handler.registered_obs_components - assert len(components) == 1 - assert isinstance(components[0], NodeLinkTable) + assert len(components) == 1 + assert isinstance(components[0], NodeLinkTable) -@pytest.mark.env_config_paths( - dict( - training_config_path=TEST_CONFIG_ROOT - / "obs_tests/main_config_without_obs.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", - ) +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml", + TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", + ] + ], + indirect=True, ) -def test_registering_components(env: Primaite): +def test_registering_components(temp_primaite_session): """Test regitering and deregistering a component.""" - handler = ObservationsHandler() - component = NodeStatuses(env) - handler.register(component) - assert component in handler.registered_obs_components - handler.deregister(component) - assert component not in handler.registered_obs_components + with temp_primaite_session as session: + env = session.env + handler = ObservationsHandler() + component = NodeStatuses(env) + handler.register(component) + assert component in handler.registered_obs_components + handler.deregister(component) + assert component not in handler.registered_obs_components -@pytest.mark.env_config_paths( - dict( - training_config_path=TEST_CONFIG_ROOT - / "obs_tests/main_config_NODE_LINK_TABLE.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", - ) +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_LINK_TABLE.yaml", + TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", + ] + ], + indirect=True, ) class TestNodeLinkTable: """Test the NodeLinkTable observation component (in isolation).""" - def test_obs_shape(self, env: Primaite): + def test_obs_shape(self, temp_primaite_session): """Try creating env with box observation space.""" - env.update_environent_obs() + with temp_primaite_session as session: + env = session.env + env.update_environent_obs() - # we have three nodes and two links, with two service - # therefore the box observation space will have: - # * 5 rows (3 nodes + 2 links) - # * 6 columns (four fixed and two for the services) - assert env.env_obs.shape == (5, 6) + # we have three nodes and two links, with two service + # therefore the box observation space will have: + # * 5 rows (3 nodes + 2 links) + # * 6 columns (four fixed and two for the services) + assert env.env_obs.shape == (5, 6) - def test_value(self, env: Primaite): + def test_value(self, temp_primaite_session): """Test that the observation is generated correctly. The laydown has: @@ -125,36 +120,43 @@ class TestNodeLinkTable: * 999 (999 traffic service1) * 0 (no traffic for service2) """ - # act = np.asarray([0,]) - obs, reward, done, info = env.step(0) # apply the 'do nothing' action + with temp_primaite_session as session: + env = session.env + # act = np.asarray([0,]) + obs, reward, done, info = env.step(0) # apply the 'do nothing' action - assert np.array_equal( - obs, - [ - [1, 1, 3, 1, 1, 1], - [2, 1, 1, 1, 1, 4], - [3, 1, 1, 1, 0, 0], - [4, 0, 0, 0, 999, 0], - [5, 0, 0, 0, 999, 0], - ], - ) + assert np.array_equal( + obs, + [ + [1, 1, 3, 1, 1, 1], + [2, 1, 1, 1, 1, 4], + [3, 1, 1, 1, 0, 0], + [4, 0, 0, 0, 999, 0], + [5, 0, 0, 0, 999, 0], + ], + ) -@pytest.mark.env_config_paths( - dict( - training_config_path=TEST_CONFIG_ROOT - / "obs_tests/main_config_NODE_STATUSES.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", - ) +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_STATUSES.yaml", + TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", + ] + ], + indirect=True, ) class TestNodeStatuses: """Test the NodeStatuses observation component (in isolation).""" - def test_obs_shape(self, env: Primaite): + def test_obs_shape(self, temp_primaite_session): """Try creating env with NodeStatuses as the only component.""" - assert env.env_obs.shape == (15,) + with temp_primaite_session as session: + env = session.env + assert env.env_obs.shape == (15,) - def test_values(self, env: Primaite): + def test_values(self, temp_primaite_session): """Test that the hardware and software states are encoded correctly. The laydown has: @@ -181,28 +183,36 @@ class TestNodeStatuses: * service 1 = n/a (0) * service 2 = n/a (0) """ - obs, _, _, _ = env.step(0) # apply the 'do nothing' action - assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0]) + with temp_primaite_session as session: + env = session.env + obs, _, _, _ = env.step(0) # apply the 'do nothing' action + print(obs) + assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0]) -@pytest.mark.env_config_paths( - dict( - training_config_path=TEST_CONFIG_ROOT - / "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", - ) +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT / "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml", + TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", + ] + ], + indirect=True, ) class TestLinkTrafficLevels: """Test the LinkTrafficLevels observation component (in isolation).""" - def test_obs_shape(self, env: Primaite): + def test_obs_shape(self, temp_primaite_session): """Try creating env with MultiDiscrete observation space.""" - env.update_environent_obs() + with temp_primaite_session as session: + env = session.env + env.update_environent_obs() - # we have two links and two services, so the shape should be 2 * 2 - assert env.env_obs.shape == (2 * 2,) + # we have two links and two services, so the shape should be 2 * 2 + assert env.env_obs.shape == (2 * 2,) - def test_values(self, env: Primaite): + def test_values(self, temp_primaite_session): """Test that traffic values are encoded correctly. The laydown has: @@ -212,12 +222,14 @@ class TestLinkTrafficLevels: * an IER trying to send 999 bits of data over both links the whole time (via the first service) * link bandwidth of 1000, therefore the utilisation is 99.9% """ - obs, reward, done, info = env.step(0) - obs, reward, done, info = env.step(0) + with temp_primaite_session as session: + env = session.env + obs, reward, done, info = env.step(0) + obs, reward, done, info = env.step(0) - # the observation space has combine_service_traffic set to False, so the space has this format: - # [link1_service1, link1_service2, link2_service1, link2_service2] - # we send 999 bits of data via link1 and link2 on service 1. - # therefore the first and third elements should be 6 and all others 0 - # (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%) - assert np.array_equal(obs, [6, 0, 6, 0]) + # the observation space has combine_service_traffic set to False, so the space has this format: + # [link1_service1, link1_service2, link2_service1, link2_service2] + # we send 999 bits of data via link1 and link2 on service 1. + # therefore the first and third elements should be 6 and all others 0 + # (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%) + assert np.array_equal(obs, [6, 0, 6, 0]) diff --git a/tests/test_primaite_session.py b/tests/test_primaite_session.py new file mode 100644 index 00000000..ae0b0870 --- /dev/null +++ b/tests/test_primaite_session.py @@ -0,0 +1,55 @@ +import os + +import pytest + +from primaite import getLogger +from primaite.config.lay_down_config import dos_very_basic_config_path +from primaite.config.training_config import main_training_config_path + +_LOGGER = getLogger(__name__) + + +@pytest.mark.parametrize( + "temp_primaite_session", + [[main_training_config_path(), dos_very_basic_config_path()]], + indirect=True, +) +def test_primaite_session(temp_primaite_session): + """Tests the PrimaiteSession class and its outputs.""" + with temp_primaite_session as session: + session_path = session.session_path + assert session_path.exists() + session.learn() + # Learning outputs are saved in session.learning_path + session.evaluate() + # Evaluation outputs are saved in session.evaluation_path + + # If you need to inspect any session outputs, it must be done inside + # the context manager + + # Check that the metadata json file exists + assert (session_path / "session_metadata.json").exists() + + # Check that the network png file exists + assert (session_path / f"network_{session.timestamp_str}.png").exists() + + # Check that both the transactions and av reward csv files exist + for file in session.learning_path.iterdir(): + if file.suffix == ".csv": + assert "all_transactions" in file.name or "average_reward_per_episode" in file.name + + # Check that both the transactions and av reward csv files exist + for file in session.evaluation_path.iterdir(): + if file.suffix == ".csv": + assert "all_transactions" in file.name or "average_reward_per_episode" in file.name + + _LOGGER.debug("Inspecting files in temp session path...") + for dir_path, dir_names, file_names in os.walk(session_path): + for file in file_names: + path = os.path.join(dir_path, file) + file_str = path.split(str(session_path))[-1] + _LOGGER.debug(f" {file_str}") + + # Now that we've exited the context manager, the session.session_path + # directory and its contents are deleted + assert not session_path.exists() diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py index 6b06dbb1..f8885f3e 100644 --- a/tests/test_red_random_agent_behaviour.py +++ b/tests/test_red_random_agent_behaviour.py @@ -1,68 +1,30 @@ -from datetime import datetime +import pytest from primaite.config.lay_down_config import data_manipulation_config_path -from primaite.environment.primaite_env import Primaite from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from tests import TEST_CONFIG_ROOT -from tests.conftest import _get_temp_session_path -def run_generic(env, config_values): - """Run against a generic agent.""" - # Reset the environment at the start of the episode - env.reset() - for episode in range(0, config_values.num_episodes): - for step in range(0, config_values.num_steps): - # Send the observation space to the agent to get an action - # TEMP - random action for now - # action = env.blue_agent_action(obs) - # action = env.action_space.sample() - action = 0 - - # Run the simulation step on the live environment - obs, reward, done, info = env.step(action) - - # Break if done is True - if done: - break - - # Reset the environment at the end of the episode - env.reset() - - env.close() - - -def test_random_red_agent_behaviour(): - """ - Test that hardware state is penalised at each step. - - When the initial state is OFF compared to reference state which is ON. - """ +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT / "test_random_red_main_config.yaml", + data_manipulation_config_path(), + ] + ], + indirect=True, +) +def test_random_red_agent_behaviour(temp_primaite_session): + """Test that red agent POL is randomised each episode.""" list_of_node_instructions = [] - # RUN TWICE so we can make sure that red agent is randomised - for i in range(2): - """Takes a config path and returns the created instance of Primaite.""" - session_timestamp: datetime = datetime.now() - session_path = _get_temp_session_path(session_timestamp) + with temp_primaite_session as session: + session.evaluate() + list_of_node_instructions.append(session.env.red_node_pol) - timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - env = Primaite( - training_config_path=TEST_CONFIG_ROOT - / "one_node_states_on_off_main_config.yaml", - lay_down_config_path=data_manipulation_config_path(), - transaction_list=[], - session_path=session_path, - timestamp_str=timestamp_str, - ) - # set red_agent_ - env.training_config.random_red_agent = True - training_config = env.training_config - training_config.num_steps = env.episode_steps - - run_generic(env, training_config) - # add red pol instructions to list - list_of_node_instructions.append(env.red_node_pol) + session.evaluate() + list_of_node_instructions.append(session.env.red_node_pol) # compare instructions to make sure that red instructions are truly random for index, instruction in enumerate(list_of_node_instructions): @@ -73,5 +35,4 @@ def test_random_red_agent_behaviour(): print(f"{key} end step: {instruction.get_end_step()}") print(f"{key} target node id: {instruction.get_target_node_id()}") print("") - assert list_of_node_instructions[0].__ne__(list_of_node_instructions[1]) diff --git a/tests/test_resetting_node.py b/tests/test_resetting_node.py index abe8115c..fb7dc83d 100644 --- a/tests/test_resetting_node.py +++ b/tests/test_resetting_node.py @@ -1,13 +1,7 @@ """Used to test Active Node functions.""" import pytest -from primaite.common.enums import ( - FileSystemState, - HardwareState, - NodeType, - Priority, - SoftwareState, -) +from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState from primaite.common.service import Service from primaite.config.training_config import TrainingConfig from primaite.nodes.active_node import ActiveNode @@ -57,9 +51,7 @@ def test_node_boots_correctly(operating_state, expected_operating_state): file_system_state="GOOD", config_values=1, ) - service_attributes = Service( - name="node", port="80", software_state=SoftwareState.COMPROMISED - ) + service_attributes = Service(name="node", port="80", software_state=SoftwareState.COMPROMISED) service_node.add_service(service_attributes) for x in range(5): diff --git a/tests/test_reward.py b/tests/test_reward.py index b8c92274..81437860 100644 --- a/tests/test_reward.py +++ b/tests/test_reward.py @@ -1,21 +1,26 @@ +import pytest + from tests import TEST_CONFIG_ROOT -from tests.conftest import _get_primaite_env_from_config -def test_rewards_are_being_penalised_at_each_step_function(): +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + TEST_CONFIG_ROOT / "one_node_states_on_off_lay_down_config.yaml", + ] + ], + indirect=True, +) +def test_rewards_are_being_penalised_at_each_step_function( + temp_primaite_session, +): """ Test that hardware state is penalised at each step. When the initial state is OFF compared to reference state which is ON. - """ - env = _get_primaite_env_from_config( - training_config_path=TEST_CONFIG_ROOT - / "one_node_states_on_off_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT - / "one_node_states_on_off_lay_down_config.yaml", - ) - """ The config 'one_node_states_on_off_lay_down_config.yaml' has 15 steps: On different steps, the laydown config has Pattern of Life (PoLs) which change a state of the node's attribute. For example, turning the nodes' file system state to CORRUPT from its original state GOOD. @@ -38,4 +43,8 @@ def test_rewards_are_being_penalised_at_each_step_function(): For the 4 steps where this occurs the average reward is: Average Reward: -8 (-120 / 15) """ - assert env.average_reward == -8.0 + with temp_primaite_session as session: + session.evaluate() + session.close() + ev_rewards = session.eval_av_reward_per_episode_csv() + assert ev_rewards[1] == -8.0 diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index 8ff43fe6..5d55b9c9 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -1,9 +1,10 @@ import time +import pytest + from primaite.common.enums import HardwareState from primaite.environment.primaite_env import Primaite from tests import TEST_CONFIG_ROOT -from tests.conftest import _get_primaite_env_from_config def run_generic_set_actions(env: Primaite): @@ -17,7 +18,6 @@ def run_generic_set_actions(env: Primaite): # TEMP - random action for now # action = env.blue_agent_action(obs) action = 0 - print("Episode:", episode, "\nStep:", step) if step == 5: # [1, 1, 2, 1, 1, 1] # Creates an ACL rule @@ -44,59 +44,71 @@ def run_generic_set_actions(env: Primaite): # env.close() -def test_single_action_space_is_valid(): - """Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations.""" - env = _get_primaite_env_from_config( - training_config_path=TEST_CONFIG_ROOT / "single_action_space_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT - / "single_action_space_lay_down_config.yaml", - ) +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT / "single_action_space_main_config.yaml", + TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", + ] + ], + indirect=True, +) +def test_single_action_space_is_valid(temp_primaite_session): + """Test single action space is valid.""" + with temp_primaite_session as session: + env = session.env - run_generic_set_actions(env) - - # Retrieve the action space dictionary values from environment - env_action_space_dict = env.action_dict.values() - # Flags to check the conditions of the action space - contains_acl_actions = False - contains_node_actions = False - both_action_spaces = False - # Loop through each element of the list (which is every value from the dictionary) - for dict_item in env_action_space_dict: - # Node action detected - if len(dict_item) == 4: - contains_node_actions = True - # Link action detected - elif len(dict_item) == 6: - contains_acl_actions = True - # If both are there then the ANY action type is working - if contains_node_actions and contains_acl_actions: - both_action_spaces = True - # Check condition should be True - assert both_action_spaces + run_generic_set_actions(env) + # Retrieve the action space dictionary values from environment + env_action_space_dict = env.action_dict.values() + # Flags to check the conditions of the action space + contains_acl_actions = False + contains_node_actions = False + both_action_spaces = False + # Loop through each element of the list (which is every value from the dictionary) + for dict_item in env_action_space_dict: + # Node action detected + if len(dict_item) == 4: + contains_node_actions = True + # Link action detected + elif len(dict_item) == 6: + contains_acl_actions = True + # If both are there then the ANY action type is working + if contains_node_actions and contains_acl_actions: + both_action_spaces = True + # Check condition should be True + assert both_action_spaces -def test_agent_is_executing_actions_from_both_spaces(): +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT / "single_action_space_fixed_blue_actions_main_config.yaml", + TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", + ] + ], + indirect=True, +) +def test_agent_is_executing_actions_from_both_spaces(temp_primaite_session): """Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL).""" - env = _get_primaite_env_from_config( - training_config_path=TEST_CONFIG_ROOT - / "single_action_space_fixed_blue_actions_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT - / "single_action_space_lay_down_config.yaml", - ) - # Run environment with specified fixed blue agent actions only - run_generic_set_actions(env) - # Retrieve hardware state of computer_1 node in laydown config - # Agent turned this off in Step 5 - computer_node_hardware_state = env.nodes["1"].hardware_state - # Retrieve the Access Control List object stored by the environment at the end of the episode - access_control_list = env.acl - # Use the Access Control List object acl object attribute to get dictionary - # Use dictionary.values() to get total list of all items in the dictionary - acl_rules_list = access_control_list.acl.values() - # Length of this list tells you how many items are in the dictionary - # This number is the frequency of Access Control Rules in the environment - # In the scenario, we specified that the agent should create only 1 acl rule - num_of_rules = len(acl_rules_list) - # Therefore these statements below MUST be true - assert computer_node_hardware_state == HardwareState.OFF - assert num_of_rules == 1 + with temp_primaite_session as session: + env = session.env + # Run environment with specified fixed blue agent actions only + run_generic_set_actions(env) + # Retrieve hardware state of computer_1 node in laydown config + # Agent turned this off in Step 5 + computer_node_hardware_state = env.nodes["1"].hardware_state + # Retrieve the Access Control List object stored by the environment at the end of the episode + access_control_list = env.acl + # Use the Access Control List object acl object attribute to get dictionary + # Use dictionary.values() to get total list of all items in the dictionary + acl_rules_list = access_control_list.acl.values() + # Length of this list tells you how many items are in the dictionary + # This number is the frequency of Access Control Rules in the environment + # In the scenario, we specified that the agent should create only 1 acl rule + num_of_rules = len(acl_rules_list) + # Therefore these statements below MUST be true + assert computer_node_hardware_state == HardwareState.OFF + assert num_of_rules == 1 diff --git a/tests/test_training_config.py b/tests/test_training_config.py index 02e90d30..d7fe4e50 100644 --- a/tests/test_training_config.py +++ b/tests/test_training_config.py @@ -7,8 +7,8 @@ from tests import TEST_CONFIG_ROOT def test_legacy_lay_down_config_yaml_conversion(): """Tests the conversion of legacy lay down config files.""" - legacy_path = TEST_CONFIG_ROOT / "legacy" / "legacy_training_config.yaml" - new_path = TEST_CONFIG_ROOT / "legacy" / "new_training_config.yaml" + legacy_path = TEST_CONFIG_ROOT / "legacy_conversion" / "legacy_training_config.yaml" + new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "new_training_config.yaml" with open(legacy_path, "r") as file: legacy_dict = yaml.safe_load(file) @@ -24,13 +24,13 @@ def test_legacy_lay_down_config_yaml_conversion(): def test_create_config_values_main_from_file(): """Tests creating an instance of TrainingConfig from file.""" - new_path = TEST_CONFIG_ROOT / "legacy" / "new_training_config.yaml" + new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "new_training_config.yaml" training_config.load(new_path) def test_create_config_values_main_from_legacy_file(): """Tests creating an instance of TrainingConfig from legacy file.""" - new_path = TEST_CONFIG_ROOT / "legacy" / "legacy_training_config.yaml" + new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "legacy_training_config.yaml" training_config.load(new_path, legacy_file=True)