diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index 8bfdca02..691f71e9 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -59,4 +59,4 @@ steps: - script: | pytest tests/ - displayName: 'Run unmarked tests' + displayName: 'Run tests' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a08b17b8..26cd5697 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,6 +13,9 @@ repos: rev: 23.1.0 hooks: - id: black + args: [ "--line-length=79" ] + additional_dependencies: + - jupyter - repo: http://github.com/pycqa/isort rev: 5.12.0 hooks: @@ -22,4 +25,5 @@ repos: rev: 6.0.0 hooks: - id: flake8 - additional_dependencies: [ flake8-docstrings ] + additional_dependencies: + - flake8-docstrings diff --git a/docs/source/config.rst b/docs/source/config.rst index 52748eec..22fd0c01 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -22,7 +22,7 @@ The environment config file consists of the following attributes: * **agent_framework** [enum] - + This identifies the agent framework to be used to instantiate the agent algorithm. Select from one of the following: * NONE - Where a user developed agent is to be used @@ -30,14 +30,14 @@ The environment config file consists of the following attributes: * RLLIB - Ray RLlib. * **agent_identifier** - + This identifies the agent to use for the session. Select from one of the following: * A2C - Advantage Actor Critic * PPO - Proximal Policy Optimization * HARDCODED - A custom built deterministic agent * RANDOM - A Stochastic random agent - + * **action_type** [enum] diff --git a/docs/source/primaite-dependencies.rst b/docs/source/primaite-dependencies.rst index bf6bd6e3..48f835fe 100644 --- a/docs/source/primaite-dependencies.rst +++ b/docs/source/primaite-dependencies.rst @@ -47,6 +47,8 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | asttokens | 2.2.1 | Apache 2.0 | https://github.com/gristlabs/asttokens | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| astunparse | 1.6.3 | BSD License | https://github.com/simonpercivall/astunparse | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | attrs | 23.1.0 | MIT License | https://www.attrs.org/en/stable/changelog.html | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | backcall | 0.2.0 | BSD License | https://github.com/takluyver/backcall | @@ -103,6 +105,8 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | flake8 | 6.0.0 | MIT License | https://github.com/pycqa/flake8 | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| flatbuffers | 23.5.26 | Apache Software License | https://google.github.io/flatbuffers/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | fonttools | 4.39.4 | MIT License | http://github.com/fonttools/fonttools | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | fqdn | 1.5.1 | Mozilla Public License 2.0 (MPL 2.0) | https://github.com/ypcrts/fqdn | @@ -111,9 +115,13 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | furo | 2023.3.27 | MIT License | UNKNOWN | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| gast | 0.4.0 | BSD License | https://github.com/serge-sans-paille/gast/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | google-auth | 2.19.0 | Apache Software License | https://github.com/googleapis/google-auth-library-python | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ -| google-auth-oauthlib | 1.0.0 | Apache Software License | https://github.com/GoogleCloudPlatform/google-auth-library-python-oauthlib | +| google-auth-oauthlib | 0.4.6 | Apache Software License | https://github.com/GoogleCloudPlatform/google-auth-library-python-oauthlib | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| google-pasta | 0.2.0 | Apache Software License | https://github.com/google/pasta | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | grpcio | 1.51.3 | Apache Software License | https://grpc.io | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ @@ -121,6 +129,8 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | gymnasium-notices | 0.0.1 | MIT License | https://github.com/Farama-Foundation/gym-notices | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| h5py | 3.9.0 | BSD License | https://www.h5py.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | identify | 2.5.24 | MIT License | https://github.com/pre-commit/identify | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | idna | 3.4 | BSD License | https://github.com/kjd/idna | @@ -141,6 +151,8 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | isoduration | 20.11.0 | ISC License (ISCL) | https://github.com/bolsote/isoduration | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| jax | 0.4.12 | Apache-2.0 | https://github.com/google/jax | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | jedi | 0.18.2 | MIT License | https://github.com/davidhalter/jedi | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | json5 | 0.9.14 | Apache Software License | https://github.com/dpranke/pyjson5 | @@ -151,14 +163,14 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | jupyter-events | 0.6.3 | BSD License | http://jupyter.org | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| jupyter-server | 1.24.0 | BSD License | https://jupyter-server.readthedocs.io | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | jupyter-ydoc | 0.2.4 | BSD 3-Clause License | https://jupyter.org | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | jupyter_client | 8.2.0 | BSD License | https://jupyter.org | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | jupyter_core | 5.3.0 | BSD License | https://jupyter.org | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ -| jupyter_server | 2.6.0 | BSD License | https://jupyter-server.readthedocs.io | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | jupyter_server_fileid | 0.9.0 | BSD License | UNKNOWN | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | jupyter_server_terminals | 0.4.4 | BSD License | https://jupyter.org | @@ -171,10 +183,14 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | jupyterlab_server | 2.22.1 | BSD License | https://jupyterlab-server.readthedocs.io | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| keras | 2.12.0 | Apache Software License | https://keras.io/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | kiwisolver | 1.4.4 | BSD License | UNKNOWN | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | lazy_loader | 0.2 | BSD License | https://github.com/scientific-python/lazy_loader | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| libclang | 16.0.0 | Apache Software License | https://github.com/sighingnow/libclang | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | lz4 | 4.3.2 | BSD License | https://github.com/python-lz4/python-lz4 | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | markdown-it-py | 2.2.0 | MIT License | https://github.com/executablebooks/markdown-it-py | @@ -183,19 +199,23 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | matplotlib-inline | 0.1.6 | BSD 3-Clause | https://github.com/ipython/matplotlib-inline | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| mavizstyle | 1.0.0 | UNKNOWN | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | mccabe | 0.7.0 | MIT License | https://github.com/pycqa/mccabe | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | mdurl | 0.1.2 | MIT License | https://github.com/executablebooks/mdurl | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | mistune | 2.0.5 | BSD License | https://github.com/lepture/mistune | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| ml-dtypes | 0.2.0 | Apache Software License | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | mock | 5.0.2 | BSD License | http://mock.readthedocs.org/en/latest/ | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | mpmath | 1.3.0 | BSD License | http://mpmath.org/ | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | msgpack | 1.0.5 | Apache Software License | https://msgpack.org/ | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ -| nbclassic | 1.0.0 | BSD License | https://github.com/jupyter/nbclassic | +| nbclassic | 0.5.6 | BSD License | https://github.com/jupyter/nbclassic | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | nbclient | 0.8.0 | BSD License | https://jupyter.org | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ @@ -217,6 +237,8 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | oauthlib | 3.2.2 | BSD License | https://github.com/oauthlib/oauthlib | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| opt-einsum | 3.3.0 | MIT | https://github.com/dgasmith/opt_einsum | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | overrides | 7.3.1 | Apache License, Version 2.0 | https://github.com/mkorpela/overrides | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | packaging | 23.1 | Apache Software License; BSD License | https://github.com/pypa/packaging | @@ -231,11 +253,17 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | platformdirs | 3.5.1 | MIT License | https://github.com/platformdirs/platformdirs | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| plotly | 5.15.0 | MIT License | https://plotly.com/python/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | pluggy | 1.0.0 | MIT License | https://github.com/pytest-dev/pluggy | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | pre-commit | 2.20.0 | MIT License | https://github.com/pre-commit/pre-commit | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ -| primaite | 1.2.1 | GFX | UNKNOWN | +| primaite | 2.0.0rc1 | GFX | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| primaite | 2.0.0rc1 | GFX | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| primaite | 2.0.0rc1 | GFX | UNKNOWN | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | prometheus-client | 0.17.0 | Apache Software License | https://github.com/prometheus/client_python | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ @@ -295,6 +323,8 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | rsa | 4.9 | Apache Software License | https://stuvel.eu/rsa | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| ruff | 0.0.272 | MIT License | https://github.com/charliermarsh/ruff | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | scikit-image | 0.20.0 | BSD License | https://scikit-image.org | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | scipy | 1.10.1 | BSD License | https://scipy.org/ | @@ -335,14 +365,26 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | tabulate | 0.9.0 | MIT License | https://github.com/astanin/python-tabulate | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ -| tensorboard | 2.12.3 | Apache Software License | https://github.com/tensorflow/tensorboard | +| tenacity | 8.2.2 | Apache Software License | https://github.com/jd/tenacity | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ -| tensorboard-data-server | 0.7.0 | Apache Software License | https://github.com/tensorflow/tensorboard/tree/master/tensorboard/data/server | +| tensorboard | 2.11.2 | Apache Software License | https://github.com/tensorflow/tensorboard | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tensorboard-data-server | 0.6.1 | Apache Software License | https://github.com/tensorflow/tensorboard/tree/master/tensorboard/data/server | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | tensorboard-plugin-wit | 1.8.1 | Apache 2.0 | https://whatif-tool.dev | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | tensorboardX | 2.6 | MIT License | https://github.com/lanpa/tensorboardX | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tensorflow | 2.12.0 | Apache Software License | https://www.tensorflow.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tensorflow-estimator | 2.12.0 | Apache Software License | https://www.tensorflow.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tensorflow-intel | 2.12.0 | Apache Software License | https://www.tensorflow.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tensorflow-io-gcs-filesystem | 0.31.0 | Apache Software License | https://github.com/tensorflow/io | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| termcolor | 2.3.0 | MIT License | https://github.com/termcolor/termcolor | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | terminado | 0.17.1 | BSD License | https://github.com/jupyter/terminado | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | tifffile | 2023.4.12 | BSD License | https://www.cgohlke.com | @@ -377,6 +419,8 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | websocket-client | 1.5.2 | Apache Software License | https://github.com/websocket-client/websocket-client.git | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| wrapt | 1.14.1 | BSD License | https://github.com/GrahamDumpleton/wrapt | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | y-py | 0.5.9 | MIT License | https://github.com/y-crdt/ypy | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | ypy-websocket | 0.8.2 | UNKNOWN | https://github.com/y-crdt/ypy-websocket | diff --git a/pyproject.toml b/pyproject.toml index aa9f5fdc..09b60777 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,8 @@ dependencies = [ "networkx==3.1", "numpy==1.23.5", "platformdirs==3.5.1", + "plotly==5.15.0", + "polars==0.18.4", "PyYAML==6.0", "ray[rllib]==2.2.0", "stable-baselines3==1.6.2", @@ -69,3 +71,12 @@ tensorflow = [ [project.scripts] primaite = "primaite.cli:app" + +[tool.isort] +profile = "black" +line_length = 79 +force_sort_within_sections = "False" +order_by_type = "False" + +[tool.black] +line-length = 79 diff --git a/src/primaite/VERSION b/src/primaite/VERSION index 3068ee27..4111d137 100644 --- a/src/primaite/VERSION +++ b/src/primaite/VERSION @@ -1 +1 @@ -2.0.0rc1 \ No newline at end of file +2.0.0rc1 diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index 64857c80..e753b4ef 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -3,12 +3,10 @@ import logging import logging.config import sys from bisect import bisect -from logging import Formatter, LogRecord, StreamHandler -from logging import Logger +from logging import Formatter, Logger, LogRecord, StreamHandler from logging.handlers import RotatingFileHandler from pathlib import Path -from typing import Dict -from typing import Final +from typing import Dict, Final import pkg_resources import yaml @@ -21,7 +19,6 @@ _PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite") def _get_primaite_config(): config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml" if not config_path.exists(): - config_path = Path( pkg_resources.resource_filename( "primaite", "setup/_package_data/primaite_config.yaml" @@ -37,7 +34,9 @@ def _get_primaite_config(): "ERROR": logging.ERROR, "CRITICAL": logging.CRITICAL, } - primaite_config["log_level"] = log_level_map[primaite_config["logging"]["log_level"]] + primaite_config["log_level"] = log_level_map[ + primaite_config["logging"]["log_level"] + ] return primaite_config @@ -111,9 +110,13 @@ _LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter( { logging.DEBUG: _PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"], logging.INFO: _PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"], - logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"], + logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"][ + "WARNING" + ], logging.ERROR: _PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"], - logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"] + logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"][ + "CRITICAL" + ], } ) diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index 284ed764..a147d963 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -10,7 +10,9 @@ class AccessControlList: def __init__(self): """Init.""" - self.acl: Dict[str, AccessControlList] = {} # A dictionary of ACL Rules + self.acl: Dict[ + str, AccessControlList + ] = {} # A dictionary of ACL Rules def check_address_match(self, _rule, _source_ip_address, _dest_ip_address): """ @@ -37,13 +39,17 @@ class AccessControlList: _rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == "ANY" ) - or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY") + or ( + _rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY" + ) ): return True else: return False - def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port): + def is_blocked( + self, _source_ip_address, _dest_ip_address, _protocol, _port + ): """ Checks for rules that block a protocol / port. @@ -87,7 +93,9 @@ class AccessControlList: _protocol: the protocol _port: the port """ - new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) + new_rule = ACLRule( + _permission, _source_ip, _dest_ip, _protocol, str(_port) + ) hash_value = hash(new_rule) self.acl[hash_value] = new_rule @@ -102,7 +110,9 @@ class AccessControlList: _protocol: the protocol _port: the port """ - rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) + rule = ACLRule( + _permission, _source_ip, _dest_ip, _protocol, str(_port) + ) hash_value = hash(rule) # There will not always be something 'popable' since the agent will be trying random things try: @@ -114,7 +124,9 @@ class AccessControlList: """Removes all rules.""" self.acl.clear() - def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port): + def get_dictionary_hash( + self, _permission, _source_ip, _dest_ip, _protocol, _port + ): """ Produces a hash value for a rule. @@ -128,6 +140,8 @@ class AccessControlList: Returns: Hash value based on rule parameters. """ - rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) + rule = ACLRule( + _permission, _source_ip, _dest_ip, _protocol, str(_port) + ) hash_value = hash(rule) return hash_value diff --git a/src/primaite/acl/acl_rule.py b/src/primaite/acl/acl_rule.py index ef631a70..05daecc4 100644 --- a/src/primaite/acl/acl_rule.py +++ b/src/primaite/acl/acl_rule.py @@ -30,7 +30,13 @@ class ACLRule: Returns hash of core parameters. """ return hash( - (self.permission, self.source_ip, self.dest_ip, self.protocol, self.port) + ( + self.permission, + self.source_ip, + self.dest_ip, + self.protocol, + self.port, + ) ) def get_permission(self): diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index f545a3cb..c76583c0 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -1,27 +1,31 @@ from __future__ import annotations + import json import time from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path -from typing import Optional, Final, Dict, Union +from typing import Dict, Final, Optional, Union from uuid import uuid4 import yaml import primaite from primaite import getLogger, SESSIONS_DIR -from primaite.config import lay_down_config -from primaite.config import training_config +from primaite.config import lay_down_config, training_config from primaite.config.training_config import TrainingConfig +from primaite.data_viz.session_plots import plot_av_reward_per_episode from primaite.environment.primaite_env import Primaite _LOGGER = getLogger(__name__) -def _get_session_path(session_timestamp: datetime) -> Path: +def get_session_path(session_timestamp: datetime) -> Path: """ - Get a temp directory session path the test session will output to. + Get the directory path the session will output to. + + This is set in the format of: + ~/primaite/sessions//_. :param session_timestamp: This is the datetime that the session started. :return: The session directory path. @@ -35,13 +39,15 @@ def _get_session_path(session_timestamp: datetime) -> Path: class AgentSessionABC(ABC): + """ + An ABC that manages training and/or evaluation of agents in PrimAITE. + + This class cannot be directly instantiated and must be inherited from + with all implemented abstract methods implemented. + """ @abstractmethod - def __init__( - self, - training_config_path, - lay_down_config_path - ): + def __init__(self, training_config_path, lay_down_config_path): if not isinstance(training_config_path, Path): training_config_path = Path(training_config_path) self._training_config_path: Final[Union[Path]] = training_config_path @@ -66,9 +72,8 @@ class AgentSessionABC(ABC): self._uuid = str(uuid4()) self.session_timestamp: datetime = datetime.now() "The session timestamp" - self.session_path = _get_session_path(self.session_timestamp) + self.session_path = get_session_path(self.session_timestamp) "The Session path" - self.checkpoints_path.mkdir(parents=True, exist_ok=True) @property def timestamp_str(self) -> str: @@ -78,17 +83,23 @@ class AgentSessionABC(ABC): @property def learning_path(self) -> Path: """The learning outputs path.""" - return self.session_path / "learning" + path = self.session_path / "learning" + path.mkdir(exist_ok=True, parents=True) + return path @property def evaluation_path(self) -> Path: """The evaluation outputs path.""" - return self.session_path / "evaluation" + path = self.session_path / "evaluation" + path.mkdir(exist_ok=True, parents=True) + return path @property def checkpoints_path(self) -> Path: """The Session checkpoints path.""" - return self.learning_path / "checkpoints" + path = self.learning_path / "checkpoints" + path.mkdir(exist_ok=True, parents=True) + return path @property def uuid(self): @@ -118,14 +129,8 @@ class AgentSessionABC(ABC): "uuid": self.uuid, "start_datetime": self.session_timestamp.isoformat(), "end_datetime": None, - "learning": { - "total_episodes": None, - "total_time_steps": None - }, - "evaluation": { - "total_episodes": None, - "total_time_steps": None - }, + "learning": {"total_episodes": None, "total_time_steps": None}, + "evaluation": {"total_episodes": None, "total_time_steps": None}, "env": { "training_config": self._training_config.to_dict( json_serializable=True @@ -156,11 +161,19 @@ class AgentSessionABC(ABC): metadata_dict["end_datetime"] = datetime.now().isoformat() if not self.is_eval: - metadata_dict["learning"]["total_episodes"] = self._env.episode_count # noqa - metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa + metadata_dict["learning"][ + "total_episodes" + ] = self._env.episode_count # noqa + metadata_dict["learning"][ + "total_time_steps" + ] = self._env.total_step_count # noqa else: - metadata_dict["evaluation"]["total_episodes"] = self._env.episode_count # noqa - metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa + metadata_dict["evaluation"][ + "total_episodes" + ] = self._env.episode_count # noqa + metadata_dict["evaluation"][ + "total_time_steps" + ] = self._env.total_step_count # noqa filepath = self.session_path / "session_metadata.json" _LOGGER.debug(f"Updating Session Metadata file: {filepath}") @@ -187,26 +200,47 @@ class AgentSessionABC(ABC): @abstractmethod def learn( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, - **kwargs + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + **kwargs, ): + """ + Train the agent. + + :param time_steps: The number of steps per episode. Optional. If not + passed, the value from the training config will be used. + :param episodes: The number of episodes. Optional. If not + passed, the value from the training config will be used. + :param kwargs: Any agent-specific key-word args to be passed. + """ if self._can_learn: _LOGGER.info("Finished learning") _LOGGER.debug("Writing transactions") self._update_session_metadata_file() self._can_evaluate = True self.is_eval = False + self._plot_av_reward_per_episode(learning_session=True) @abstractmethod def evaluate( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, - **kwargs + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + **kwargs, ): + """ + Evaluate the agent. + + :param time_steps: The number of steps per episode. Optional. If not + passed, the value from the training config will be used. + :param episodes: The number of episodes. Optional. If not + passed, the value from the training config will be used. + :param kwargs: Any agent-specific key-word args to be passed. + """ + self._env.set_as_eval() # noqa self.is_eval = True + self._plot_av_reward_per_episode(learning_session=False) _LOGGER.info("Finished evaluation") @abstractmethod @@ -216,6 +250,7 @@ class AgentSessionABC(ABC): @classmethod @abstractmethod def load(cls, path: Union[str, Path]) -> AgentSessionABC: + """Load an agent from file.""" if not isinstance(path, Path): path = Path(path) @@ -246,21 +281,56 @@ class AgentSessionABC(ABC): else: # Session path does not exist - msg = f"Failed to load PrimAITE Session, path does not exist: {path}" + msg = ( + f"Failed to load PrimAITE Session, path does not exist: {path}" + ) _LOGGER.error(msg) raise FileNotFoundError(msg) pass @abstractmethod def save(self): + """Save the agent.""" self._agent.save(self.session_path) @abstractmethod def export(self): + """Export the agent to transportable file format.""" pass + def close(self): + """Closes the agent.""" + self._env.episode_av_reward_writer.close() # noqa + self._env.transaction_writer.close() # noqa + + def _plot_av_reward_per_episode(self, learning_session: bool = True): + # self.close() + title = f"PrimAITE Session {self.timestamp_str} " + subtitle = str(self._training_config) + csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" + image_file = f"average_reward_per_episode_{self.timestamp_str}.png" + if learning_session: + title += "(Learning)" + path = self.learning_path / csv_file + image_path = self.learning_path / image_file + else: + title += "(Evaluation)" + path = self.evaluation_path / csv_file + image_path = self.evaluation_path / image_file + + fig = plot_av_reward_per_episode(path, title, subtitle) + fig.write_image(image_path) + _LOGGER.debug(f"Saved average rewards per episode plot to: {path}") + class HardCodedAgentSessionABC(AgentSessionABC): + """ + An Agent Session ABC for evaluation deterministic agents. + + This class cannot be directly instantiated and must be inherited from + with all implemented abstract methods implemented. + """ + def __init__(self, training_config_path, lay_down_config_path): super().__init__(training_config_path, lay_down_config_path) self._setup() @@ -270,13 +340,12 @@ class HardCodedAgentSessionABC(AgentSessionABC): training_config_path=self._training_config_path, lay_down_config_path=self._lay_down_config_path, session_path=self.session_path, - timestamp_str=self.timestamp_str + timestamp_str=self.timestamp_str, ) super()._setup() self._can_learn = False self._can_evaluate = True - def _save_checkpoint(self): pass @@ -284,11 +353,20 @@ class HardCodedAgentSessionABC(AgentSessionABC): pass def learn( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, - **kwargs + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + **kwargs, ): + """ + Train the agent. + + :param time_steps: The number of steps per episode. Optional. If not + passed, the value from the training config will be used. + :param episodes: The number of episodes. Optional. If not + passed, the value from the training config will be used. + :param kwargs: Any agent-specific key-word args to be passed. + """ _LOGGER.warning("Deterministic agents cannot learn") @abstractmethod @@ -296,20 +374,31 @@ class HardCodedAgentSessionABC(AgentSessionABC): pass def evaluate( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, - **kwargs + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + **kwargs, ): + """ + Evaluate the agent. + + :param time_steps: The number of steps per episode. Optional. If not + passed, the value from the training config will be used. + :param episodes: The number of episodes. Optional. If not + passed, the value from the training config will be used. + :param kwargs: Any agent-specific key-word args to be passed. + """ + self._env.set_as_eval() # noqa + self.is_eval = True + if not time_steps: time_steps = self._training_config.num_steps if not episodes: episodes = self._training_config.num_episodes - + obs = self._env.reset() for episode in range(episodes): # Reset env and collect initial observation - obs = self._env.reset() for step in range(time_steps): # Calculate action action = self._calculate_action(obs) @@ -322,15 +411,18 @@ class HardCodedAgentSessionABC(AgentSessionABC): # Introduce a delay between steps time.sleep(self._training_config.time_delay / 1000) + obs = self._env.reset() self._env.close() - super().evaluate() @classmethod def load(cls): + """Load an agent from file.""" _LOGGER.warning("Deterministic agents cannot be loaded") def save(self): + """Save the agent.""" _LOGGER.warning("Deterministic agents cannot be saved") def export(self): + """Export the agent to transportable file format.""" _LOGGER.warning("Deterministic agents cannot be exported") diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py index 4ad08f6e..f70320f1 100644 --- a/src/primaite/agents/hardcoded_acl.py +++ b/src/primaite/agents/hardcoded_acl.py @@ -11,9 +11,13 @@ from primaite.common.enums import HardCodedAgentView class HardCodedACLAgent(HardCodedAgentSessionABC): + """An Agent Session class that implements a deterministic ACL agent.""" def _calculate_action(self, obs): - if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC: + if ( + self._training_config.hard_coded_agent_view + == HardCodedAgentView.BASIC + ): # Basic view action using only the current observation return self._calculate_action_basic_view(obs) else: @@ -22,6 +26,12 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return self._calculate_action_full_view(obs) def get_blocked_green_iers(self, green_iers, acl, nodes): + """ + Get blocked green IERs. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ blocked_green_iers = {} for green_ier_id, green_ier in green_iers.items(): @@ -33,8 +43,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): port = green_ier.get_port() # Can be blocked by an ACL or by default (no allow rule exists) - if acl.is_blocked(source_node_address, dest_node_address, protocol, - port): + if acl.is_blocked( + source_node_address, dest_node_address, protocol, port + ): blocked_green_iers[green_ier_id] = green_ier return blocked_green_iers @@ -42,8 +53,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): def get_matching_acl_rules_for_ier(self, ier, acl, nodes): """ Get matching ACL rules for an IER. - """ + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ source_node_id = ier.get_source_node_id() source_node_address = nodes[source_node_id].ip_address dest_node_id = ier.get_dest_node_id() @@ -51,17 +64,22 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): protocol = ier.get_protocol() # e.g. 'TCP' port = ier.get_port() - matching_rules = acl.get_relevant_rules(source_node_address, - dest_node_address, protocol, - port) + matching_rules = acl.get_relevant_rules( + source_node_address, dest_node_address, protocol, port + ) return matching_rules def get_blocking_acl_rules_for_ier(self, ier, acl, nodes): """ Get blocking ACL rules for an IER. - Warning: Can return empty dict but IER can still be blocked by default (No ALLOW rule, therefore blocked) - """ + .. warning:: + Can return empty dict but IER can still be blocked by default + (No ALLOW rule, therefore blocked). + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes) blocked_rules = {} @@ -74,8 +92,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): def get_allow_acl_rules_for_ier(self, ier, acl, nodes): """ Get all allowing ACL rules for an IER. - """ + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes) allowed_rules = {} @@ -85,9 +105,22 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return allowed_rules - def get_matching_acl_rules(self, source_node_id, dest_node_id, protocol, - port, acl, - nodes, services_list): + def get_matching_acl_rules( + self, + source_node_id, + dest_node_id, + protocol, + port, + acl, + nodes, + services_list, + ): + """ + Get matching ACL rules. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ if source_node_id != "ANY": source_node_address = nodes[str(source_node_id)].ip_address else: @@ -100,21 +133,39 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): if protocol != "ANY": protocol = services_list[ - protocol - 1] # -1 as dont have to account for ANY in list of services + protocol - 1 + ] # -1 as dont have to account for ANY in list of services - matching_rules = acl.get_relevant_rules(source_node_address, - dest_node_address, protocol, - port) + matching_rules = acl.get_relevant_rules( + source_node_address, dest_node_address, protocol, port + ) return matching_rules - def get_allow_acl_rules(self, source_node_id, dest_node_id, protocol, - port, acl, - nodes, services_list): - matching_rules = self.get_matching_acl_rules(source_node_id, - dest_node_id, - protocol, port, acl, - nodes, - services_list) + def get_allow_acl_rules( + self, + source_node_id, + dest_node_id, + protocol, + port, + acl, + nodes, + services_list, + ): + """ + Get the ALLOW ACL rules. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + matching_rules = self.get_matching_acl_rules( + source_node_id, + dest_node_id, + protocol, + port, + acl, + nodes, + services_list, + ) allowed_rules = {} for rule_key, rule_value in matching_rules.items(): @@ -123,14 +174,31 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return allowed_rules - def get_deny_acl_rules(self, source_node_id, dest_node_id, protocol, port, - acl, - nodes, services_list): - matching_rules = self.get_matching_acl_rules(source_node_id, - dest_node_id, - protocol, port, acl, - nodes, - services_list) + def get_deny_acl_rules( + self, + source_node_id, + dest_node_id, + protocol, + port, + acl, + nodes, + services_list, + ): + """ + Get the DENY ACL rules. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + matching_rules = self.get_matching_acl_rules( + source_node_id, + dest_node_id, + protocol, + port, + acl, + nodes, + services_list, + ) allowed_rules = {} for rule_key, rule_value in matching_rules.items(): @@ -141,7 +209,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): def _calculate_action_full_view(self, obs): """ - Given an observation and the environment calculate a good acl-based action for the blue agent to take + Calculate a good acl-based action for the blue agent to take. Knowledge of just the observation space is insufficient for a perfect solution, as we need to know: @@ -167,8 +235,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing an overwhelmed state, so we don't do this. + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ - #obs = convert_to_old_obs(obs) + # obs = convert_to_old_obs(obs) r_obs = transform_change_obs_readable(obs) _, _, _, *s = r_obs @@ -184,7 +254,6 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): for service_num, service_states in enumerate(s): for x, service_state in enumerate(service_states): if service_state == "COMPROMISED": - action_source_id = x + 1 # +1 as 0 is any action_destination_id = "ANY" action_protocol = service_num + 1 # +1 as 0 is any @@ -215,19 +284,23 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): action_permission = "ALLOW" action_source_ip = rule.get_source_ip() action_source_id = int( - get_node_of_ip(action_source_ip, self._env.nodes)) + get_node_of_ip(action_source_ip, self._env.nodes) + ) action_destination_ip = rule.get_dest_ip() action_destination_id = int( - get_node_of_ip(action_destination_ip, - self._env.nodes)) + get_node_of_ip( + action_destination_ip, self._env.nodes + ) + ) action_protocol_name = rule.get_protocol() action_protocol = ( - self._env.services_list.index( - action_protocol_name) + 1 + self._env.services_list.index(action_protocol_name) + + 1 ) # convert name e.g. 'TCP' to index action_port_name = rule.get_port() - action_port = self._env.ports_list.index( - action_port_name) + 1 # convert port name e.g. '80' to index + action_port = ( + self._env.ports_list.index(action_port_name) + 1 + ) # convert port name e.g. '80' to index found_action = True break @@ -258,21 +331,21 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): if not found_action: # Which Green IERS are blocked blocked_green_iers = self.get_blocked_green_iers( - self._env.green_iers, self._env.acl, - self._env.nodes) + self._env.green_iers, self._env.acl, self._env.nodes + ) for ier_key, ier in blocked_green_iers.items(): - # Which ALLOW rules are allowing this IER (none) - allowing_rules = self.get_allow_acl_rules_for_ier(ier, - self._env.acl, - self._env.nodes) + allowing_rules = self.get_allow_acl_rules_for_ier( + ier, self._env.acl, self._env.nodes + ) # If there are no blocking rules, it may be being blocked by default # If there is already an allow rule node_id_to_check = int(ier.get_source_node_id()) service_name_to_check = ier.get_protocol() service_id_to_check = self._env.services_list.index( - service_name_to_check) + service_name_to_check + ) # Service state of the the source node in the ier service_state = s[service_id_to_check][node_id_to_check - 1] @@ -283,11 +356,13 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): action_source_id = int(ier.get_source_node_id()) action_destination_id = int(ier.get_dest_node_id()) action_protocol_name = ier.get_protocol() - action_protocol = self._env.services_list.index( - action_protocol_name) + 1 # convert name e.g. 'TCP' to index + action_protocol = ( + self._env.services_list.index(action_protocol_name) + 1 + ) # convert name e.g. 'TCP' to index action_port_name = ier.get_port() - action_port = self._env.ports_list.index( - action_port_name) + 1 # convert port name e.g. '80' to index + action_port = ( + self._env.ports_list.index(action_port_name) + 1 + ) # convert port name e.g. '80' to index found_action = True break @@ -311,19 +386,25 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return action def _calculate_action_basic_view(self, obs): - """Given an observation calculate a good acl-based action for the blue agent to take + """Calculate a good acl-based action for the blue agent to take. - Uses ONLY information from the current observation with NO knowledge of previous actions taken and - NO reward feedback. + Uses ONLY information from the current observation with NO knowledge + of previous actions taken and NO reward feedback. - We rely on randomness to select the precise action, as we want to block all traffic originating from - a compromised node, without being able to tell: + We rely on randomness to select the precise action, as we want to + block all traffic originating from a compromised node, without being + able to tell: 1. Which ACL rules already exist - 1. Which actions the agent has already tried. + 2. Which actions the agent has already tried. - There is a high probability that the correct rule will not be deleted before the state becomes overwhelmed. + There is a high probability that the correct rule will not be deleted + before the state becomes overwhelmed. - Currently a deny rule does not overwrite an allow rule. The allow rules must be deleted. + Currently, a deny rule does not overwrite an allow rule. The allow + rules must be deleted. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ action_dict = self._env.action_dict r_obs = transform_change_obs_readable(obs) @@ -333,27 +414,35 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): s = [*s] number_of_nodes = len( - [i for i in o if i != "NONE"]) # number of nodes (not links) + [i for i in o if i != "NONE"] + ) # number of nodes (not links) for service_num, service_states in enumerate(s): - comprimised_states = [n for n, i in enumerate(service_states) if - i == "COMPROMISED"] + comprimised_states = [ + n for n, i in enumerate(service_states) if i == "COMPROMISED" + ] if len(comprimised_states) == 0: # No states are COMPROMISED, try the next service continue - compromised_node = np.random.choice( - comprimised_states) + 1 # +1 as 0 would be any + compromised_node = ( + np.random.choice(comprimised_states) + 1 + ) # +1 as 0 would be any action_decision = "DELETE" action_permission = "ALLOW" action_source_ip = compromised_node # Randomly select a destination ID to block action_destination_ip = np.random.choice( - list(range(1, number_of_nodes + 1)) + ["ANY"]) - action_destination_ip = int( - action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip + list(range(1, number_of_nodes + 1)) + ["ANY"] + ) + action_destination_ip = ( + int(action_destination_ip) + if action_destination_ip != "ANY" + else action_destination_ip + ) action_protocol = service_num + 1 # +1 as 0 is any # Randomly select a port to block - # Bad assumption that number of protocols equals number of ports AND no rules exist with an ANY port + # Bad assumption that number of protocols equals number of ports + # AND no rules exist with an ANY port action_port = np.random.choice(list(range(1, len(s) + 1))) action = [ diff --git a/src/primaite/agents/hardcoded_node.py b/src/primaite/agents/hardcoded_node.py index 6db43da6..e258edb0 100644 --- a/src/primaite/agents/hardcoded_node.py +++ b/src/primaite/agents/hardcoded_node.py @@ -1,16 +1,21 @@ from primaite.agents.agent import HardCodedAgentSessionABC from primaite.agents.utils import ( get_new_action, - transform_change_obs_readable, -) -from primaite.agents.utils import ( transform_action_node_enum, + transform_change_obs_readable, ) class HardCodedNodeAgent(HardCodedAgentSessionABC): + """An Agent Session class that implements a deterministic Node agent.""" + def _calculate_action(self, obs): - """Given an observation calculate a good node-based action for the blue agent to take""" + """ + Calculate a good node-based action for the blue agent to take. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ action_dict = self._env.action_dict r_obs = transform_change_obs_readable(obs) _, o, os, *s = r_obs @@ -18,7 +23,8 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC): if len(r_obs) == 4: # only 1 service s = [*s] - # Check in order of most important states (order doesn't currently matter, but it probably should) + # Check in order of most important states (order doesn't currently + # matter, but it probably should) # First see if any OS states are compromised for x, os_state in enumerate(os): if os_state == "COMPROMISED": @@ -26,8 +32,12 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC): action_node_property = "OS" property_action = "PATCHING" action_service_index = 0 # does nothing isn't relevant for os - action = [action_node_id, action_node_property, - property_action, action_service_index] + action = [ + action_node_id, + action_node_property, + property_action, + action_service_index, + ] action = transform_action_node_enum(action) action = get_new_action(action, action_dict) # We can only perform 1 action on each step @@ -44,8 +54,12 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC): property_action = "PATCHING" action_service_index = service_num - action = [action_node_id, action_node_property, - property_action, action_service_index] + action = [ + action_node_id, + action_node_property, + property_action, + action_service_index, + ] action = transform_action_node_enum(action) action = get_new_action(action, action_dict) # We can only perform 1 action on each step @@ -63,8 +77,12 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC): property_action = "PATCHING" action_service_index = service_num - action = [action_node_id, action_node_property, - property_action, action_service_index] + action = [ + action_node_id, + action_node_property, + property_action, + action_service_index, + ] action = transform_action_node_enum(action) action = get_new_action(action, action_dict) # We can only perform 1 action on each step @@ -75,10 +93,18 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC): if os_state == "OFF": action_node_id = x + 1 action_node_property = "OPERATING" - property_action = "ON" # Why reset it when we can just turn it on - action_service_index = 0 # does nothing isn't relevant for operating state - action = [action_node_id, action_node_property, - property_action, action_service_index] + property_action = ( + "ON" # Why reset it when we can just turn it on + ) + action_service_index = ( + 0 # does nothing isn't relevant for operating state + ) + action = [ + action_node_id, + action_node_property, + property_action, + action_service_index, + ] action = transform_action_node_enum(action, action_dict) action = get_new_action(action, action_dict) # We can only perform 1 action on each step @@ -89,8 +115,12 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC): action_node_property = "NONE" property_action = "ON" action_service_index = 0 - action = [action_node_id, action_node_property, property_action, - action_service_index] + action = [ + action_node_id, + action_node_property, + property_action, + action_service_index, + ] action = transform_action_node_enum(action) action = get_new_action(action, action_dict) diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index 8a6428bb..35ae1b53 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -1,28 +1,35 @@ +from __future__ import annotations + import json from datetime import datetime from pathlib import Path -from typing import Optional +from typing import Optional, Union +import tensorflow as tf from ray.rllib.algorithms import Algorithm from ray.rllib.algorithms.a2c import A2CConfig from ray.rllib.algorithms.ppo import PPOConfig from ray.tune.logger import UnifiedLogger from ray.tune.registry import register_env -import tensorflow as tf + from primaite import getLogger from primaite.agents.agent import AgentSessionABC -from primaite.common.enums import AgentFramework, AgentIdentifier, \ - DeepLearningFramework +from primaite.common.enums import ( + AgentFramework, + AgentIdentifier, + DeepLearningFramework, +) from primaite.environment.primaite_env import Primaite _LOGGER = getLogger(__name__) + def _env_creator(env_config): return Primaite( training_config_path=env_config["training_config_path"], lay_down_config_path=env_config["lay_down_config_path"], session_path=env_config["session_path"], - timestamp_str=env_config["timestamp_str"] + timestamp_str=env_config["timestamp_str"], ) @@ -37,16 +44,15 @@ def _custom_log_creator(session_path: Path): class RLlibAgent(AgentSessionABC): + """An AgentSession class that implements a Ray RLlib agent.""" - def __init__( - self, - training_config_path, - lay_down_config_path - ): + def __init__(self, training_config_path, lay_down_config_path): super().__init__(training_config_path, lay_down_config_path) if not self._training_config.agent_framework == AgentFramework.RLLIB: - msg = (f"Expected RLLIB agent_framework, " - f"got {self._training_config.agent_framework}") + msg = ( + f"Expected RLLIB agent_framework, " + f"got {self._training_config.agent_framework}" + ) _LOGGER.error(msg) raise ValueError(msg) if self._training_config.agent_identifier == AgentIdentifier.PPO: @@ -54,8 +60,10 @@ class RLlibAgent(AgentSessionABC): elif self._training_config.agent_identifier == AgentIdentifier.A2C: self._agent_config_class = A2CConfig else: - msg = ("Expected PPO or A2C agent_identifier, " - f"got {self._training_config.agent_identifier.value}") + msg = ( + "Expected PPO or A2C agent_identifier, " + f"got {self._training_config.agent_identifier.value}" + ) _LOGGER.error(msg) raise ValueError(msg) self._agent_config: PPOConfig @@ -86,8 +94,12 @@ class RLlibAgent(AgentSessionABC): metadata_dict = json.load(file) metadata_dict["end_datetime"] = datetime.now().isoformat() - metadata_dict["total_episodes"] = self._current_result["episodes_total"] - metadata_dict["total_time_steps"] = self._current_result["timesteps_total"] + metadata_dict["total_episodes"] = self._current_result[ + "episodes_total" + ] + metadata_dict["total_time_steps"] = self._current_result[ + "timesteps_total" + ] filepath = self.session_path / "session_metadata.json" _LOGGER.debug(f"Updating Session Metadata file: {filepath}") @@ -106,43 +118,48 @@ class RLlibAgent(AgentSessionABC): training_config_path=self._training_config_path, lay_down_config_path=self._lay_down_config_path, session_path=self.session_path, - timestamp_str=self.timestamp_str - ) + timestamp_str=self.timestamp_str, + ), ) self._agent_config.training( train_batch_size=self._training_config.num_steps ) - self._agent_config.framework( - framework="tf" - ) + self._agent_config.framework(framework="tf") self._agent_config.rollouts( num_rollout_workers=1, num_envs_per_worker=1, - horizon=self._training_config.num_steps + horizon=self._training_config.num_steps, ) self._agent: Algorithm = self._agent_config.build( logger_creator=_custom_log_creator(self.session_path) ) - def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes episode_count = self._current_result["episodes_total"] if checkpoint_n > 0 and episode_count > 0: - if ( - (episode_count % checkpoint_n == 0) - or (episode_count == self._training_config.num_episodes) + if (episode_count % checkpoint_n == 0) or ( + episode_count == self._training_config.num_episodes ): - self._agent.save(self.checkpoints_path) + self._agent.save(str(self.checkpoints_path)) def learn( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, - **kwargs + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + **kwargs, ): + """ + Evaluate the agent. + + :param time_steps: The number of steps per episode. Optional. If not + passed, the value from the training config will be used. + :param episodes: The number of episodes. Optional. If not + passed, the value from the training config will be used. + :param kwargs: Any agent-specific key-word args to be passed. + """ # Temporarily override train_batch_size and horizon if time_steps: self._agent_config.train_batch_size = time_steps @@ -150,37 +167,53 @@ class RLlibAgent(AgentSessionABC): if not episodes: episodes = self._training_config.num_episodes - _LOGGER.info(f"Beginning learning for {episodes} episodes @" - f" {time_steps} time steps...") + _LOGGER.info( + f"Beginning learning for {episodes} episodes @" + f" {time_steps} time steps..." + ) for i in range(episodes): self._current_result = self._agent.train() self._save_checkpoint() - if self._training_config.deep_learning_framework != DeepLearningFramework.TORCH: + if ( + self._training_config.deep_learning_framework + != DeepLearningFramework.TORCH + ): policy = self._agent.get_policy() tf.compat.v1.summary.FileWriter( - self.session_path / "ray_results", - policy.get_session().graph + self.session_path / "ray_results", policy.get_session().graph ) super().learn() self._agent.stop() def evaluate( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, - **kwargs + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + **kwargs, ): + """ + Evaluate the agent. + + :param time_steps: The number of steps per episode. Optional. If not + passed, the value from the training config will be used. + :param episodes: The number of episodes. Optional. If not + passed, the value from the training config will be used. + :param kwargs: Any agent-specific key-word args to be passed. + """ raise NotImplementedError def _get_latest_checkpoint(self): raise NotImplementedError @classmethod - def load(cls): + def load(cls, path: Union[str, Path]) -> RLlibAgent: + """Load an agent from file.""" raise NotImplementedError def save(self): + """Save the agent.""" raise NotImplementedError def export(self): + """Export the agent to transportable file format.""" raise NotImplementedError diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 328e6286..8d5dd633 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -1,27 +1,30 @@ -from typing import Optional +from __future__ import annotations + +from pathlib import Path +from typing import Optional, Union import numpy as np -from stable_baselines3 import PPO, A2C +from stable_baselines3 import A2C, PPO from stable_baselines3.ppo import MlpPolicy as PPOMlp from primaite import getLogger from primaite.agents.agent import AgentSessionABC -from primaite.common.enums import AgentIdentifier, AgentFramework +from primaite.common.enums import AgentFramework, AgentIdentifier from primaite.environment.primaite_env import Primaite _LOGGER = getLogger(__name__) class SB3Agent(AgentSessionABC): - def __init__( - self, - training_config_path, - lay_down_config_path - ): + """An AgentSession class that implements a Stable Baselines3 agent.""" + + def __init__(self, training_config_path, lay_down_config_path): super().__init__(training_config_path, lay_down_config_path) if not self._training_config.agent_framework == AgentFramework.SB3: - msg = (f"Expected SB3 agent_framework, " - f"got {self._training_config.agent_framework}") + msg = ( + f"Expected SB3 agent_framework, " + f"got {self._training_config.agent_framework}" + ) _LOGGER.error(msg) raise ValueError(msg) if self._training_config.agent_identifier == AgentIdentifier.PPO: @@ -29,8 +32,10 @@ class SB3Agent(AgentSessionABC): elif self._training_config.agent_identifier == AgentIdentifier.A2C: self._agent_class = A2C else: - msg = ("Expected PPO or A2C agent_identifier, " - f"got {self._training_config.agent_identifier.value}") + msg = ( + "Expected PPO or A2C agent_identifier, " + f"got {self._training_config.agent_identifier}" + ) _LOGGER.error(msg) raise ValueError(msg) @@ -52,25 +57,26 @@ class SB3Agent(AgentSessionABC): training_config_path=self._training_config_path, lay_down_config_path=self._lay_down_config_path, session_path=self.session_path, - timestamp_str=self.timestamp_str + timestamp_str=self.timestamp_str, ) self._agent = self._agent_class( PPOMlp, self._env, verbose=self.output_verbose_level, n_steps=self._training_config.num_steps, - tensorboard_log=self._tensorboard_log_path + tensorboard_log=self._tensorboard_log_path, ) def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes episode_count = self._env.episode_count if checkpoint_n > 0 and episode_count > 0: - if ( - (episode_count % checkpoint_n == 0) - or (episode_count == self._training_config.num_episodes) + if (episode_count % checkpoint_n == 0) or ( + episode_count == self._training_config.num_episodes ): - checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip" + checkpoint_path = ( + self.checkpoints_path / f"sb3ppo_{episode_count}.zip" + ) self._agent.save(checkpoint_path) _LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}") @@ -78,33 +84,54 @@ class SB3Agent(AgentSessionABC): pass def learn( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, - **kwargs + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + **kwargs, ): + """ + Train the agent. + + :param time_steps: The number of steps per episode. Optional. If not + passed, the value from the training config will be used. + :param episodes: The number of episodes. Optional. If not + passed, the value from the training config will be used. + :param kwargs: Any agent-specific key-word args to be passed. + """ if not time_steps: time_steps = self._training_config.num_steps if not episodes: episodes = self._training_config.num_episodes self.is_eval = False - _LOGGER.info(f"Beginning learning for {episodes} episodes @" - f" {time_steps} time steps...") + _LOGGER.info( + f"Beginning learning for {episodes} episodes @" + f" {time_steps} time steps..." + ) for i in range(episodes): self._agent.learn(total_timesteps=time_steps) self._save_checkpoint() - self._env.close() + self.close() super().learn() def evaluate( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, - deterministic: bool = True, - **kwargs + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + deterministic: bool = True, + **kwargs, ): + """ + Evaluate the agent. + + :param time_steps: The number of steps per episode. Optional. If not + passed, the value from the training config will be used. + :param episodes: The number of episodes. Optional. If not + passed, the value from the training config will be used. + :param deterministic: Whether the evaluation is deterministic. + :param kwargs: Any agent-specific key-word args to be passed. + """ if not time_steps: time_steps = self._training_config.num_steps @@ -116,27 +143,31 @@ class SB3Agent(AgentSessionABC): deterministic_str = "deterministic" else: deterministic_str = "non-deterministic" - _LOGGER.info(f"Beginning {deterministic_str} evaluation for " - f"{episodes} episodes @ {time_steps} time steps...") + _LOGGER.info( + f"Beginning {deterministic_str} evaluation for " + f"{episodes} episodes @ {time_steps} time steps..." + ) for episode in range(episodes): obs = self._env.reset() for step in range(time_steps): action, _states = self._agent.predict( - obs, - deterministic=deterministic + obs, deterministic=deterministic ) if isinstance(action, np.ndarray): action = np.int64(action) obs, rewards, done, info = self._env.step(action) - _LOGGER.info(f"Finished evaluation") + super().evaluate() @classmethod - def load(self): + def load(cls, path: Union[str, Path]) -> SB3Agent: + """Load an agent from file.""" raise NotImplementedError def save(self): + """Save the agent.""" raise NotImplementedError def export(self): + """Export the agent to transportable file format.""" raise NotImplementedError diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index a4eadc3b..c3e67fdf 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -4,9 +4,9 @@ from primaite.common.enums import ( HardwareState, LinkStatus, NodeHardwareAction, + NodePOLType, NodeSoftwareAction, SoftwareState, - NodePOLType ) @@ -16,14 +16,17 @@ def transform_action_node_readable(action): example: [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] + + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ action_node_property = NodePOLType(action[1]).name if action_node_property == "OPERATING": property_action = NodeHardwareAction(action[2]).name - elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[ - 2 - ] <= 1: + elif ( + action_node_property == "OS" or action_node_property == "SERVICE" + ) and action[2] <= 1: property_action = NodeSoftwareAction(action[2]).name else: property_action = "NONE" @@ -38,6 +41,9 @@ def transform_action_acl_readable(action): example: [0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1] + + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"} action_permissions = {0: "DENY", 1: "ALLOW"} @@ -62,6 +68,9 @@ def is_valid_node_action(action): Does NOT consider: - Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch - Node already being in that state (turning an ON node ON) + + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ action_r = transform_action_node_readable(action) @@ -77,7 +86,10 @@ def is_valid_node_action(action): if node_property == "OPERATING" and node_action == "PATCHING": # Operating State cannot PATCH return False - if node_property != "OPERATING" and node_action not in ["NONE", "PATCHING"]: + if node_property != "OPERATING" and node_action not in [ + "NONE", + "PATCHING", + ]: # Software States can only do Nothing or Patch return False return True @@ -92,6 +104,9 @@ def is_valid_acl_action(action): Does NOT consider: - Trying to create identical rules - Trying to create a rule which is a subset of another rule (caused by "ANY") + + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ action_r = transform_action_acl_readable(action) @@ -118,7 +133,12 @@ def is_valid_acl_action(action): def is_valid_acl_action_extra(action): - """Harsher version of valid acl actions, does not allow action.""" + """ + Harsher version of valid acl actions, does not allow action. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ if is_valid_acl_action(action) is False: return False @@ -136,13 +156,15 @@ def is_valid_acl_action_extra(action): return True - def transform_change_obs_readable(obs): - """Transform list of transactions to readable list of each observation property + """ + Transform list of transactions to readable list of each observation property. example: - np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']] + + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ ids = [i for i in obs[:, 0]] operating_states = [HardwareState(i).name for i in obs[:, 1]] @@ -151,7 +173,9 @@ def transform_change_obs_readable(obs): for service in range(3, obs.shape[1]): # Links bit/s don't have a service state - service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]] + service_states = [ + SoftwareState(i).name if i <= 4 else i for i in obs[:, service] + ] new_obs.append(service_states) return new_obs @@ -159,10 +183,13 @@ def transform_change_obs_readable(obs): def transform_obs_readable(obs): """ - example: - np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']] - """ + Transform observation to readable format. + np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']] + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ changed_obs = transform_change_obs_readable(obs) new_obs = list(zip(*changed_obs)) # Convert list of tuples to list of lists @@ -172,7 +199,12 @@ def transform_obs_readable(obs): def convert_to_new_obs(obs, num_nodes=10): - """Convert original gym Box observation space to new multiDiscrete observation space""" + """ + Convert original gym Box observation space to new multiDiscrete observation space. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ # Remove ID columns, remove links and flatten to MultiDiscrete observation space new_obs = obs[:num_nodes, 1:].flatten() return new_obs @@ -180,7 +212,9 @@ def convert_to_new_obs(obs, num_nodes=10): def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): """ - Convert to old observation, links filled with 0's as no information is included in new observation space + Convert to old observation. + + Links filled with 0's as no information is included in new observation space. example: obs = array([1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1, 1]) @@ -190,13 +224,17 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): [ 3, 1, 1, 1], ... [20, 0, 0, 0]]) + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ - # Convert back to more readable, original format reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2) # Add empty links back and add node ID back - s = np.zeros([reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1], dtype=np.int64) + s = np.zeros( + [reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1], + dtype=np.int64, + ) s[:, 0] = range(1, num_nodes + num_links + 1) # Adding ID back s[:num_nodes, 1:] = reshaped_nodes # put values back in new_obs = s @@ -209,14 +247,19 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): return new_obs -def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): - """Return string describing change between two observations +def describe_obs_change( + obs1, obs2, num_nodes=10, num_links=10, num_services=1 +): + """ + Return string describing change between two observations. example: obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]]) obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]]) output = 'ID 1: SERVICE 2 set to GOOD' + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services) obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services) @@ -236,20 +279,27 @@ def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): def _describe_obs_change_helper(obs_change, is_link): - """ " - Helper funcion to describe what has changed + """ + Helper funcion to describe what has changed. example: [ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD" Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.' + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ # Indexes where a change has occured, not including 0th index - index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1] + index_changed = [ + i for i in range(1, len(obs_change)) if obs_change[i] != -1 + ] # Node pol types, Indexes >= 3 are service nodes NodePOLTypes = [ - NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) for i in index_changed + NodePOLType(i).name + if i < 3 + else NodePOLType(3).name + " " + str(i - 3) + for i in index_changed ] # Account for hardware states, software sattes and links states = [ @@ -263,8 +313,8 @@ def _describe_obs_change_helper(obs_change, is_link): if not is_link: desc = f"ID {obs_change[0]}:" - for NodePOLType, state in list(zip(NodePOLTypes, states)): - desc = desc + " " + NodePOLType + " changed to " + state + "." + for node_pol_type, state in list(zip(NodePOLTypes, states)): + desc = desc + " " + node_pol_type + " changed to " + state + "." else: desc = f"ID {obs_change[0]}: Link traffic changed to {states[0]}." @@ -273,12 +323,14 @@ def _describe_obs_change_helper(obs_change, is_link): def transform_action_node_enum(action): """ - Convert a node action from readable string format, to enumerated format + Convert a node action from readable string format, to enumerated format. example: [1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0] - """ + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ action_node_id = action[0] action_node_property = NodePOLType[action[1]].value @@ -291,24 +343,33 @@ def transform_action_node_enum(action): action_service_index = action[3] - new_action = [action_node_id, action_node_property, property_action, action_service_index] + new_action = [ + action_node_id, + action_node_property, + property_action, + action_service_index, + ] return new_action def transform_action_node_readable(action): """ - Convert a node action from enumerated format to readable format + Convert a node action from enumerated format to readable format. example: [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] - """ + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ action_node_property = NodePOLType(action[1]).name if action_node_property == "OPERATING": property_action = NodeHardwareAction(action[2]).name - elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1: + elif ( + action_node_property == "OS" or action_node_property == "SERVICE" + ) and action[2] <= 1: property_action = NodeSoftwareAction(action[2]).name else: property_action = "NONE" @@ -319,9 +380,11 @@ def transform_action_node_readable(action): def node_action_description(action): """ - Generate string describing a node-based action - """ + Generate string describing a node-based action. + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ if isinstance(action[1], (int, np.int64)): # transform action to readable format action = transform_action_node_readable(action) @@ -334,7 +397,9 @@ def node_action_description(action): if property_action == "NONE": return "" if node_property == "OPERATING" or node_property == "OS": - description = f"NODE {node_id}, {node_property}, SET TO {property_action}" + description = ( + f"NODE {node_id}, {node_property}, SET TO {property_action}" + ) elif node_property == "SERVICE": description = f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}" else: @@ -343,34 +408,13 @@ def node_action_description(action): return description -def transform_action_acl_readable(action): - """ - Transform an ACL action to a more readable format - - example: - [0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1] - """ - - action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"} - action_permissions = {0: "DENY", 1: "ALLOW"} - - action_decision = action_decisions[action[0]] - action_permission = action_permissions[action[1]] - - # For IPs, Ports and Protocols, 0 means any, otherwise its just an index - new_action = [action_decision, action_permission] + list(action[2:6]) - for n, val in enumerate(list(action[2:6])): - if val == 0: - new_action[n + 2] = "ANY" - - return new_action - - def transform_action_acl_enum(action): """ - Convert a acl action from readable string format, to enumerated format - """ + Convert acl action from readable str format, to enumerated format. + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2} action_permissions = {"DENY": 0, "ALLOW": 1} @@ -388,8 +432,12 @@ def transform_action_acl_enum(action): def acl_action_description(action): - """generate string describing a acl-based action""" + """ + Generate string describing an acl-based action. + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ if isinstance(action[0], (int, np.int64)): # transform action to readable format action = transform_action_acl_readable(action) @@ -406,11 +454,13 @@ def acl_action_description(action): def get_node_of_ip(ip, node_dict): """ - Get the node ID of an IP address + Get the node ID of an IP address. node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes) - """ + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ for node_key, node_value in node_dict.items(): node_ip = node_value.ip_address if node_ip == ip: @@ -418,13 +468,16 @@ def get_node_of_ip(ip, node_dict): def is_valid_node_action(action): - """Is the node action an actual valid action + """Is the node action an actual valid action. Only uses information about the action to determine if the action has an effect Does NOT consider: - Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch - Node already being in that state (turning an ON node ON) + + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ action_r = transform_action_node_readable(action) @@ -438,7 +491,10 @@ def is_valid_node_action(action): if node_property == "OPERATING" and node_action == "PATCHING": # Operating State cannot PATCH return False - if node_property != "OPERATING" and node_action not in ["NONE", "PATCHING"]: + if node_property != "OPERATING" and node_action not in [ + "NONE", + "PATCHING", + ]: # Software States can only do Nothing or Patch return False return True @@ -446,13 +502,16 @@ def is_valid_node_action(action): def is_valid_acl_action(action): """ - Is the ACL action an actual valid action + Is the ACL action an actual valid action. Only uses information about the action to determine if the action has an effect Does NOT consider: - Trying to create identical rules - Trying to create a rule which is a subset of another rule (caused by "ANY") + + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ action_r = transform_action_acl_readable(action) @@ -463,7 +522,11 @@ def is_valid_acl_action(action): if action_decision == "NONE": return False - if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY": + if ( + action_source_id == action_destination_id + and action_source_id != "ANY" + and action_destination_id != "ANY" + ): # ACL rule towards itself return False if action_permission == "DENY": @@ -475,7 +538,12 @@ def is_valid_acl_action(action): def is_valid_acl_action_extra(action): - """Harsher version of valid acl actions, does not allow action""" + """ + Harsher version of valid acl actions, does not allow action. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ if is_valid_acl_action(action) is False: return False @@ -494,33 +562,17 @@ def is_valid_acl_action_extra(action): def get_new_action(old_action, action_dict): - """Get new action (e.g. 32) from old action e.g. [1,1,1,0] - - old_action can be either node or acl action type """ + Get new action (e.g. 32) from old action e.g. [1,1,1,0]. + Old_action can be either node or acl action type + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ for key, val in action_dict.items(): if list(val) == list(old_action): return key # Not all possible actions are included in dict, only valid action are # if action is not in the dict, its an invalid action so return 0 return 0 - - -def get_action_description(action, action_dict): - """ - Get a string describing/explaining what an action is doing in words - """ - - action_array = action_dict[action] - if len(action_array) == 4: - # node actions have length 4 - action_description = node_action_description(action_array) - elif len(action_array) == 6: - # acl actions have length 6 - action_description = acl_action_description(action_array) - else: - # Should never happen - action_description = "Unrecognised action" - - return action_description diff --git a/src/primaite/cli.py b/src/primaite/cli.py index aa88a391..10e23bfc 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -13,6 +13,8 @@ import yaml from platformdirs import PlatformDirs from typing_extensions import Annotated +from primaite.data_viz import PlotlyTemplate + app = typer.Typer() @@ -54,7 +56,9 @@ def logs(last_n: Annotated[int, typer.Option("-n")]): print(re.sub(r"\n*", "", line)) -_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa +_LogLevel = Enum( + "LogLevel", {k: k for k in logging._levelToName.values()} +) # noqa @app.command() @@ -76,11 +80,12 @@ def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None): primaite_config = yaml.safe_load(file) if level: - primaite_config["log_level"] = level.value + primaite_config["logging"]["log_level"] = level.value with open(user_config_path, "w") as file: yaml.dump(primaite_config, file) + print(f"PrimAITE Log Level: {level}") else: - level = primaite_config["log_level"] + level = primaite_config["logging"]["log_level"] print(f"PrimAITE Log Level: {level}") @@ -170,16 +175,50 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None): ldc: The lay down config file path. Optional. If no value is passed then example default lay down config is used from: - ~/primaite/config/example_config/lay_down/lay_down_config_5_data_manipulation.yaml. + ~/primaite/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml. """ - from primaite.main import run + from primaite.config.lay_down_config import dos_very_basic_config_path from primaite.config.training_config import main_training_config_path - from primaite.config.lay_down_config import data_manipulation_config_path + from primaite.main import run if not tc: tc = main_training_config_path() if not ldc: - ldc = data_manipulation_config_path() + ldc = dos_very_basic_config_path() run(training_config_path=tc, lay_down_config_path=ldc) + + +@app.command() +def plotly_template( + template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None +): + """ + View or set the plotly template for Session plots. + + To View, simply call: primaite plotly-template + + To set, call: primaite plotly-template + + For example, to set as plotly_dark, call: primaite plotly-template PLOTLY_DARK + """ + app_dirs = PlatformDirs(appname="primaite") + app_dirs.user_config_path.mkdir(exist_ok=True, parents=True) + user_config_path = app_dirs.user_config_path / "primaite_config.yaml" + if user_config_path.exists(): + with open(user_config_path, "r") as file: + primaite_config = yaml.safe_load(file) + + if template: + primaite_config["session"]["outputs"]["plots"][ + "template" + ] = template.value + with open(user_config_path, "w") as file: + yaml.dump(primaite_config, file) + print(f"PrimAITE plotly template: {template.value}") + else: + template = primaite_config["session"]["outputs"]["plots"][ + "template" + ] + print(f"PrimAITE plotly template: {template}") diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 6a93e1b5..a363a1a0 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -83,6 +83,7 @@ class Protocol(Enum): class SessionType(Enum): """The type of PrimAITE Session to be run.""" + TRAIN = 1 "Train an agent" EVAL = 2 @@ -93,6 +94,7 @@ class SessionType(Enum): class VerboseLevel(IntEnum): """PrimAITE Session Output verbose level.""" + NO_OUTPUT = 0 INFO = 1 DEBUG = 2 @@ -100,6 +102,7 @@ class VerboseLevel(IntEnum): class AgentFramework(Enum): """The agent algorithm framework/package.""" + CUSTOM = 0 "Custom Agent" SB3 = 1 @@ -110,6 +113,7 @@ class AgentFramework(Enum): class DeepLearningFramework(Enum): """The deep learning framework.""" + TF = "tf" "Tensorflow" TF2 = "tf2" @@ -120,6 +124,7 @@ class DeepLearningFramework(Enum): class AgentIdentifier(Enum): """The Red Agent algo/class.""" + A2C = 1 "Advantage Actor Critic" PPO = 2 @@ -136,6 +141,7 @@ class AgentIdentifier(Enum): class HardCodedAgentView(Enum): """The view the deterministic hard-coded agent has of the environment.""" + BASIC = 1 "The current observation space only" FULL = 2 @@ -144,6 +150,7 @@ class HardCodedAgentView(Enum): class ActionType(Enum): """Action type enumeration.""" + NODE = 0 ACL = 1 ANY = 2 @@ -151,6 +158,7 @@ class ActionType(Enum): class ObservationType(Enum): """Observation type enumeration.""" + BOX = 0 MULTIDISCRETE = 1 @@ -193,6 +201,7 @@ class LinkStatus(Enum): class OutputVerboseLevel(IntEnum): """The Agent output verbosity level.""" + NONE = 0 "No Output" INFO = 1 diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index 3cccbcae..cc5d4955 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -35,10 +35,10 @@ hard_coded_agent_view: FULL # "NODE" # "ACL" # "ANY" node and acl actions -action_type: NODE +action_type: ANY # Number of episodes to run per session -num_episodes: 1000 +num_episodes: 10 # Number of time_steps per episode num_steps: 256 @@ -47,14 +47,14 @@ num_steps: 256 # Set to 0 if no checkpoints are required. Default is 10 checkpoint_every_n_episodes: 10 -# Time delay between steps (for generic agents) +# Time delay (milliseconds) between steps for CUSTOM agents. time_delay: 5 # Type of session to be run. Options are: # "TRAIN" (Trains an agent) # "EVAL" (Evaluates an agent) # "TRAIN_EVAL" (Trains then evaluates an agent) -session_type: TRAIN +session_type: TRAIN_EVAL # Environment config values # The high value for the observation space diff --git a/src/primaite/config/lay_down_config.py b/src/primaite/config/lay_down_config.py index 49a33d6e..ae067228 100644 --- a/src/primaite/config/lay_down_config.py +++ b/src/primaite/config/lay_down_config.py @@ -1,20 +1,20 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. from pathlib import Path -from typing import Final, Union, Dict, Any +from typing import Any, Dict, Final, Union -import networkx import yaml -from primaite import USERS_CONFIG_DIR, getLogger +from primaite import getLogger, USERS_CONFIG_DIR _LOGGER = getLogger(__name__) -_EXAMPLE_LAY_DOWN: Final[ - Path] = USERS_CONFIG_DIR / "example_config" / "lay_down" +_EXAMPLE_LAY_DOWN: Final[Path] = ( + USERS_CONFIG_DIR / "example_config" / "lay_down" +) def convert_legacy_lay_down_config_dict( - legacy_config_dict: Dict[str, Any] + legacy_config_dict: Dict[str, Any] ) -> Dict[str, Any]: """ Convert a legacy lay down config dict to the new format. @@ -25,10 +25,7 @@ def convert_legacy_lay_down_config_dict( return legacy_config_dict -def load( - file_path: Union[str, Path], - legacy_file: bool = False -) -> Dict: +def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict: """ Read in a lay down config yaml file. diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 72b5523a..84dd3cc8 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -7,15 +7,22 @@ from typing import Any, Dict, Final, Optional, Union import yaml -from primaite import USERS_CONFIG_DIR, getLogger -from primaite.common.enums import DeepLearningFramework, HardCodedAgentView -from primaite.common.enums import ActionType, AgentIdentifier, \ - AgentFramework, SessionType, OutputVerboseLevel +from primaite import getLogger, USERS_CONFIG_DIR +from primaite.common.enums import ( + ActionType, + AgentFramework, + AgentIdentifier, + DeepLearningFramework, + HardCodedAgentView, + OutputVerboseLevel, + SessionType, +) _LOGGER = getLogger(__name__) -_EXAMPLE_TRAINING: Final[ - Path] = USERS_CONFIG_DIR / "example_config" / "training" +_EXAMPLE_TRAINING: Final[Path] = ( + USERS_CONFIG_DIR / "example_config" / "training" +) def main_training_config_path() -> Path: @@ -36,6 +43,7 @@ def main_training_config_path() -> Path: @dataclass() class TrainingConfig: """The Training Config class.""" + agent_framework: AgentFramework = AgentFramework.SB3 "The AgentFramework" @@ -171,12 +179,16 @@ class TrainingConfig: file_system_scanning_limit: int = 5 "The time taken to scan the file system" - @classmethod def from_dict( - cls, - config_dict: Dict[str, Union[str, int, bool]] + cls, config_dict: Dict[str, Union[str, int, bool]] ) -> TrainingConfig: + """ + Create an instance of TrainingConfig from a dict. + + :param config_dict: The training config dict. + :return: The instance of TrainingConfig. + """ field_enum_map = { "agent_framework": AgentFramework, "deep_learning_framework": DeepLearningFramework, @@ -187,9 +199,9 @@ class TrainingConfig: "hard_coded_agent_view": HardCodedAgentView, } - for field, enum_class in field_enum_map.items(): - if field in config_dict: - config_dict[field] = enum_class[config_dict[field]] + for key, value in field_enum_map.items(): + if key in config_dict: + config_dict[key] = value[config_dict[key]] return TrainingConfig(**config_dict) def to_dict(self, json_serializable: bool = True): @@ -213,23 +225,21 @@ class TrainingConfig: return data def __str__(self) -> str: - tc = f"TrainingConfig(agent_framework={self.agent_framework.name}, " + tc = f"{self.agent_framework}, " if self.agent_framework is AgentFramework.RLLIB: - tc += f"deep_learning_framework=" \ - f"{self.deep_learning_framework.name}, " - tc += f"agent_identifier={self.agent_identifier.name}, " + tc += f"{self.deep_learning_framework}, " + tc += f"{self.agent_identifier}, " if self.agent_identifier is AgentIdentifier.HARDCODED: - tc += f"hard_coded_agent_view={self.hard_coded_agent_view.name}, " - tc += f"action_type={self.action_type.name}, " + tc += f"{self.hard_coded_agent_view}, " + tc += f"{self.action_type}, " tc += f"observation_space={self.observation_space}, " - tc += f"num_episodes={self.num_episodes}, " - tc += f"num_steps={self.num_steps})" + tc += f"{self.num_episodes} episodes @ " + tc += f"{self.num_steps} steps" return tc def load( - file_path: Union[str, Path], - legacy_file: bool = False + file_path: Union[str, Path], legacy_file: bool = False ) -> TrainingConfig: """ Read in a training config yaml file. @@ -273,12 +283,12 @@ def load( def convert_legacy_training_config_dict( - legacy_config_dict: Dict[str, Any], - agent_framework: AgentFramework = AgentFramework.SB3, - agent_identifier: AgentIdentifier = AgentIdentifier.PPO, - action_type: ActionType = ActionType.ANY, - num_steps: int = 256, - output_verbose_level: OutputVerboseLevel = OutputVerboseLevel.INFO + legacy_config_dict: Dict[str, Any], + agent_framework: AgentFramework = AgentFramework.SB3, + agent_identifier: AgentIdentifier = AgentIdentifier.PPO, + action_type: ActionType = ActionType.ANY, + num_steps: int = 256, + output_verbose_level: OutputVerboseLevel = OutputVerboseLevel.INFO, ) -> Dict[str, Any]: """ Convert a legacy training config dict to the new format. @@ -301,8 +311,12 @@ def convert_legacy_training_config_dict( "agent_identifier": agent_identifier.name, "action_type": action_type.name, "num_steps": num_steps, - "output_verbose_level": output_verbose_level + "output_verbose_level": output_verbose_level.name, } + session_type_map = {"TRAINING": "TRAIN", "EVALUATION": "EVAL"} + legacy_config_dict["sessionType"] = session_type_map[ + legacy_config_dict["sessionType"] + ] for legacy_key, value in legacy_config_dict.items(): new_key = _get_new_key_from_legacy(legacy_key) if new_key: diff --git a/src/primaite/data_viz/__init__.py b/src/primaite/data_viz/__init__.py new file mode 100644 index 00000000..a7cc3e8b --- /dev/null +++ b/src/primaite/data_viz/__init__.py @@ -0,0 +1,13 @@ +from enum import Enum + + +class PlotlyTemplate(Enum): + """The built-in plotly templates.""" + + PLOTLY = "plotly" + PLOTLY_WHITE = "plotly_white" + PLOTLY_DARK = "plotly_dark" + GGPLOT2 = "ggplot2" + SEABORN = "seaborn" + SIMPLE_WHITE = "simple_white" + NONE = "none" diff --git a/src/primaite/data_viz/session_plots.py b/src/primaite/data_viz/session_plots.py new file mode 100644 index 00000000..245b9774 --- /dev/null +++ b/src/primaite/data_viz/session_plots.py @@ -0,0 +1,73 @@ +from pathlib import Path +from typing import Dict, Optional, Union + +import plotly.graph_objects as go +import polars as pl +import yaml +from plotly.graph_objs import Figure + +from primaite import _PLATFORM_DIRS + + +def _get_plotly_config() -> Dict: + """Get the plotly config from primaite_config.yaml.""" + user_config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml" + with open(user_config_path, "r") as file: + primaite_config = yaml.safe_load(file) + return primaite_config["session"]["outputs"]["plots"] + + +def plot_av_reward_per_episode( + av_reward_per_episode_csv: Union[str, Path], + title: Optional[str] = None, + subtitle: Optional[str] = None, +) -> Figure: + """ + Plot the average reward per episode from a csv session output. + + :param av_reward_per_episode_csv: The average reward per episode csv + file path. + :param title: The plot title. This is optional. + :param subtitle: The plot subtitle. This is optional. + :return: The plot as an instance of ``plotly.graph_objs._figure.Figure``. + """ + df = pl.read_csv(av_reward_per_episode_csv) + + if title: + if subtitle: + title = f"{title}
{subtitle}" + else: + if subtitle: + title = subtitle + + config = _get_plotly_config() + layout = go.Layout( + autosize=config["size"]["auto_size"], + width=config["size"]["width"], + height=config["size"]["height"], + ) + # Create the line graph with a colored line + fig = go.Figure(layout=layout) + fig.update_layout(template=config["template"]) + fig.add_trace( + go.Scatter( + x=df["Episode"], + y=df["Average Reward"], + mode="lines", + name="Mean Reward per Episode", + ) + ) + + # Set the layout of the graph + fig.update_layout( + xaxis={ + "title": "Episode", + "type": "linear", + "rangeslider": {"visible": config["range_slider"]}, + }, + yaxis={"title": "Average Reward"}, + title=title, + showlegend=False, + ) + + return fig diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 9e71ef1b..6893125e 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -1,7 +1,7 @@ """Module for handling configurable observation spaces in PrimAITE.""" import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Dict, Final, List, Tuple, Union +from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union import numpy as np from gym import spaces @@ -77,7 +77,9 @@ class NodeLinkTable(AbstractObservationComponent): ) # 3. Initialise Observation with zeroes - self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE) + self.current_observation = np.zeros( + observation_shape, dtype=self._DATA_TYPE + ) def update(self): """Update the observation based on current environment state. @@ -92,7 +94,9 @@ class NodeLinkTable(AbstractObservationComponent): self.current_observation[item_index][0] = int(node.node_id) self.current_observation[item_index][1] = node.hardware_state.value if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): - self.current_observation[item_index][2] = node.software_state.value + self.current_observation[item_index][ + 2 + ] = node.software_state.value self.current_observation[item_index][ 3 ] = node.file_system_state_observed.value @@ -199,9 +203,16 @@ class NodeStatuses(AbstractObservationComponent): if isinstance(node, ServiceNode): for i, service in enumerate(self.env.services_list): if node.has_service(service): - service_states[i] = node.get_service_state(service).value + service_states[i] = node.get_service_state( + service + ).value obs.extend( - [hardware_state, software_state, file_system_state, *service_states] + [ + hardware_state, + software_state, + file_system_state, + *service_states, + ] ) self.current_observation[:] = obs @@ -259,7 +270,9 @@ class LinkTrafficLevels(AbstractObservationComponent): # 1. Define the shape of your observation space component shape = ( - [self._quantisation_levels] * self.env.num_links * self._entries_per_link + [self._quantisation_levels] + * self.env.num_links + * self._entries_per_link ) # 2. Create Observation space @@ -279,7 +292,9 @@ class LinkTrafficLevels(AbstractObservationComponent): if self._combine_service_traffic: loads = [link.get_current_load()] else: - loads = [protocol.get_load() for protocol in link.protocol_list] + loads = [ + protocol.get_load() for protocol in link.protocol_list + ] for load in loads: if load <= 0: diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index e43dc8a5..ea8f82d4 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -2,7 +2,7 @@ """Main environment module containing the PRIMmary AI Training Evironment (Primaite) class.""" import copy from pathlib import Path -from typing import Dict, Tuple, Union, Final +from typing import Dict, Final, Tuple, Union import networkx as nx import numpy as np @@ -12,8 +12,10 @@ from matplotlib import pyplot as plt from primaite import getLogger from primaite.acl.access_control_list import AccessControlList -from primaite.agents.utils import is_valid_acl_action_extra, \ - is_valid_node_action +from primaite.agents.utils import ( + is_valid_acl_action_extra, + is_valid_node_action, +) from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( ActionType, @@ -24,7 +26,8 @@ from primaite.common.enums import ( NodeType, ObservationType, Priority, - SoftwareState, SessionType, + SessionType, + SoftwareState, ) from primaite.common.service import Service from primaite.config import training_config @@ -34,15 +37,18 @@ from primaite.environment.reward import calculate_reward_function from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode from primaite.nodes.node import Node -from primaite.nodes.node_state_instruction_green import \ - NodeStateInstructionGreen +from primaite.nodes.node_state_instruction_green import ( + NodeStateInstructionGreen, +) from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from primaite.nodes.passive_node import PassiveNode from primaite.nodes.service_node import ServiceNode from primaite.pol.green_pol import apply_iers, apply_node_pol from primaite.pol.ier import IER -from primaite.pol.red_agent_pol import apply_red_agent_iers, \ - apply_red_agent_node_pol +from primaite.pol.red_agent_pol import ( + apply_red_agent_iers, + apply_red_agent_node_pol, +) from primaite.transactions.transaction import Transaction from primaite.utils.session_output_writer import SessionOutputWriter @@ -59,11 +65,11 @@ class Primaite(Env): ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2 def __init__( - self, - training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path], - session_path: Path, - timestamp_str: str, + self, + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + session_path: Path, + timestamp_str: str, ): """ The Primaite constructor. @@ -237,27 +243,19 @@ class Primaite(Env): ) self.episode_av_reward_writer = SessionOutputWriter( - self, - transaction_writer=False, - learning_session=True + self, transaction_writer=False, learning_session=True ) self.transaction_writer = SessionOutputWriter( - self, - transaction_writer=True, - learning_session=True + self, transaction_writer=True, learning_session=True ) def set_as_eval(self): """Set the writers to write to eval directories.""" self.episode_av_reward_writer = SessionOutputWriter( - self, - transaction_writer=False, - learning_session=False + self, transaction_writer=False, learning_session=False ) self.transaction_writer = SessionOutputWriter( - self, - transaction_writer=True, - learning_session=False + self, transaction_writer=True, learning_session=False ) self.episode_count = 0 self.step_count = 0 @@ -322,9 +320,7 @@ class Primaite(Env): # Create a Transaction (metric) object for this step transaction = Transaction( - self.agent_identifier, - self.episode_count, - self.step_count + self.agent_identifier, self.episode_count, self.step_count ) # Load the initial observation space into the transaction transaction.obs_space_pre = copy.deepcopy(self.env_obs) @@ -354,8 +350,9 @@ class Primaite(Env): self.nodes_post_pol = copy.deepcopy(self.nodes) self.links_post_pol = copy.deepcopy(self.links) # Reference - apply_node_pol(self.nodes_reference, self.node_pol, - self.step_count) # Node PoL + apply_node_pol( + self.nodes_reference, self.node_pol, self.step_count + ) # Node PoL apply_iers( self.network_reference, self.nodes_reference, @@ -404,8 +401,10 @@ class Primaite(Env): # For evaluation, need to trigger the done value = True when # step count is reached in order to prevent neverending episode done = True - _LOGGER.info(f"Episode: {self.episode_count}, " - f"Average Reward: {self.average_reward}") + _LOGGER.info( + f"Episode: {self.episode_count}, " + f"Average Reward: {self.average_reward}" + ) # Load the reward into the transaction transaction.reward = reward @@ -452,11 +451,11 @@ class Primaite(Env): elif self.training_config.action_type == ActionType.ACL: self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 6 + len(self.action_dict[_action]) == 6 ): # ACL actions in multidiscrete form have len 6 self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 4 + len(self.action_dict[_action]) == 4 ): # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: @@ -525,7 +524,7 @@ class Primaite(Env): # Patch (valid action if it's good or compromised) node.set_service_state( self.services_list[service_index], - SoftwareState.PATCHING + SoftwareState.PATCHING, ) else: # Node is not of Service Type @@ -542,7 +541,10 @@ class Primaite(Env): elif property_action == 2: # Repair # You cannot repair a destroyed file system - it needs restoring - if node.file_system_state_actual != FileSystemState.DESTROYED: + if ( + node.file_system_state_actual + != FileSystemState.DESTROYED + ): node.set_file_system_state(FileSystemState.REPAIRING) elif property_action == 3: # Restore @@ -585,8 +587,9 @@ class Primaite(Env): acl_rule_source = "ANY" else: node = list(self.nodes.values())[action_source_ip - 1] - if isinstance(node, ServiceNode) or isinstance(node, - ActiveNode): + if isinstance(node, ServiceNode) or isinstance( + node, ActiveNode + ): acl_rule_source = node.ip_address else: return @@ -595,8 +598,9 @@ class Primaite(Env): acl_rule_destination = "ANY" else: node = list(self.nodes.values())[action_destination_ip - 1] - if isinstance(node, ServiceNode) or isinstance(node, - ActiveNode): + if isinstance(node, ServiceNode) or isinstance( + node, ActiveNode + ): acl_rule_destination = node.ip_address else: return @@ -681,8 +685,9 @@ class Primaite(Env): :return: The observation space, initial observation (zeroed out array with the correct shape) :rtype: Tuple[spaces.Space, np.ndarray] """ - self.obs_handler = ObservationsHandler.from_config(self, - self.obs_config) + self.obs_handler = ObservationsHandler.from_config( + self, self.obs_config + ) return self.obs_handler.space, self.obs_handler.current_observation @@ -790,7 +795,8 @@ class Primaite(Env): service_port = service["port"] service_state = SoftwareState[service["state"]] node.add_service( - Service(service_protocol, service_port, service_state)) + Service(service_protocol, service_port, service_state) + ) else: # Bad formatting pass @@ -843,8 +849,9 @@ class Primaite(Env): dest_node_ref: Node = self.nodes_reference[link_destination] # Add link to network (reference) - self.network_reference.add_edge(source_node_ref, dest_node_ref, - id=link_name) + self.network_reference.add_edge( + source_node_ref, dest_node_ref, id=link_name + ) # Add link to link dictionary (reference) self.links_reference[link_name] = Link( @@ -1120,7 +1127,8 @@ class Primaite(Env): node_id = item["node_id"] node_class = item["node_class"] node_hardware_state: HardwareState = HardwareState[ - item["hardware_state"]] + item["hardware_state"] + ] node: NodeUnion = self.nodes[node_id] node_ref = self.nodes_reference[node_id] @@ -1186,8 +1194,12 @@ class Primaite(Env): # Use MAX to ensure we get them all for node_action in range(4): for service_state in range(self.num_services): - action = [node, node_property, node_action, - service_state] + action = [ + node, + node_property, + node_action, + service_state, + ] # check to see if it's a nothing action (has no effect) if is_valid_node_action(action): actions[action_key] = action diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index 00e45fa3..4dd0550e 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -46,7 +46,9 @@ def calculate_reward_function( ) # Software State - if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode): + if isinstance(final_node, ActiveNode) or isinstance( + final_node, ServiceNode + ): reward_value += score_node_os_state( final_node, initial_node, reference_node, config_values ) @@ -81,7 +83,8 @@ def calculate_reward_function( reference_blocked = not reference_ier.get_is_running() live_blocked = not ier_value.get_is_running() ier_reward = ( - config_values.green_ier_blocked * ier_value.get_mission_criticality() + config_values.green_ier_blocked + * ier_value.get_mission_criticality() ) if live_blocked and not reference_blocked: @@ -104,7 +107,9 @@ def calculate_reward_function( return reward_value -def score_node_operating_state(final_node, initial_node, reference_node, config_values): +def score_node_operating_state( + final_node, initial_node, reference_node, config_values +): """ Calculates score relating to the hardware state of a node. @@ -153,7 +158,9 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_ return score -def score_node_os_state(final_node, initial_node, reference_node, config_values): +def score_node_os_state( + final_node, initial_node, reference_node, config_values +): """ Calculates score relating to the Software State of a node. @@ -204,7 +211,9 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values) return score -def score_node_service_state(final_node, initial_node, reference_node, config_values): +def score_node_service_state( + final_node, initial_node, reference_node, config_values +): """ Calculates score relating to the service state(s) of a node. @@ -276,7 +285,9 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va return score -def score_node_file_system(final_node, initial_node, reference_node, config_values): +def score_node_file_system( + final_node, initial_node, reference_node, config_values +): """ Calculates score relating to the file system state of a node. diff --git a/src/primaite/links/link.py b/src/primaite/links/link.py index 90235e9f..054f4c34 100644 --- a/src/primaite/links/link.py +++ b/src/primaite/links/link.py @@ -8,7 +8,9 @@ from primaite.common.protocol import Protocol class Link(object): """Link class.""" - def __init__(self, _id, _bandwidth, _source_node_name, _dest_node_name, _services): + def __init__( + self, _id, _bandwidth, _source_node_name, _dest_node_name, _services + ): """ Init. diff --git a/src/primaite/main.py b/src/primaite/main.py index 3c0f93b3..556c5ec3 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -10,7 +10,10 @@ from primaite.primaite_session import PrimaiteSession _LOGGER = getLogger(__name__) -def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]): +def run( + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], +): """Run the PrimAITE Session. :param training_config_path: The training config filepath. diff --git a/src/primaite/nodes/active_node.py b/src/primaite/nodes/active_node.py index 57fa4c68..b1c3f57c 100644 --- a/src/primaite/nodes/active_node.py +++ b/src/primaite/nodes/active_node.py @@ -87,7 +87,9 @@ class ActiveNode(Node): f"Node.software_state:{self._software_state}" ) - def set_software_state_if_not_compromised(self, software_state: SoftwareState): + def set_software_state_if_not_compromised( + self, software_state: SoftwareState + ): """ Sets Software State if the node is not compromised. @@ -98,7 +100,9 @@ class ActiveNode(Node): if self._software_state != SoftwareState.COMPROMISED: self._software_state = software_state if software_state == SoftwareState.PATCHING: - self.patching_count = self.config_values.os_patching_duration + self.patching_count = ( + self.config_values.os_patching_duration + ) else: _LOGGER.info( f"The Nodes hardware state is OFF so OS State cannot be changed." @@ -187,7 +191,9 @@ class ActiveNode(Node): def start_file_system_scan(self): """Starts a file system scan.""" self.file_system_scanning = True - self.file_system_scanning_count = self.config_values.file_system_scanning_limit + self.file_system_scanning_count = ( + self.config_values.file_system_scanning_limit + ) def update_file_system_state(self): """Updates file system status based on scanning/restore/repair cycle.""" @@ -206,7 +212,10 @@ class ActiveNode(Node): self.file_system_state_observed = FileSystemState.GOOD # Scanning updates - if self.file_system_scanning == True and self.file_system_scanning_count < 0: + if ( + self.file_system_scanning == True + and self.file_system_scanning_count < 0 + ): self.file_system_state_observed = self.file_system_state_actual self.file_system_scanning = False self.file_system_scanning_count = 0 diff --git a/src/primaite/nodes/node_state_instruction_green.py b/src/primaite/nodes/node_state_instruction_green.py index 2b1d94be..04681807 100644 --- a/src/primaite/nodes/node_state_instruction_green.py +++ b/src/primaite/nodes/node_state_instruction_green.py @@ -32,7 +32,9 @@ class NodeStateInstructionGreen(object): self.end_step = _end_step self.node_id = _node_id self.node_pol_type = _node_pol_type - self.service_name = _service_name # Not used when not a service instruction + self.service_name = ( + _service_name # Not used when not a service instruction + ) self.state = _state def get_start_step(self): diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py index 7f62fe24..ba35067c 100644 --- a/src/primaite/nodes/node_state_instruction_red.py +++ b/src/primaite/nodes/node_state_instruction_red.py @@ -42,7 +42,9 @@ class NodeStateInstructionRed(object): self.target_node_id = _target_node_id self.initiator = _pol_initiator self.pol_type: NodePOLType = _pol_type - self.service_name = pol_protocol # Not used when not a service instruction + self.service_name = ( + pol_protocol # Not used when not a service instruction + ) self.state = _pol_state self.source_node_id = _pol_source_node_id self.source_node_service = _pol_source_node_service diff --git a/src/primaite/nodes/service_node.py b/src/primaite/nodes/service_node.py index 324592c3..6dcff73e 100644 --- a/src/primaite/nodes/service_node.py +++ b/src/primaite/nodes/service_node.py @@ -110,7 +110,9 @@ class ServiceNode(ActiveNode): return False return False - def set_service_state(self, protocol_name: str, software_state: SoftwareState): + def set_service_state( + self, protocol_name: str, software_state: SoftwareState + ): """ Sets the software_state of a service (protocol) on the node. diff --git a/src/primaite/notebooks/__init__.py b/src/primaite/notebooks/__init__.py index 71ed343e..0e81e581 100644 --- a/src/primaite/notebooks/__init__.py +++ b/src/primaite/notebooks/__init__.py @@ -4,7 +4,7 @@ import os import subprocess import sys -from primaite import NOTEBOOKS_DIR, getLogger +from primaite import getLogger, NOTEBOOKS_DIR _LOGGER = getLogger(__name__) diff --git a/src/primaite/pol/green_pol.py b/src/primaite/pol/green_pol.py index 1d05dc3f..aeae7add 100644 --- a/src/primaite/pol/green_pol.py +++ b/src/primaite/pol/green_pol.py @@ -6,10 +6,17 @@ from networkx import MultiGraph, shortest_path from primaite.acl.access_control_list import AccessControlList from primaite.common.custom_typing import NodeUnion -from primaite.common.enums import HardwareState, NodePOLType, NodeType, SoftwareState +from primaite.common.enums import ( + HardwareState, + NodePOLType, + NodeType, + SoftwareState, +) from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode -from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen +from primaite.nodes.node_state_instruction_green import ( + NodeStateInstructionGreen, +) from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from primaite.nodes.service_node import ServiceNode from primaite.pol.ier import IER @@ -190,7 +197,9 @@ def apply_iers( link_id = edge_dict[0].get("id") link = links[link_id] # Check whether the new load exceeds the bandwidth - if (link.get_current_load() + load) > link.get_bandwidth(): + if ( + link.get_current_load() + load + ) > link.get_bandwidth(): link_capacity_exceeded = True if _VERBOSE: print("Link capacity exceeded") @@ -204,7 +213,8 @@ def apply_iers( while count < path_node_list_length - 1: # Get the link between the next two nodes edge_dict = network.get_edge_data( - path_node_list[count], path_node_list[count + 1] + path_node_list[count], + path_node_list[count + 1], ) link_id = edge_dict[0].get("id") link = links[link_id] @@ -216,7 +226,9 @@ def apply_iers( else: # One of the nodes is not operational if _VERBOSE: - print("Path not valid - one or more nodes not operational") + print( + "Path not valid - one or more nodes not operational" + ) pass else: @@ -231,7 +243,9 @@ def apply_iers( def apply_node_pol( nodes: Dict[str, NodeUnion], - node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]], + node_pol: Dict[ + any, Union[NodeStateInstructionGreen, NodeStateInstructionRed] + ], step: int, ): """ @@ -263,16 +277,22 @@ def apply_node_pol( elif node_pol_type == NodePOLType.OS: # Change OS state # Don't allow PoL to fix something that is compromised. Only the Blue agent can do this - if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): + if isinstance(node, ActiveNode) or isinstance( + node, ServiceNode + ): node.set_software_state_if_not_compromised(state) elif node_pol_type == NodePOLType.SERVICE: # Change a service state # Don't allow PoL to fix something that is compromised. Only the Blue agent can do this if isinstance(node, ServiceNode): - node.set_service_state_if_not_compromised(service_name, state) + node.set_service_state_if_not_compromised( + service_name, state + ) else: # Change the file system status - if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): + if isinstance(node, ActiveNode) or isinstance( + node, ServiceNode + ): node.set_file_system_state_if_not_compromised(state) else: # PoL is not valid in this time step diff --git a/src/primaite/pol/red_agent_pol.py b/src/primaite/pol/red_agent_pol.py index b23992e7..96fe787c 100644 --- a/src/primaite/pol/red_agent_pol.py +++ b/src/primaite/pol/red_agent_pol.py @@ -176,7 +176,9 @@ def apply_red_agent_iers( link_id = edge_dict[0].get("id") link = links[link_id] # Check whether the new load exceeds the bandwidth - if (link.get_current_load() + load) > link.get_bandwidth(): + if ( + link.get_current_load() + load + ) > link.get_bandwidth(): link_capacity_exceeded = True if _VERBOSE: print("Link capacity exceeded") @@ -190,7 +192,8 @@ def apply_red_agent_iers( while count < path_node_list_length - 1: # Get the link between the next two nodes edge_dict = network.get_edge_data( - path_node_list[count], path_node_list[count + 1] + path_node_list[count], + path_node_list[count + 1], ) link_id = edge_dict[0].get("id") link = links[link_id] @@ -200,16 +203,23 @@ def apply_red_agent_iers( # This IER is now valid, so set it to running ier_value.set_is_running(True) if _VERBOSE: - print("Red IER was allowed to run in step " + str(step)) + print( + "Red IER was allowed to run in step " + + str(step) + ) else: # One of the nodes is not operational if _VERBOSE: - print("Path not valid - one or more nodes not operational") + print( + "Path not valid - one or more nodes not operational" + ) pass else: if _VERBOSE: - print("Red IER was NOT allowed to run in step " + str(step)) + print( + "Red IER was NOT allowed to run in step " + str(step) + ) print("Source, Dest or ACL were not valid") pass # ------------------------------------ @@ -264,7 +274,9 @@ def apply_red_agent_node_pol( passed_checks = True elif initiator == NodePOLInitiator.IER: # Need to check there is a red IER incoming - passed_checks = is_red_ier_incoming(target_node, iers, pol_type) + passed_checks = is_red_ier_incoming( + target_node, iers, pol_type + ) elif initiator == NodePOLInitiator.SERVICE: # Need to check the condition of a service on another node source_node = nodes[source_node_id] @@ -308,7 +320,9 @@ def apply_red_agent_node_pol( target_node.set_file_system_state(state) else: if _VERBOSE: - print("Node Red Agent PoL not allowed - did not pass checks") + print( + "Node Red Agent PoL not allowed - did not pass checks" + ) else: # PoL is not valid in this time step pass @@ -323,7 +337,10 @@ def is_red_ier_incoming(node, iers, node_pol_type): node_id = node.node_id for ier_key, ier_value in iers.items(): - if ier_value.get_is_running() and ier_value.get_dest_node_id() == node_id: + if ( + ier_value.get_is_running() + and ier_value.get_dest_node_id() == node_id + ): if ( node_pol_type == NodePOLType.OPERATING or node_pol_type == NodePOLType.OS diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index cd959be0..4d8d3022 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -1,54 +1,51 @@ from __future__ import annotations -import json -from datetime import datetime from pathlib import Path -from typing import Final, Optional, Union, Dict -from uuid import uuid4 +from typing import Dict, Final, Optional, Union -from primaite import getLogger, SESSIONS_DIR +from primaite import getLogger from primaite.agents.agent import AgentSessionABC from primaite.agents.hardcoded_acl import HardCodedACLAgent from primaite.agents.hardcoded_node import HardCodedNodeAgent from primaite.agents.rllib import RLlibAgent from primaite.agents.sb3 import SB3Agent -from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, \ - RandomAgent, DummyAgent -from primaite.common.enums import AgentFramework, AgentIdentifier, \ - ActionType, SessionType +from primaite.agents.simple import ( + DoNothingACLAgent, + DoNothingNodeAgent, + DummyAgent, + RandomAgent, +) +from primaite.common.enums import ( + ActionType, + AgentFramework, + AgentIdentifier, + SessionType, +) from primaite.config import lay_down_config, training_config from primaite.config.training_config import TrainingConfig -from primaite.environment.primaite_env import Primaite _LOGGER = getLogger(__name__) -def _get_session_path(session_timestamp: datetime) -> Path: - """ - Get the directory path the session will output to. - - This is set in the format of: - ~/primaite/sessions//_. - - :param session_timestamp: This is the datetime that the session started. - :return: The session directory path. - """ - date_dir = session_timestamp.strftime("%Y-%m-%d") - session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = SESSIONS_DIR / date_dir / session_path - session_path.mkdir(exist_ok=True, parents=True) - _LOGGER.debug(f"Created PrimAITE Session path: {session_path}") - - return session_path - - class PrimaiteSession: + """ + The PrimaiteSession class. + + Provides a single learning and evaluation entry point for all training + and lay down configurations. + """ def __init__( - self, - training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path] + self, + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], ): + """ + The PrimaiteSession constructor. + + :param training_config_path: The training config path. + :param lay_down_config_path: The lay down config path. + """ if not isinstance(training_config_path, Path): training_config_path = Path(training_config_path) self._training_config_path: Final[Union[Path]] = training_config_path @@ -64,22 +61,35 @@ class PrimaiteSession: ) self._agent_session: AgentSessionABC = None # noqa + self.session_path: Path = None # noqa + self.timestamp_str: str = None # noqa + self.learning_path: Path = None # noqa + self.evaluation_path: Path = None # noqa def setup(self): + """Performs the session setup.""" if self._training_config.agent_framework == AgentFramework.CUSTOM: - if self._training_config.agent_identifier == AgentIdentifier.HARDCODED: + _LOGGER.debug( + f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}" + ) + if ( + self._training_config.agent_identifier + == AgentIdentifier.HARDCODED + ): + _LOGGER.debug( + f"PrimaiteSession Setup: Agent Identifier =" + f" {AgentIdentifier.HARDCODED}" + ) if self._training_config.action_type == ActionType.NODE: # Deterministic Hardcoded Agent with Node Action Space self._agent_session = HardCodedNodeAgent( - self._training_config_path, - self._lay_down_config_path + self._training_config_path, self._lay_down_config_path ) elif self._training_config.action_type == ActionType.ACL: # Deterministic Hardcoded Agent with ACL Action Space self._agent_session = HardCodedACLAgent( - self._training_config_path, - self._lay_down_config_path + self._training_config_path, self._lay_down_config_path ) elif self._training_config.action_type == ActionType.ANY: @@ -90,18 +100,23 @@ class PrimaiteSession: # Invalid AgentIdentifier ActionType combo raise ValueError - elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING: + elif ( + self._training_config.agent_identifier + == AgentIdentifier.DO_NOTHING + ): + _LOGGER.debug( + f"PrimaiteSession Setup: Agent Identifier =" + f" {AgentIdentifier.DO_NOTHINGD}" + ) if self._training_config.action_type == ActionType.NODE: self._agent_session = DoNothingNodeAgent( - self._training_config_path, - self._lay_down_config_path + self._training_config_path, self._lay_down_config_path ) elif self._training_config.action_type == ActionType.ACL: # Deterministic Hardcoded Agent with ACL Action Space self._agent_session = DoNothingACLAgent( - self._training_config_path, - self._lay_down_config_path + self._training_config_path, self._lay_down_config_path ) elif self._training_config.action_type == ActionType.ANY: @@ -112,15 +127,26 @@ class PrimaiteSession: # Invalid AgentIdentifier ActionType combo raise ValueError - elif self._training_config.agent_identifier == AgentIdentifier.RANDOM: - self._agent_session = RandomAgent( - self._training_config_path, - self._lay_down_config_path + elif ( + self._training_config.agent_identifier + == AgentIdentifier.RANDOM + ): + _LOGGER.debug( + f"PrimaiteSession Setup: Agent Identifier =" + f" {AgentIdentifier.RANDOM}" + ) + self._agent_session = RandomAgent( + self._training_config_path, self._lay_down_config_path + ) + elif ( + self._training_config.agent_identifier == AgentIdentifier.DUMMY + ): + _LOGGER.debug( + f"PrimaiteSession Setup: Agent Identifier =" + f" {AgentIdentifier.DUMMY}" ) - elif self._training_config.agent_identifier == AgentIdentifier.DUMMY: self._agent_session = DummyAgent( - self._training_config_path, - self._lay_down_config_path + self._training_config_path, self._lay_down_config_path ) else: @@ -128,37 +154,64 @@ class PrimaiteSession: raise ValueError elif self._training_config.agent_framework == AgentFramework.SB3: + _LOGGER.debug( + f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}" + ) # Stable Baselines3 Agent self._agent_session = SB3Agent( - self._training_config_path, - self._lay_down_config_path + self._training_config_path, self._lay_down_config_path ) elif self._training_config.agent_framework == AgentFramework.RLLIB: + _LOGGER.debug( + f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}" + ) # Ray RLlib Agent self._agent_session = RLlibAgent( - self._training_config_path, - self._lay_down_config_path + self._training_config_path, self._lay_down_config_path ) else: # Invalid AgentFramework raise ValueError + self.session_path: Path = self._agent_session.session_path + self.timestamp_str: str = self._agent_session.timestamp_str + self.learning_path: Path = self._agent_session.learning_path + self.evaluation_path: Path = self._agent_session.evaluation_path + def learn( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, - **kwargs + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + **kwargs, ): + """ + Train the agent. + + :param time_steps: The number of time steps per episode. + :param episodes: The number of episodes. + :param kwargs: Any agent-framework specific key word args. + """ if not self._training_config.session_type == SessionType.EVAL: self._agent_session.learn(time_steps, episodes, **kwargs) def evaluate( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, - **kwargs + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + **kwargs, ): + """ + Evaluate the agent. + + :param time_steps: The number of time steps per episode. + :param episodes: The number of episodes. + :param kwargs: Any agent-framework specific key word args. + """ if not self._training_config.session_type == SessionType.TRAIN: self._agent_session.evaluate(time_steps, episodes, **kwargs) + + def close(self): + """Closes the agent.""" + self._agent_session.close() diff --git a/src/primaite/setup/_package_data/primaite_config.yaml b/src/primaite/setup/_package_data/primaite_config.yaml index 1dd8775b..b9e0d73c 100644 --- a/src/primaite/setup/_package_data/primaite_config.yaml +++ b/src/primaite/setup/_package_data/primaite_config.yaml @@ -9,3 +9,14 @@ logging: WARNING: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' ERROR: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' CRITICAL: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' + +# Session +session: + outputs: + plots: + size: + auto_size: false + width: 1500 + height: 900 + template: plotly_white + range_slider: false diff --git a/src/primaite/setup/reset_demo_notebooks.py b/src/primaite/setup/reset_demo_notebooks.py index 5192c48f..59eaf8cc 100644 --- a/src/primaite/setup/reset_demo_notebooks.py +++ b/src/primaite/setup/reset_demo_notebooks.py @@ -6,7 +6,7 @@ from pathlib import Path import pkg_resources -from primaite import NOTEBOOKS_DIR, getLogger +from primaite import getLogger, NOTEBOOKS_DIR _LOGGER = getLogger(__name__) @@ -24,7 +24,9 @@ def run(overwrite_existing: bool = True): for subdir, dirs, files in os.walk(notebooks_package_data_root): for file in files: fp = os.path.join(subdir, file) - path_split = os.path.relpath(fp, notebooks_package_data_root).split(os.sep) + path_split = os.path.relpath( + fp, notebooks_package_data_root + ).split(os.sep) target_fp = NOTEBOOKS_DIR / Path(*path_split) target_fp.parent.mkdir(exist_ok=True, parents=True) copy_file = not target_fp.is_file() diff --git a/src/primaite/setup/reset_example_configs.py b/src/primaite/setup/reset_example_configs.py index f4166c6a..f2b4a18f 100644 --- a/src/primaite/setup/reset_example_configs.py +++ b/src/primaite/setup/reset_example_configs.py @@ -5,7 +5,7 @@ from pathlib import Path import pkg_resources -from primaite import USERS_CONFIG_DIR, getLogger +from primaite import getLogger, USERS_CONFIG_DIR _LOGGER = getLogger(__name__) @@ -24,7 +24,9 @@ def run(overwrite_existing=True): for subdir, dirs, files in os.walk(configs_package_data_root): for file in files: fp = os.path.join(subdir, file) - path_split = os.path.relpath(fp, configs_package_data_root).split(os.sep) + path_split = os.path.relpath(fp, configs_package_data_root).split( + os.sep + ) target_fp = USERS_CONFIG_DIR / "example_config" / Path(*path_split) target_fp.parent.mkdir(exist_ok=True, parents=True) copy_file = not target_fp.is_file() diff --git a/src/primaite/setup/setup_app_dirs.py b/src/primaite/setup/setup_app_dirs.py index 9f6e8a13..693b11c1 100644 --- a/src/primaite/setup/setup_app_dirs.py +++ b/src/primaite/setup/setup_app_dirs.py @@ -1,5 +1,5 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. -from primaite import _USER_DIRS, LOG_DIR, NOTEBOOKS_DIR, getLogger +from primaite import _USER_DIRS, getLogger, LOG_DIR, NOTEBOOKS_DIR _LOGGER = getLogger(__name__) diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index 6e5ba5f0..1a71f0ff 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -7,12 +7,7 @@ from typing import List, Tuple class Transaction(object): """Transaction class.""" - def __init__( - self, - agent_identifier, - episode_number, - step_number - ): + def __init__(self, agent_identifier, episode_number, step_number): """ Transaction constructor. @@ -37,6 +32,11 @@ class Transaction(object): "The action space invoked by the agent" def as_csv_data(self) -> Tuple[List, List]: + """ + Converts the Transaction to a csv data row and provides a header. + + :return: A tuple consisting of (header, data). + """ if isinstance(self.action_space, int): action_length = self.action_space else: @@ -74,12 +74,14 @@ class Transaction(object): str(self.reward), ] row = ( - row - + _turn_action_space_to_array(self.action_space) - + _turn_obs_space_to_array(self.obs_space_pre, obs_assets, - obs_features) - + _turn_obs_space_to_array(self.obs_space_post, obs_assets, - obs_features) + row + + _turn_action_space_to_array(self.action_space) + + _turn_obs_space_to_array( + self.obs_space_pre, obs_assets, obs_features + ) + + _turn_obs_space_to_array( + self.obs_space_post, obs_assets, obs_features + ) ) return header, row diff --git a/src/primaite/utils/session_output_reader.py b/src/primaite/utils/session_output_reader.py new file mode 100644 index 00000000..d04f375e --- /dev/null +++ b/src/primaite/utils/session_output_reader.py @@ -0,0 +1,20 @@ +from pathlib import Path +from typing import Dict, Union + +# Using polars as it's faster than Pandas; it will speed things up when +# files get big! +import polars as pl + + +def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]: + """ + Read an average rewards per episode csv file and return as a dict. + + The dictionary keys are the episode number, and the values are the mean + reward that episode. + + :param av_rewards_csv_file: The average rewards per episode csv file path. + :return: The average rewards per episode cdv as a dict. + """ + d = pl.read_csv(av_rewards_csv_file).to_dict() + return {v: d["Average Reward"][i] for i, v in enumerate(d["Episode"])} diff --git a/src/primaite/utils/session_output_writer.py b/src/primaite/utils/session_output_writer.py index 308e1fb3..86c5ca28 100644 --- a/src/primaite/utils/session_output_writer.py +++ b/src/primaite/utils/session_output_writer.py @@ -1,7 +1,6 @@ import csv from logging import Logger -from typing import List, Final, IO, Union, Tuple -from typing import TYPE_CHECKING +from typing import Final, List, Tuple, TYPE_CHECKING, Union from primaite import getLogger from primaite.transactions.transaction import Transaction @@ -13,15 +12,22 @@ _LOGGER: Logger = getLogger(__name__) class SessionOutputWriter: + """ + A session output writer class. + + Is used to write session outputs to csv file. + """ + _AV_REWARD_PER_EPISODE_HEADER: Final[List[str]] = [ - "Episode", "Average Reward" + "Episode", + "Average Reward", ] def __init__( - self, - env: "Primaite", - transaction_writer: bool = False, - learning_session: bool = True + self, + env: "Primaite", + transaction_writer: bool = False, + learning_session: bool = True, ): self._env = env self.transaction_writer = transaction_writer @@ -52,14 +58,21 @@ class SessionOutputWriter: self._csv_writer = csv.writer(self._csv_file) def __del__(self): + self.close() + + def close(self): + """Close the cvs file.""" if self._csv_file: self._csv_file.close() - _LOGGER.info(f"Finished writing file: {self._csv_file_path}") + _LOGGER.debug(f"Finished writing file: {self._csv_file_path}") - def write( - self, - data: Union[Tuple, Transaction] - ): + def write(self, data: Union[Tuple, Transaction]): + """ + Write a row of session data. + + :param data: The row of data to write. Can be a Tuple or an instance + of Transaction. + """ if isinstance(data, Transaction): header, data = data.as_csv_data() else: @@ -69,5 +82,4 @@ class SessionOutputWriter: self._init_csv_writer() self._csv_writer.writerow(header) self._first_write = False - self._csv_writer.writerow(data) diff --git a/tests/config/legacy/new_training_config.yaml b/tests/config/legacy/new_training_config.yaml index 9fdf9a05..49e6a00b 100644 --- a/tests/config/legacy/new_training_config.yaml +++ b/tests/config/legacy/new_training_config.yaml @@ -6,7 +6,7 @@ # "SB3" (Stable Baselines3) # "RLLIB" (Ray[RLlib]) # "NONE" (Custom Agent) -agent_framework: RLLIB +agent_framework: SB3 # Sets which Red Agent algo/class will be used: # "PPO" (Proximal Policy Optimization) @@ -27,7 +27,7 @@ num_steps: 256 # Time delay between steps (for generic agents) time_delay: 10 # Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING +session_type: TRAIN # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml b/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml index 67aaa9de..d26d7955 100644 --- a/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml +++ b/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml @@ -1,11 +1,22 @@ -# Main Config File +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: SB3 + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: A2C -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: STABLE_BASELINES3_A2C # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -28,7 +39,7 @@ observation_space: time_delay: 1 # Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING +session_type: TRAIN # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml b/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml index 29a89b8d..aae740b6 100644 --- a/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml +++ b/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml @@ -1,11 +1,22 @@ -# Main Config File +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: CUSTOM + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: RANDOM -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: NONE # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -24,7 +35,7 @@ observation_space: time_delay: 1 # Filename of the scenario / laydown -session_type: TRAINING +session_type: TRAIN # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/config/obs_tests/main_config_NODE_STATUSES.yaml b/tests/config/obs_tests/main_config_NODE_STATUSES.yaml index 8f2d9a38..4066eace 100644 --- a/tests/config/obs_tests/main_config_NODE_STATUSES.yaml +++ b/tests/config/obs_tests/main_config_NODE_STATUSES.yaml @@ -1,11 +1,22 @@ -# Main Config File +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: CUSTOM + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: RANDOM -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: NONE # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -25,7 +36,7 @@ observation_space: time_delay: 1 # Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING +session_type: TRAIN # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/config/obs_tests/main_config_without_obs.yaml b/tests/config/obs_tests/main_config_without_obs.yaml index e8bb49ea..08452dda 100644 --- a/tests/config/obs_tests/main_config_without_obs.yaml +++ b/tests/config/obs_tests/main_config_without_obs.yaml @@ -1,11 +1,22 @@ -# Main Config File +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: CUSTOM + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: RANDOM -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: NONE # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -18,7 +29,7 @@ num_steps: 5 # Time delay between steps (for generic agents) time_delay: 1 # Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING +session_type: TRAIN # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/config/one_node_states_on_off_main_config.yaml b/tests/config/one_node_states_on_off_main_config.yaml index 2e752bc9..7f1ced01 100644 --- a/tests/config/one_node_states_on_off_main_config.yaml +++ b/tests/config/one_node_states_on_off_main_config.yaml @@ -1,10 +1,22 @@ -# Main Config File +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: CUSTOM + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: DUMMY -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -agent_identifier: GENERIC # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -18,7 +30,7 @@ num_steps: 15 time_delay: 1 # Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING +session_type: EVAL # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml index 5c5db582..97d0ddaf 100644 --- a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml +++ b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml @@ -1,11 +1,22 @@ -# Main Config File +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: CUSTOM + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: RANDOM -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: GENERIC # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -18,7 +29,7 @@ num_steps: 15 # Time delay between steps (for generic agents) time_delay: 1 # Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING +session_type: EVAL # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/config/single_action_space_main_config.yaml b/tests/config/single_action_space_main_config.yaml index 967fdcce..067b9a6d 100644 --- a/tests/config/single_action_space_main_config.yaml +++ b/tests/config/single_action_space_main_config.yaml @@ -1,11 +1,22 @@ -# Main Config File +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: CUSTOM + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: RANDOM -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: GENERIC # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -18,7 +29,7 @@ num_steps: 5 # Time delay between steps (for generic agents) time_delay: 1 # Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING +session_type: EVAL # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/conftest.py b/tests/conftest.py index 945d23f0..41dc5e77 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,37 +1,151 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +import datetime +import shutil import tempfile import time from datetime import datetime from pathlib import Path -from typing import Union +from typing import Dict, Union +from unittest.mock import patch +import pytest + +from primaite import getLogger +from primaite.common.enums import AgentIdentifier from primaite.environment.primaite_env import Primaite +from primaite.primaite_session import PrimaiteSession +from primaite.utils.session_output_reader import av_rewards_dict +from tests.mock_and_patch.get_session_path_mock import get_temp_session_path ACTION_SPACE_NODE_VALUES = 1 ACTION_SPACE_NODE_ACTION_VALUES = 1 +_LOGGER = getLogger(__name__) -def _get_temp_session_path(session_timestamp: datetime) -> Path: + +class TempPrimaiteSession(PrimaiteSession): + """ + A temporary PrimaiteSession class. + + Uses context manager for deletion of files upon exit. + """ + + def __init__( + self, + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + ): + super().__init__(training_config_path, lay_down_config_path) + self.setup() + + def learn_av_reward_per_episode(self) -> Dict[int, float]: + """Get the learn av reward per episode from file.""" + csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" + return av_rewards_dict(self.learning_path / csv_file) + + def eval_av_reward_per_episode_csv(self) -> Dict[int, float]: + """Get the eval av reward per episode from file.""" + csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" + return av_rewards_dict(self.evaluation_path / csv_file) + + @property + def env(self) -> Primaite: + """Direct access to the env for ease of testing.""" + return self._agent_session._env # noqa + + def __enter__(self): + return self + + def __exit__(self, type, value, tb): + del self._agent_session._env.episode_av_reward_writer + del self._agent_session._env.transaction_writer + shutil.rmtree(self.session_path) + shutil.rmtree(self.session_path.parent) + _LOGGER.debug(f"Deleted temp session directory: {self.session_path}") + + +@pytest.fixture +def temp_primaite_session(request): + """ + Provides a temporary PrimaiteSession instance. + + It's temporary as it uses a temporary directory as the session path. + + To use this fixture you need to: + + - parametrize your test function with: + + - "temp_primaite_session" + - [[path to training config, path to lay down config]] + - Include the temp_primaite_session fixture as a param in your test + function. + - use the temp_primaite_session as a context manager assigning is the + name 'session'. + + .. code:: python + + from primaite.config.lay_down_config import dos_very_basic_config_path + from primaite.config.training_config import main_training_config_path + @pytest.mark.parametrize( + "temp_primaite_session", + [ + [main_training_config_path(), dos_very_basic_config_path()] + ], + indirect=True + ) + def test_primaite_session(temp_primaite_session): + with temp_primaite_session as session: + # Learning outputs are saved in session.learning_path + session.learn() + + # Evaluation outputs are saved in session.evaluation_path + session.evaluate() + + # To ensure that all files are written, you must call .close() + session.close() + + # If you need to inspect any session outputs, it must be done + # inside the context manager + + # Now that we've exited the context manager, the + # session.session_path directory and its contents are deleted + """ + training_config_path = request.param[0] + lay_down_config_path = request.param[1] + with patch( + "primaite.agents.agent.get_session_path", get_temp_session_path + ) as mck: + mck.session_timestamp = datetime.now() + + return TempPrimaiteSession(training_config_path, lay_down_config_path) + + +@pytest.fixture +def temp_session_path() -> Path: """ Get a temp directory session path the test session will output to. - :param session_timestamp: This is the datetime that the session started. :return: The session directory path. """ + session_timestamp = datetime.now() date_dir = session_timestamp.strftime("%Y-%m-%d") session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path + session_path = ( + Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path + ) session_path.mkdir(exist_ok=True, parents=True) return session_path def _get_primaite_env_from_config( - training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path] + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + temp_session_path, ): """Takes a config path and returns the created instance of Primaite.""" session_timestamp: datetime = datetime.now() - session_path = _get_temp_session_path(session_timestamp) + session_path = temp_session_path(session_timestamp) timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") env = Primaite( @@ -45,7 +159,7 @@ def _get_primaite_env_from_config( # TOOD: This needs t be refactored to happen outside. Should be part of # a main Session class. - if env.training_config.agent_identifier == "GENERIC": + if env.training_config.agent_identifier is AgentIdentifier.RANDOM: run_generic(env, config_values) return env diff --git a/tests/e2e_integration_tests/test_primaite_main.py b/tests/e2e_integration_tests/test_primaite_main.py deleted file mode 100644 index b457557a..00000000 --- a/tests/e2e_integration_tests/test_primaite_main.py +++ /dev/null @@ -1,8 +0,0 @@ -from primaite.config.lay_down_config import data_manipulation_config_path -from primaite.config.training_config import main_training_config_path -from primaite.main import run - - -def test_primaite_main_e2e(): - """Tests the primaite.main.run function end-to-end.""" - run(main_training_config_path(), data_manipulation_config_path()) diff --git a/tests/e2e_integration_tests/__init__.py b/tests/mock_and_patch/__init__.py similarity index 100% rename from tests/e2e_integration_tests/__init__.py rename to tests/mock_and_patch/__init__.py diff --git a/tests/mock_and_patch/get_session_path_mock.py b/tests/mock_and_patch/get_session_path_mock.py new file mode 100644 index 00000000..cfcfb8f0 --- /dev/null +++ b/tests/mock_and_patch/get_session_path_mock.py @@ -0,0 +1,24 @@ +import tempfile +from datetime import datetime +from pathlib import Path + +from primaite import getLogger + +_LOGGER = getLogger(__name__) + + +def get_temp_session_path(session_timestamp: datetime) -> Path: + """ + Get a temp directory session path the test session will output to. + + :param session_timestamp: This is the datetime that the session started. + :return: The session directory path. + """ + date_dir = session_timestamp.strftime("%Y-%m-%d") + session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + session_path = ( + Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path + ) + session_path.mkdir(exist_ok=True, parents=True) + _LOGGER.debug(f"Created temp session directory: {session_path}") + return session_path diff --git a/tests/test_active_node.py b/tests/test_active_node.py index addc595c..b6833182 100644 --- a/tests/test_active_node.py +++ b/tests/test_active_node.py @@ -60,7 +60,9 @@ def test_os_state_change_if_not_compromised(operating_state, expected_state): 1, ) - active_node.set_software_state_if_not_compromised(SoftwareState.OVERWHELMED) + active_node.set_software_state_if_not_compromised( + SoftwareState.OVERWHELMED + ) assert active_node.software_state == expected_state @@ -98,7 +100,9 @@ def test_file_system_change(operating_state, expected_state): (HardwareState.ON, FileSystemState.CORRUPT), ], ) -def test_file_system_change_if_not_compromised(operating_state, expected_state): +def test_file_system_change_if_not_compromised( + operating_state, expected_state +): """ Test that a node cannot change its file system state. @@ -116,6 +120,8 @@ def test_file_system_change_if_not_compromised(operating_state, expected_state): 1, ) - active_node.set_file_system_state_if_not_compromised(FileSystemState.CORRUPT) + active_node.set_file_system_state_if_not_compromised( + FileSystemState.CORRUPT + ) assert active_node.file_system_state_actual == expected_state diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index efca7b0b..21e4857f 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -7,79 +7,78 @@ from primaite.environment.observations import ( NodeStatuses, ObservationsHandler, ) -from primaite.environment.primaite_env import Primaite from tests import TEST_CONFIG_ROOT -from tests.conftest import _get_primaite_env_from_config -@pytest.fixture -def env(request): - """Build Primaite environment for integration tests of observation space.""" - marker = request.node.get_closest_marker("env_config_paths") - training_config_path = marker.args[0]["training_config_path"] - lay_down_config_path = marker.args[0]["lay_down_config_path"] - env = _get_primaite_env_from_config( - training_config_path=training_config_path, - lay_down_config_path=lay_down_config_path, - ) - yield env - - -@pytest.mark.env_config_paths( - dict( - training_config_path=TEST_CONFIG_ROOT - / "obs_tests/main_config_without_obs.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", - ) +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml", + TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", + ] + ], + indirect=True, ) -def test_default_obs_space(env: Primaite): +def test_default_obs_space(temp_primaite_session): """Create environment with no obs space defined in config and check that the default obs space was created.""" - env.update_environent_obs() + with temp_primaite_session as session: + session.env.update_environent_obs() - components = env.obs_handler.registered_obs_components + components = session.env.obs_handler.registered_obs_components - assert len(components) == 1 - assert isinstance(components[0], NodeLinkTable) + assert len(components) == 1 + assert isinstance(components[0], NodeLinkTable) -@pytest.mark.env_config_paths( - dict( - training_config_path=TEST_CONFIG_ROOT - / "obs_tests/main_config_without_obs.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", - ) +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml", + TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", + ] + ], + indirect=True, ) -def test_registering_components(env: Primaite): +def test_registering_components(temp_primaite_session): """Test regitering and deregistering a component.""" - handler = ObservationsHandler() - component = NodeStatuses(env) - handler.register(component) - assert component in handler.registered_obs_components - handler.deregister(component) - assert component not in handler.registered_obs_components + with temp_primaite_session as session: + env = session.env + handler = ObservationsHandler() + component = NodeStatuses(env) + handler.register(component) + assert component in handler.registered_obs_components + handler.deregister(component) + assert component not in handler.registered_obs_components -@pytest.mark.env_config_paths( - dict( - training_config_path=TEST_CONFIG_ROOT - / "obs_tests/main_config_NODE_LINK_TABLE.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", - ) +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_LINK_TABLE.yaml", + TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", + ] + ], + indirect=True, ) class TestNodeLinkTable: """Test the NodeLinkTable observation component (in isolation).""" - def test_obs_shape(self, env: Primaite): + def test_obs_shape(self, temp_primaite_session): """Try creating env with box observation space.""" - env.update_environent_obs() + with temp_primaite_session as session: + env = session.env + env.update_environent_obs() - # we have three nodes and two links, with two service - # therefore the box observation space will have: - # * 5 rows (3 nodes + 2 links) - # * 6 columns (four fixed and two for the services) - assert env.env_obs.shape == (5, 6) + # we have three nodes and two links, with two service + # therefore the box observation space will have: + # * 5 rows (3 nodes + 2 links) + # * 6 columns (four fixed and two for the services) + assert env.env_obs.shape == (5, 6) - def test_value(self, env: Primaite): + def test_value(self, temp_primaite_session): """Test that the observation is generated correctly. The laydown has: @@ -125,36 +124,45 @@ class TestNodeLinkTable: * 999 (999 traffic service1) * 0 (no traffic for service2) """ - # act = np.asarray([0,]) - obs, reward, done, info = env.step(0) # apply the 'do nothing' action + with temp_primaite_session as session: + env = session.env + # act = np.asarray([0,]) + obs, reward, done, info = env.step( + 0 + ) # apply the 'do nothing' action - assert np.array_equal( - obs, - [ - [1, 1, 3, 1, 1, 1], - [2, 1, 1, 1, 1, 4], - [3, 1, 1, 1, 0, 0], - [4, 0, 0, 0, 999, 0], - [5, 0, 0, 0, 999, 0], - ], - ) + assert np.array_equal( + obs, + [ + [1, 1, 3, 1, 1, 1], + [2, 1, 1, 1, 1, 4], + [3, 1, 1, 1, 0, 0], + [4, 0, 0, 0, 999, 0], + [5, 0, 0, 0, 999, 0], + ], + ) -@pytest.mark.env_config_paths( - dict( - training_config_path=TEST_CONFIG_ROOT - / "obs_tests/main_config_NODE_STATUSES.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", - ) +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_STATUSES.yaml", + TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", + ] + ], + indirect=True, ) class TestNodeStatuses: """Test the NodeStatuses observation component (in isolation).""" - def test_obs_shape(self, env: Primaite): + def test_obs_shape(self, temp_primaite_session): """Try creating env with NodeStatuses as the only component.""" - assert env.env_obs.shape == (15,) + with temp_primaite_session as session: + env = session.env + assert env.env_obs.shape == (15,) - def test_values(self, env: Primaite): + def test_values(self, temp_primaite_session): """Test that the hardware and software states are encoded correctly. The laydown has: @@ -181,28 +189,38 @@ class TestNodeStatuses: * service 1 = n/a (0) * service 2 = n/a (0) """ - obs, _, _, _ = env.step(0) # apply the 'do nothing' action - assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0]) + with temp_primaite_session as session: + env = session.env + obs, _, _, _ = env.step(0) # apply the 'do nothing' action + assert np.array_equal( + obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0] + ) -@pytest.mark.env_config_paths( - dict( - training_config_path=TEST_CONFIG_ROOT - / "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", - ) +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT + / "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml", + TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", + ] + ], + indirect=True, ) class TestLinkTrafficLevels: """Test the LinkTrafficLevels observation component (in isolation).""" - def test_obs_shape(self, env: Primaite): + def test_obs_shape(self, temp_primaite_session): """Try creating env with MultiDiscrete observation space.""" - env.update_environent_obs() + with temp_primaite_session as session: + env = session.env + env.update_environent_obs() - # we have two links and two services, so the shape should be 2 * 2 - assert env.env_obs.shape == (2 * 2,) + # we have two links and two services, so the shape should be 2 * 2 + assert env.env_obs.shape == (2 * 2,) - def test_values(self, env: Primaite): + def test_values(self, temp_primaite_session): """Test that traffic values are encoded correctly. The laydown has: @@ -212,12 +230,14 @@ class TestLinkTrafficLevels: * an IER trying to send 999 bits of data over both links the whole time (via the first service) * link bandwidth of 1000, therefore the utilisation is 99.9% """ - obs, reward, done, info = env.step(0) - obs, reward, done, info = env.step(0) + with temp_primaite_session as session: + env = session.env + obs, reward, done, info = env.step(0) + obs, reward, done, info = env.step(0) - # the observation space has combine_service_traffic set to False, so the space has this format: - # [link1_service1, link1_service2, link2_service1, link2_service2] - # we send 999 bits of data via link1 and link2 on service 1. - # therefore the first and third elements should be 6 and all others 0 - # (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%) - assert np.array_equal(obs, [6, 0, 6, 0]) + # the observation space has combine_service_traffic set to False, so the space has this format: + # [link1_service1, link1_service2, link2_service1, link2_service2] + # we send 999 bits of data via link1 and link2 on service 1. + # therefore the first and third elements should be 6 and all others 0 + # (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%) + assert np.array_equal(obs, [6, 0, 6, 0]) diff --git a/tests/test_primaite_session.py b/tests/test_primaite_session.py new file mode 100644 index 00000000..8c8d2b80 --- /dev/null +++ b/tests/test_primaite_session.py @@ -0,0 +1,61 @@ +import os + +import pytest + +from primaite import getLogger +from primaite.config.lay_down_config import dos_very_basic_config_path +from primaite.config.training_config import main_training_config_path + +_LOGGER = getLogger(__name__) + + +@pytest.mark.parametrize( + "temp_primaite_session", + [[main_training_config_path(), dos_very_basic_config_path()]], + indirect=True, +) +def test_primaite_session(temp_primaite_session): + """Tests the PrimaiteSession class and its outputs.""" + with temp_primaite_session as session: + session_path = session.session_path + assert session_path.exists() + session.learn() + # Learning outputs are saved in session.learning_path + session.evaluate() + # Evaluation outputs are saved in session.evaluation_path + + # If you need to inspect any session outputs, it must be done inside + # the context manager + + # Check that the metadata json file exists + assert (session_path / "session_metadata.json").exists() + + # Check that the network png file exists + assert (session_path / f"network_{session.timestamp_str}.png").exists() + + # Check that both the transactions and av reward csv files exist + for file in session.learning_path.iterdir(): + if file.suffix == ".csv": + assert ( + "all_transactions" in file.name + or "average_reward_per_episode" in file.name + ) + + # Check that both the transactions and av reward csv files exist + for file in session.evaluation_path.iterdir(): + if file.suffix == ".csv": + assert ( + "all_transactions" in file.name + or "average_reward_per_episode" in file.name + ) + + _LOGGER.debug("Inspecting files in temp session path...") + for dir_path, dir_names, file_names in os.walk(session_path): + for file in file_names: + path = os.path.join(dir_path, file) + file_str = path.split(str(session_path))[-1] + _LOGGER.debug(f" {file_str}") + + # Now that we've exited the context manager, the session.session_path + # directory and its contents are deleted + assert not session_path.exists() diff --git a/tests/test_resetting_node.py b/tests/test_resetting_node.py index abe8115c..e7312777 100644 --- a/tests/test_resetting_node.py +++ b/tests/test_resetting_node.py @@ -18,7 +18,9 @@ from primaite.nodes.service_node import ServiceNode "starting_operating_state, expected_operating_state", [(HardwareState.RESETTING, HardwareState.ON)], ) -def test_node_resets_correctly(starting_operating_state, expected_operating_state): +def test_node_resets_correctly( + starting_operating_state, expected_operating_state +): """Tests that a node resets correctly.""" active_node = ActiveNode( node_id="0", diff --git a/tests/test_reward.py b/tests/test_reward.py index c3fcdfc4..95603b54 100644 --- a/tests/test_reward.py +++ b/tests/test_reward.py @@ -1,26 +1,33 @@ +import pytest + from tests import TEST_CONFIG_ROOT -from tests.conftest import _get_primaite_env_from_config -def test_rewards_are_being_penalised_at_each_step_function(): +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + TEST_CONFIG_ROOT / "one_node_states_on_off_lay_down_config.yaml", + ] + ], + indirect=True, +) +def test_rewards_are_being_penalised_at_each_step_function( + temp_primaite_session, +): """ Test that hardware state is penalised at each step. When the initial state is OFF compared to reference state which is ON. - """ - env = _get_primaite_env_from_config( - training_config_path=TEST_CONFIG_ROOT - / "one_node_states_on_off_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT - / "one_node_states_on_off_lay_down_config.yaml", - ) - """ - On different steps (of the 13 in total) these are the following rewards for config_6 which are activated: + On different steps (of the 13 in total) these are the following rewards + for config_6 which are activated: File System State: goodShouldBeCorrupt = 5 (between Steps 1 & 3) Hardware State: onShouldBeOff = -2 (between Steps 4 & 6) Service State: goodShouldBeCompromised = 5 (between Steps 7 & 9) - Software State (Software State): goodShouldBeCompromised = 5 (between Steps 10 & 12) + Software State (Software State): goodShouldBeCompromised = 5 (between + Steps 10 & 12) Total Reward: -2 - 2 + 5 + 5 + 5 + 5 + 5 + 5 = 26 Step Count: 13 @@ -28,5 +35,8 @@ def test_rewards_are_being_penalised_at_each_step_function(): For the 4 steps where this occurs the average reward is: Average Reward: 2 (26 / 13) """ - print("average reward", env.average_reward) - assert env.average_reward == -8.0 + with temp_primaite_session as session: + session.evaluate() + session.close() + ev_rewards = session.eval_av_reward_per_episode_csv() + assert ev_rewards[1] == -8.0 diff --git a/tests/test_service_node.py b/tests/test_service_node.py index 4383fc1b..9e760b23 100644 --- a/tests/test_service_node.py +++ b/tests/test_service_node.py @@ -45,7 +45,9 @@ def test_service_state_change(operating_state, expected_state): (HardwareState.ON, SoftwareState.OVERWHELMED), ], ) -def test_service_state_change_if_not_comprised(operating_state, expected_state): +def test_service_state_change_if_not_comprised( + operating_state, expected_state +): """ Test that a node cannot change the state of a running service. @@ -65,6 +67,8 @@ def test_service_state_change_if_not_comprised(operating_state, expected_state): service = Service("TCP", 80, SoftwareState.GOOD) service_node.add_service(service) - service_node.set_service_state_if_not_compromised("TCP", SoftwareState.OVERWHELMED) + service_node.set_service_state_if_not_compromised( + "TCP", SoftwareState.OVERWHELMED + ) assert service_node.get_service_state("TCP") == expected_state diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index 8ff43fe6..1cf63cde 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -1,9 +1,10 @@ import time +import pytest + from primaite.common.enums import HardwareState from primaite.environment.primaite_env import Primaite from tests import TEST_CONFIG_ROOT -from tests.conftest import _get_primaite_env_from_config def run_generic_set_actions(env: Primaite): @@ -44,59 +45,72 @@ def run_generic_set_actions(env: Primaite): # env.close() -def test_single_action_space_is_valid(): - """Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations.""" - env = _get_primaite_env_from_config( - training_config_path=TEST_CONFIG_ROOT / "single_action_space_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT - / "single_action_space_lay_down_config.yaml", - ) +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT / "single_action_space_main_config.yaml", + TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", + ] + ], + indirect=True, +) +def test_single_action_space_is_valid(temp_primaite_session): + """Test single action space is valid.""" + with temp_primaite_session as session: + env = session.env - run_generic_set_actions(env) - - # Retrieve the action space dictionary values from environment - env_action_space_dict = env.action_dict.values() - # Flags to check the conditions of the action space - contains_acl_actions = False - contains_node_actions = False - both_action_spaces = False - # Loop through each element of the list (which is every value from the dictionary) - for dict_item in env_action_space_dict: - # Node action detected - if len(dict_item) == 4: - contains_node_actions = True - # Link action detected - elif len(dict_item) == 6: - contains_acl_actions = True - # If both are there then the ANY action type is working - if contains_node_actions and contains_acl_actions: - both_action_spaces = True - # Check condition should be True - assert both_action_spaces + run_generic_set_actions(env) + # Retrieve the action space dictionary values from environment + env_action_space_dict = env.action_dict.values() + # Flags to check the conditions of the action space + contains_acl_actions = False + contains_node_actions = False + both_action_spaces = False + # Loop through each element of the list (which is every value from the dictionary) + for dict_item in env_action_space_dict: + # Node action detected + if len(dict_item) == 4: + contains_node_actions = True + # Link action detected + elif len(dict_item) == 6: + contains_acl_actions = True + # If both are there then the ANY action type is working + if contains_node_actions and contains_acl_actions: + both_action_spaces = True + # Check condition should be True + assert both_action_spaces -def test_agent_is_executing_actions_from_both_spaces(): +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT + / "single_action_space_fixed_blue_actions_main_config.yaml", + TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", + ] + ], + indirect=True, +) +def test_agent_is_executing_actions_from_both_spaces(temp_primaite_session): """Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL).""" - env = _get_primaite_env_from_config( - training_config_path=TEST_CONFIG_ROOT - / "single_action_space_fixed_blue_actions_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT - / "single_action_space_lay_down_config.yaml", - ) - # Run environment with specified fixed blue agent actions only - run_generic_set_actions(env) - # Retrieve hardware state of computer_1 node in laydown config - # Agent turned this off in Step 5 - computer_node_hardware_state = env.nodes["1"].hardware_state - # Retrieve the Access Control List object stored by the environment at the end of the episode - access_control_list = env.acl - # Use the Access Control List object acl object attribute to get dictionary - # Use dictionary.values() to get total list of all items in the dictionary - acl_rules_list = access_control_list.acl.values() - # Length of this list tells you how many items are in the dictionary - # This number is the frequency of Access Control Rules in the environment - # In the scenario, we specified that the agent should create only 1 acl rule - num_of_rules = len(acl_rules_list) - # Therefore these statements below MUST be true - assert computer_node_hardware_state == HardwareState.OFF - assert num_of_rules == 1 + with temp_primaite_session as session: + env = session.env + # Run environment with specified fixed blue agent actions only + run_generic_set_actions(env) + # Retrieve hardware state of computer_1 node in laydown config + # Agent turned this off in Step 5 + computer_node_hardware_state = env.nodes["1"].hardware_state + # Retrieve the Access Control List object stored by the environment at the end of the episode + access_control_list = env.acl + # Use the Access Control List object acl object attribute to get dictionary + # Use dictionary.values() to get total list of all items in the dictionary + acl_rules_list = access_control_list.acl.values() + # Length of this list tells you how many items are in the dictionary + # This number is the frequency of Access Control Rules in the environment + # In the scenario, we specified that the agent should create only 1 acl rule + num_of_rules = len(acl_rules_list) + # Therefore these statements below MUST be true + assert computer_node_hardware_state == HardwareState.OFF + assert num_of_rules == 1 diff --git a/tests/test_training_config.py b/tests/test_training_config.py index 02e90d30..88bc802b 100644 --- a/tests/test_training_config.py +++ b/tests/test_training_config.py @@ -16,7 +16,9 @@ def test_legacy_lay_down_config_yaml_conversion(): with open(new_path, "r") as file: new_dict = yaml.safe_load(file) - converted_dict = training_config.convert_legacy_training_config_dict(legacy_dict) + converted_dict = training_config.convert_legacy_training_config_dict( + legacy_dict + ) for key, value in new_dict.items(): assert converted_dict[key] == value