#917 - Integrated the PrimaiteSession into all tests.

- Ran a full pre-commit hook and thus encountered tons of fixes required
This commit is contained in:
Chris McCarthy
2023-06-30 09:08:13 +01:00
parent 7f912df383
commit 73015802ec
62 changed files with 1880 additions and 802 deletions

View File

@@ -59,4 +59,4 @@ steps:
- script: |
pytest tests/
displayName: 'Run unmarked tests'
displayName: 'Run tests'

View File

@@ -13,6 +13,9 @@ repos:
rev: 23.1.0
hooks:
- id: black
args: [ "--line-length=79" ]
additional_dependencies:
- jupyter
- repo: http://github.com/pycqa/isort
rev: 5.12.0
hooks:
@@ -22,4 +25,5 @@ repos:
rev: 6.0.0
hooks:
- id: flake8
additional_dependencies: [ flake8-docstrings ]
additional_dependencies:
- flake8-docstrings

View File

@@ -22,7 +22,7 @@ The environment config file consists of the following attributes:
* **agent_framework** [enum]
This identifies the agent framework to be used to instantiate the agent algorithm. Select from one of the following:
* NONE - Where a user developed agent is to be used
@@ -30,14 +30,14 @@ The environment config file consists of the following attributes:
* RLLIB - Ray RLlib.
* **agent_identifier**
This identifies the agent to use for the session. Select from one of the following:
* A2C - Advantage Actor Critic
* PPO - Proximal Policy Optimization
* HARDCODED - A custom built deterministic agent
* RANDOM - A Stochastic random agent
* **action_type** [enum]

View File

@@ -47,6 +47,8 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| asttokens | 2.2.1 | Apache 2.0 | https://github.com/gristlabs/asttokens |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| astunparse | 1.6.3 | BSD License | https://github.com/simonpercivall/astunparse |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| attrs | 23.1.0 | MIT License | https://www.attrs.org/en/stable/changelog.html |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| backcall | 0.2.0 | BSD License | https://github.com/takluyver/backcall |
@@ -103,6 +105,8 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| flake8 | 6.0.0 | MIT License | https://github.com/pycqa/flake8 |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| flatbuffers | 23.5.26 | Apache Software License | https://google.github.io/flatbuffers/ |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| fonttools | 4.39.4 | MIT License | http://github.com/fonttools/fonttools |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| fqdn | 1.5.1 | Mozilla Public License 2.0 (MPL 2.0) | https://github.com/ypcrts/fqdn |
@@ -111,9 +115,13 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| furo | 2023.3.27 | MIT License | UNKNOWN |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| gast | 0.4.0 | BSD License | https://github.com/serge-sans-paille/gast/ |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| google-auth | 2.19.0 | Apache Software License | https://github.com/googleapis/google-auth-library-python |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| google-auth-oauthlib | 1.0.0 | Apache Software License | https://github.com/GoogleCloudPlatform/google-auth-library-python-oauthlib |
| google-auth-oauthlib | 0.4.6 | Apache Software License | https://github.com/GoogleCloudPlatform/google-auth-library-python-oauthlib |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| google-pasta | 0.2.0 | Apache Software License | https://github.com/google/pasta |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| grpcio | 1.51.3 | Apache Software License | https://grpc.io |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
@@ -121,6 +129,8 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| gymnasium-notices | 0.0.1 | MIT License | https://github.com/Farama-Foundation/gym-notices |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| h5py | 3.9.0 | BSD License | https://www.h5py.org/ |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| identify | 2.5.24 | MIT License | https://github.com/pre-commit/identify |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| idna | 3.4 | BSD License | https://github.com/kjd/idna |
@@ -141,6 +151,8 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| isoduration | 20.11.0 | ISC License (ISCL) | https://github.com/bolsote/isoduration |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| jax | 0.4.12 | Apache-2.0 | https://github.com/google/jax |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| jedi | 0.18.2 | MIT License | https://github.com/davidhalter/jedi |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| json5 | 0.9.14 | Apache Software License | https://github.com/dpranke/pyjson5 |
@@ -151,14 +163,14 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| jupyter-events | 0.6.3 | BSD License | http://jupyter.org |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| jupyter-server | 1.24.0 | BSD License | https://jupyter-server.readthedocs.io |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| jupyter-ydoc | 0.2.4 | BSD 3-Clause License | https://jupyter.org |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| jupyter_client | 8.2.0 | BSD License | https://jupyter.org |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| jupyter_core | 5.3.0 | BSD License | https://jupyter.org |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| jupyter_server | 2.6.0 | BSD License | https://jupyter-server.readthedocs.io |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| jupyter_server_fileid | 0.9.0 | BSD License | UNKNOWN |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| jupyter_server_terminals | 0.4.4 | BSD License | https://jupyter.org |
@@ -171,10 +183,14 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| jupyterlab_server | 2.22.1 | BSD License | https://jupyterlab-server.readthedocs.io |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| keras | 2.12.0 | Apache Software License | https://keras.io/ |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| kiwisolver | 1.4.4 | BSD License | UNKNOWN |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| lazy_loader | 0.2 | BSD License | https://github.com/scientific-python/lazy_loader |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| libclang | 16.0.0 | Apache Software License | https://github.com/sighingnow/libclang |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| lz4 | 4.3.2 | BSD License | https://github.com/python-lz4/python-lz4 |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| markdown-it-py | 2.2.0 | MIT License | https://github.com/executablebooks/markdown-it-py |
@@ -183,19 +199,23 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| matplotlib-inline | 0.1.6 | BSD 3-Clause | https://github.com/ipython/matplotlib-inline |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| mavizstyle | 1.0.0 | UNKNOWN | UNKNOWN |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| mccabe | 0.7.0 | MIT License | https://github.com/pycqa/mccabe |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| mdurl | 0.1.2 | MIT License | https://github.com/executablebooks/mdurl |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| mistune | 2.0.5 | BSD License | https://github.com/lepture/mistune |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| ml-dtypes | 0.2.0 | Apache Software License | UNKNOWN |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| mock | 5.0.2 | BSD License | http://mock.readthedocs.org/en/latest/ |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| mpmath | 1.3.0 | BSD License | http://mpmath.org/ |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| msgpack | 1.0.5 | Apache Software License | https://msgpack.org/ |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| nbclassic | 1.0.0 | BSD License | https://github.com/jupyter/nbclassic |
| nbclassic | 0.5.6 | BSD License | https://github.com/jupyter/nbclassic |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| nbclient | 0.8.0 | BSD License | https://jupyter.org |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
@@ -217,6 +237,8 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| oauthlib | 3.2.2 | BSD License | https://github.com/oauthlib/oauthlib |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| opt-einsum | 3.3.0 | MIT | https://github.com/dgasmith/opt_einsum |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| overrides | 7.3.1 | Apache License, Version 2.0 | https://github.com/mkorpela/overrides |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| packaging | 23.1 | Apache Software License; BSD License | https://github.com/pypa/packaging |
@@ -231,11 +253,17 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| platformdirs | 3.5.1 | MIT License | https://github.com/platformdirs/platformdirs |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| plotly | 5.15.0 | MIT License | https://plotly.com/python/ |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| pluggy | 1.0.0 | MIT License | https://github.com/pytest-dev/pluggy |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| pre-commit | 2.20.0 | MIT License | https://github.com/pre-commit/pre-commit |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| primaite | 1.2.1 | GFX | UNKNOWN |
| primaite | 2.0.0rc1 | GFX | UNKNOWN |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| primaite | 2.0.0rc1 | GFX | UNKNOWN |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| primaite | 2.0.0rc1 | GFX | UNKNOWN |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| prometheus-client | 0.17.0 | Apache Software License | https://github.com/prometheus/client_python |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
@@ -295,6 +323,8 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| rsa | 4.9 | Apache Software License | https://stuvel.eu/rsa |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| ruff | 0.0.272 | MIT License | https://github.com/charliermarsh/ruff |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| scikit-image | 0.20.0 | BSD License | https://scikit-image.org |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| scipy | 1.10.1 | BSD License | https://scipy.org/ |
@@ -335,14 +365,26 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| tabulate | 0.9.0 | MIT License | https://github.com/astanin/python-tabulate |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| tensorboard | 2.12.3 | Apache Software License | https://github.com/tensorflow/tensorboard |
| tenacity | 8.2.2 | Apache Software License | https://github.com/jd/tenacity |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| tensorboard-data-server | 0.7.0 | Apache Software License | https://github.com/tensorflow/tensorboard/tree/master/tensorboard/data/server |
| tensorboard | 2.11.2 | Apache Software License | https://github.com/tensorflow/tensorboard |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| tensorboard-data-server | 0.6.1 | Apache Software License | https://github.com/tensorflow/tensorboard/tree/master/tensorboard/data/server |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| tensorboard-plugin-wit | 1.8.1 | Apache 2.0 | https://whatif-tool.dev |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| tensorboardX | 2.6 | MIT License | https://github.com/lanpa/tensorboardX |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| tensorflow | 2.12.0 | Apache Software License | https://www.tensorflow.org/ |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| tensorflow-estimator | 2.12.0 | Apache Software License | https://www.tensorflow.org/ |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| tensorflow-intel | 2.12.0 | Apache Software License | https://www.tensorflow.org/ |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| tensorflow-io-gcs-filesystem | 0.31.0 | Apache Software License | https://github.com/tensorflow/io |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| termcolor | 2.3.0 | MIT License | https://github.com/termcolor/termcolor |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| terminado | 0.17.1 | BSD License | https://github.com/jupyter/terminado |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| tifffile | 2023.4.12 | BSD License | https://www.cgohlke.com |
@@ -377,6 +419,8 @@
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| websocket-client | 1.5.2 | Apache Software License | https://github.com/websocket-client/websocket-client.git |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| wrapt | 1.14.1 | BSD License | https://github.com/GrahamDumpleton/wrapt |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| y-py | 0.5.9 | MIT License | https://github.com/y-crdt/ypy |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
| ypy-websocket | 0.8.2 | UNKNOWN | https://github.com/y-crdt/ypy-websocket |

View File

@@ -31,6 +31,8 @@ dependencies = [
"networkx==3.1",
"numpy==1.23.5",
"platformdirs==3.5.1",
"plotly==5.15.0",
"polars==0.18.4",
"PyYAML==6.0",
"ray[rllib]==2.2.0",
"stable-baselines3==1.6.2",
@@ -69,3 +71,12 @@ tensorflow = [
[project.scripts]
primaite = "primaite.cli:app"
[tool.isort]
profile = "black"
line_length = 79
force_sort_within_sections = "False"
order_by_type = "False"
[tool.black]
line-length = 79

View File

@@ -1 +1 @@
2.0.0rc1
2.0.0rc1

View File

@@ -3,12 +3,10 @@ import logging
import logging.config
import sys
from bisect import bisect
from logging import Formatter, LogRecord, StreamHandler
from logging import Logger
from logging import Formatter, Logger, LogRecord, StreamHandler
from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import Dict
from typing import Final
from typing import Dict, Final
import pkg_resources
import yaml
@@ -21,7 +19,6 @@ _PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite")
def _get_primaite_config():
config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml"
if not config_path.exists():
config_path = Path(
pkg_resources.resource_filename(
"primaite", "setup/_package_data/primaite_config.yaml"
@@ -37,7 +34,9 @@ def _get_primaite_config():
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
primaite_config["log_level"] = log_level_map[primaite_config["logging"]["log_level"]]
primaite_config["log_level"] = log_level_map[
primaite_config["logging"]["log_level"]
]
return primaite_config
@@ -111,9 +110,13 @@ _LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter(
{
logging.DEBUG: _PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"],
logging.INFO: _PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"],
logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"],
logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"][
"WARNING"
],
logging.ERROR: _PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"],
logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"]
logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"][
"CRITICAL"
],
}
)

View File

@@ -10,7 +10,9 @@ class AccessControlList:
def __init__(self):
"""Init."""
self.acl: Dict[str, AccessControlList] = {} # A dictionary of ACL Rules
self.acl: Dict[
str, AccessControlList
] = {} # A dictionary of ACL Rules
def check_address_match(self, _rule, _source_ip_address, _dest_ip_address):
"""
@@ -37,13 +39,17 @@ class AccessControlList:
_rule.get_source_ip() == _source_ip_address
and _rule.get_dest_ip() == "ANY"
)
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY")
or (
_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY"
)
):
return True
else:
return False
def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port):
def is_blocked(
self, _source_ip_address, _dest_ip_address, _protocol, _port
):
"""
Checks for rules that block a protocol / port.
@@ -87,7 +93,9 @@ class AccessControlList:
_protocol: the protocol
_port: the port
"""
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
new_rule = ACLRule(
_permission, _source_ip, _dest_ip, _protocol, str(_port)
)
hash_value = hash(new_rule)
self.acl[hash_value] = new_rule
@@ -102,7 +110,9 @@ class AccessControlList:
_protocol: the protocol
_port: the port
"""
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
rule = ACLRule(
_permission, _source_ip, _dest_ip, _protocol, str(_port)
)
hash_value = hash(rule)
# There will not always be something 'popable' since the agent will be trying random things
try:
@@ -114,7 +124,9 @@ class AccessControlList:
"""Removes all rules."""
self.acl.clear()
def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port):
def get_dictionary_hash(
self, _permission, _source_ip, _dest_ip, _protocol, _port
):
"""
Produces a hash value for a rule.
@@ -128,6 +140,8 @@ class AccessControlList:
Returns:
Hash value based on rule parameters.
"""
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
rule = ACLRule(
_permission, _source_ip, _dest_ip, _protocol, str(_port)
)
hash_value = hash(rule)
return hash_value

View File

@@ -30,7 +30,13 @@ class ACLRule:
Returns hash of core parameters.
"""
return hash(
(self.permission, self.source_ip, self.dest_ip, self.protocol, self.port)
(
self.permission,
self.source_ip,
self.dest_ip,
self.protocol,
self.port,
)
)
def get_permission(self):

View File

@@ -1,27 +1,31 @@
from __future__ import annotations
import json
import time
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from typing import Optional, Final, Dict, Union
from typing import Dict, Final, Optional, Union
from uuid import uuid4
import yaml
import primaite
from primaite import getLogger, SESSIONS_DIR
from primaite.config import lay_down_config
from primaite.config import training_config
from primaite.config import lay_down_config, training_config
from primaite.config.training_config import TrainingConfig
from primaite.data_viz.session_plots import plot_av_reward_per_episode
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
def _get_session_path(session_timestamp: datetime) -> Path:
def get_session_path(session_timestamp: datetime) -> Path:
"""
Get a temp directory session path the test session will output to.
Get the directory path the session will output to.
This is set in the format of:
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
:param session_timestamp: This is the datetime that the session started.
:return: The session directory path.
@@ -35,13 +39,15 @@ def _get_session_path(session_timestamp: datetime) -> Path:
class AgentSessionABC(ABC):
"""
An ABC that manages training and/or evaluation of agents in PrimAITE.
This class cannot be directly instantiated and must be inherited from
with all implemented abstract methods implemented.
"""
@abstractmethod
def __init__(
self,
training_config_path,
lay_down_config_path
):
def __init__(self, training_config_path, lay_down_config_path):
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path]] = training_config_path
@@ -66,9 +72,8 @@ class AgentSessionABC(ABC):
self._uuid = str(uuid4())
self.session_timestamp: datetime = datetime.now()
"The session timestamp"
self.session_path = _get_session_path(self.session_timestamp)
self.session_path = get_session_path(self.session_timestamp)
"The Session path"
self.checkpoints_path.mkdir(parents=True, exist_ok=True)
@property
def timestamp_str(self) -> str:
@@ -78,17 +83,23 @@ class AgentSessionABC(ABC):
@property
def learning_path(self) -> Path:
"""The learning outputs path."""
return self.session_path / "learning"
path = self.session_path / "learning"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def evaluation_path(self) -> Path:
"""The evaluation outputs path."""
return self.session_path / "evaluation"
path = self.session_path / "evaluation"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def checkpoints_path(self) -> Path:
"""The Session checkpoints path."""
return self.learning_path / "checkpoints"
path = self.learning_path / "checkpoints"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def uuid(self):
@@ -118,14 +129,8 @@ class AgentSessionABC(ABC):
"uuid": self.uuid,
"start_datetime": self.session_timestamp.isoformat(),
"end_datetime": None,
"learning": {
"total_episodes": None,
"total_time_steps": None
},
"evaluation": {
"total_episodes": None,
"total_time_steps": None
},
"learning": {"total_episodes": None, "total_time_steps": None},
"evaluation": {"total_episodes": None, "total_time_steps": None},
"env": {
"training_config": self._training_config.to_dict(
json_serializable=True
@@ -156,11 +161,19 @@ class AgentSessionABC(ABC):
metadata_dict["end_datetime"] = datetime.now().isoformat()
if not self.is_eval:
metadata_dict["learning"]["total_episodes"] = self._env.episode_count # noqa
metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa
metadata_dict["learning"][
"total_episodes"
] = self._env.episode_count # noqa
metadata_dict["learning"][
"total_time_steps"
] = self._env.total_step_count # noqa
else:
metadata_dict["evaluation"]["total_episodes"] = self._env.episode_count # noqa
metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa
metadata_dict["evaluation"][
"total_episodes"
] = self._env.episode_count # noqa
metadata_dict["evaluation"][
"total_time_steps"
] = self._env.total_step_count # noqa
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
@@ -187,26 +200,47 @@ class AgentSessionABC(ABC):
@abstractmethod
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Train the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param kwargs: Any agent-specific key-word args to be passed.
"""
if self._can_learn:
_LOGGER.info("Finished learning")
_LOGGER.debug("Writing transactions")
self._update_session_metadata_file()
self._can_evaluate = True
self.is_eval = False
self._plot_av_reward_per_episode(learning_session=True)
@abstractmethod
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Evaluate the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param kwargs: Any agent-specific key-word args to be passed.
"""
self._env.set_as_eval() # noqa
self.is_eval = True
self._plot_av_reward_per_episode(learning_session=False)
_LOGGER.info("Finished evaluation")
@abstractmethod
@@ -216,6 +250,7 @@ class AgentSessionABC(ABC):
@classmethod
@abstractmethod
def load(cls, path: Union[str, Path]) -> AgentSessionABC:
"""Load an agent from file."""
if not isinstance(path, Path):
path = Path(path)
@@ -246,21 +281,56 @@ class AgentSessionABC(ABC):
else:
# Session path does not exist
msg = f"Failed to load PrimAITE Session, path does not exist: {path}"
msg = (
f"Failed to load PrimAITE Session, path does not exist: {path}"
)
_LOGGER.error(msg)
raise FileNotFoundError(msg)
pass
@abstractmethod
def save(self):
"""Save the agent."""
self._agent.save(self.session_path)
@abstractmethod
def export(self):
"""Export the agent to transportable file format."""
pass
def close(self):
"""Closes the agent."""
self._env.episode_av_reward_writer.close() # noqa
self._env.transaction_writer.close() # noqa
def _plot_av_reward_per_episode(self, learning_session: bool = True):
# self.close()
title = f"PrimAITE Session {self.timestamp_str} "
subtitle = str(self._training_config)
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
image_file = f"average_reward_per_episode_{self.timestamp_str}.png"
if learning_session:
title += "(Learning)"
path = self.learning_path / csv_file
image_path = self.learning_path / image_file
else:
title += "(Evaluation)"
path = self.evaluation_path / csv_file
image_path = self.evaluation_path / image_file
fig = plot_av_reward_per_episode(path, title, subtitle)
fig.write_image(image_path)
_LOGGER.debug(f"Saved average rewards per episode plot to: {path}")
class HardCodedAgentSessionABC(AgentSessionABC):
"""
An Agent Session ABC for evaluation deterministic agents.
This class cannot be directly instantiated and must be inherited from
with all implemented abstract methods implemented.
"""
def __init__(self, training_config_path, lay_down_config_path):
super().__init__(training_config_path, lay_down_config_path)
self._setup()
@@ -270,13 +340,12 @@ class HardCodedAgentSessionABC(AgentSessionABC):
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str
timestamp_str=self.timestamp_str,
)
super()._setup()
self._can_learn = False
self._can_evaluate = True
def _save_checkpoint(self):
pass
@@ -284,11 +353,20 @@ class HardCodedAgentSessionABC(AgentSessionABC):
pass
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Train the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param kwargs: Any agent-specific key-word args to be passed.
"""
_LOGGER.warning("Deterministic agents cannot learn")
@abstractmethod
@@ -296,20 +374,31 @@ class HardCodedAgentSessionABC(AgentSessionABC):
pass
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Evaluate the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param kwargs: Any agent-specific key-word args to be passed.
"""
self._env.set_as_eval() # noqa
self.is_eval = True
if not time_steps:
time_steps = self._training_config.num_steps
if not episodes:
episodes = self._training_config.num_episodes
obs = self._env.reset()
for episode in range(episodes):
# Reset env and collect initial observation
obs = self._env.reset()
for step in range(time_steps):
# Calculate action
action = self._calculate_action(obs)
@@ -322,15 +411,18 @@ class HardCodedAgentSessionABC(AgentSessionABC):
# Introduce a delay between steps
time.sleep(self._training_config.time_delay / 1000)
obs = self._env.reset()
self._env.close()
super().evaluate()
@classmethod
def load(cls):
"""Load an agent from file."""
_LOGGER.warning("Deterministic agents cannot be loaded")
def save(self):
"""Save the agent."""
_LOGGER.warning("Deterministic agents cannot be saved")
def export(self):
"""Export the agent to transportable file format."""
_LOGGER.warning("Deterministic agents cannot be exported")

View File

@@ -11,9 +11,13 @@ from primaite.common.enums import HardCodedAgentView
class HardCodedACLAgent(HardCodedAgentSessionABC):
"""An Agent Session class that implements a deterministic ACL agent."""
def _calculate_action(self, obs):
if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC:
if (
self._training_config.hard_coded_agent_view
== HardCodedAgentView.BASIC
):
# Basic view action using only the current observation
return self._calculate_action_basic_view(obs)
else:
@@ -22,6 +26,12 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
return self._calculate_action_full_view(obs)
def get_blocked_green_iers(self, green_iers, acl, nodes):
"""
Get blocked green IERs.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
blocked_green_iers = {}
for green_ier_id, green_ier in green_iers.items():
@@ -33,8 +43,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
port = green_ier.get_port()
# Can be blocked by an ACL or by default (no allow rule exists)
if acl.is_blocked(source_node_address, dest_node_address, protocol,
port):
if acl.is_blocked(
source_node_address, dest_node_address, protocol, port
):
blocked_green_iers[green_ier_id] = green_ier
return blocked_green_iers
@@ -42,8 +53,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
def get_matching_acl_rules_for_ier(self, ier, acl, nodes):
"""
Get matching ACL rules for an IER.
"""
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
source_node_id = ier.get_source_node_id()
source_node_address = nodes[source_node_id].ip_address
dest_node_id = ier.get_dest_node_id()
@@ -51,17 +64,22 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
protocol = ier.get_protocol() # e.g. 'TCP'
port = ier.get_port()
matching_rules = acl.get_relevant_rules(source_node_address,
dest_node_address, protocol,
port)
matching_rules = acl.get_relevant_rules(
source_node_address, dest_node_address, protocol, port
)
return matching_rules
def get_blocking_acl_rules_for_ier(self, ier, acl, nodes):
"""
Get blocking ACL rules for an IER.
Warning: Can return empty dict but IER can still be blocked by default (No ALLOW rule, therefore blocked)
"""
.. warning::
Can return empty dict but IER can still be blocked by default
(No ALLOW rule, therefore blocked).
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
blocked_rules = {}
@@ -74,8 +92,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
def get_allow_acl_rules_for_ier(self, ier, acl, nodes):
"""
Get all allowing ACL rules for an IER.
"""
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
allowed_rules = {}
@@ -85,9 +105,22 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
return allowed_rules
def get_matching_acl_rules(self, source_node_id, dest_node_id, protocol,
port, acl,
nodes, services_list):
def get_matching_acl_rules(
self,
source_node_id,
dest_node_id,
protocol,
port,
acl,
nodes,
services_list,
):
"""
Get matching ACL rules.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
if source_node_id != "ANY":
source_node_address = nodes[str(source_node_id)].ip_address
else:
@@ -100,21 +133,39 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
if protocol != "ANY":
protocol = services_list[
protocol - 1] # -1 as dont have to account for ANY in list of services
protocol - 1
] # -1 as dont have to account for ANY in list of services
matching_rules = acl.get_relevant_rules(source_node_address,
dest_node_address, protocol,
port)
matching_rules = acl.get_relevant_rules(
source_node_address, dest_node_address, protocol, port
)
return matching_rules
def get_allow_acl_rules(self, source_node_id, dest_node_id, protocol,
port, acl,
nodes, services_list):
matching_rules = self.get_matching_acl_rules(source_node_id,
dest_node_id,
protocol, port, acl,
nodes,
services_list)
def get_allow_acl_rules(
self,
source_node_id,
dest_node_id,
protocol,
port,
acl,
nodes,
services_list,
):
"""
Get the ALLOW ACL rules.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
matching_rules = self.get_matching_acl_rules(
source_node_id,
dest_node_id,
protocol,
port,
acl,
nodes,
services_list,
)
allowed_rules = {}
for rule_key, rule_value in matching_rules.items():
@@ -123,14 +174,31 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
return allowed_rules
def get_deny_acl_rules(self, source_node_id, dest_node_id, protocol, port,
acl,
nodes, services_list):
matching_rules = self.get_matching_acl_rules(source_node_id,
dest_node_id,
protocol, port, acl,
nodes,
services_list)
def get_deny_acl_rules(
self,
source_node_id,
dest_node_id,
protocol,
port,
acl,
nodes,
services_list,
):
"""
Get the DENY ACL rules.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
matching_rules = self.get_matching_acl_rules(
source_node_id,
dest_node_id,
protocol,
port,
acl,
nodes,
services_list,
)
allowed_rules = {}
for rule_key, rule_value in matching_rules.items():
@@ -141,7 +209,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
def _calculate_action_full_view(self, obs):
"""
Given an observation and the environment calculate a good acl-based action for the blue agent to take
Calculate a good acl-based action for the blue agent to take.
Knowledge of just the observation space is insufficient for a perfect solution, as we need to know:
@@ -167,8 +235,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing
an overwhelmed state, so we don't do this.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
#obs = convert_to_old_obs(obs)
# obs = convert_to_old_obs(obs)
r_obs = transform_change_obs_readable(obs)
_, _, _, *s = r_obs
@@ -184,7 +254,6 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
for service_num, service_states in enumerate(s):
for x, service_state in enumerate(service_states):
if service_state == "COMPROMISED":
action_source_id = x + 1 # +1 as 0 is any
action_destination_id = "ANY"
action_protocol = service_num + 1 # +1 as 0 is any
@@ -215,19 +284,23 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
action_permission = "ALLOW"
action_source_ip = rule.get_source_ip()
action_source_id = int(
get_node_of_ip(action_source_ip, self._env.nodes))
get_node_of_ip(action_source_ip, self._env.nodes)
)
action_destination_ip = rule.get_dest_ip()
action_destination_id = int(
get_node_of_ip(action_destination_ip,
self._env.nodes))
get_node_of_ip(
action_destination_ip, self._env.nodes
)
)
action_protocol_name = rule.get_protocol()
action_protocol = (
self._env.services_list.index(
action_protocol_name) + 1
self._env.services_list.index(action_protocol_name)
+ 1
) # convert name e.g. 'TCP' to index
action_port_name = rule.get_port()
action_port = self._env.ports_list.index(
action_port_name) + 1 # convert port name e.g. '80' to index
action_port = (
self._env.ports_list.index(action_port_name) + 1
) # convert port name e.g. '80' to index
found_action = True
break
@@ -258,21 +331,21 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
if not found_action:
# Which Green IERS are blocked
blocked_green_iers = self.get_blocked_green_iers(
self._env.green_iers, self._env.acl,
self._env.nodes)
self._env.green_iers, self._env.acl, self._env.nodes
)
for ier_key, ier in blocked_green_iers.items():
# Which ALLOW rules are allowing this IER (none)
allowing_rules = self.get_allow_acl_rules_for_ier(ier,
self._env.acl,
self._env.nodes)
allowing_rules = self.get_allow_acl_rules_for_ier(
ier, self._env.acl, self._env.nodes
)
# If there are no blocking rules, it may be being blocked by default
# If there is already an allow rule
node_id_to_check = int(ier.get_source_node_id())
service_name_to_check = ier.get_protocol()
service_id_to_check = self._env.services_list.index(
service_name_to_check)
service_name_to_check
)
# Service state of the the source node in the ier
service_state = s[service_id_to_check][node_id_to_check - 1]
@@ -283,11 +356,13 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
action_source_id = int(ier.get_source_node_id())
action_destination_id = int(ier.get_dest_node_id())
action_protocol_name = ier.get_protocol()
action_protocol = self._env.services_list.index(
action_protocol_name) + 1 # convert name e.g. 'TCP' to index
action_protocol = (
self._env.services_list.index(action_protocol_name) + 1
) # convert name e.g. 'TCP' to index
action_port_name = ier.get_port()
action_port = self._env.ports_list.index(
action_port_name) + 1 # convert port name e.g. '80' to index
action_port = (
self._env.ports_list.index(action_port_name) + 1
) # convert port name e.g. '80' to index
found_action = True
break
@@ -311,19 +386,25 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
return action
def _calculate_action_basic_view(self, obs):
"""Given an observation calculate a good acl-based action for the blue agent to take
"""Calculate a good acl-based action for the blue agent to take.
Uses ONLY information from the current observation with NO knowledge of previous actions taken and
NO reward feedback.
Uses ONLY information from the current observation with NO knowledge
of previous actions taken and NO reward feedback.
We rely on randomness to select the precise action, as we want to block all traffic originating from
a compromised node, without being able to tell:
We rely on randomness to select the precise action, as we want to
block all traffic originating from a compromised node, without being
able to tell:
1. Which ACL rules already exist
1. Which actions the agent has already tried.
2. Which actions the agent has already tried.
There is a high probability that the correct rule will not be deleted before the state becomes overwhelmed.
There is a high probability that the correct rule will not be deleted
before the state becomes overwhelmed.
Currently a deny rule does not overwrite an allow rule. The allow rules must be deleted.
Currently, a deny rule does not overwrite an allow rule. The allow
rules must be deleted.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_dict = self._env.action_dict
r_obs = transform_change_obs_readable(obs)
@@ -333,27 +414,35 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
s = [*s]
number_of_nodes = len(
[i for i in o if i != "NONE"]) # number of nodes (not links)
[i for i in o if i != "NONE"]
) # number of nodes (not links)
for service_num, service_states in enumerate(s):
comprimised_states = [n for n, i in enumerate(service_states) if
i == "COMPROMISED"]
comprimised_states = [
n for n, i in enumerate(service_states) if i == "COMPROMISED"
]
if len(comprimised_states) == 0:
# No states are COMPROMISED, try the next service
continue
compromised_node = np.random.choice(
comprimised_states) + 1 # +1 as 0 would be any
compromised_node = (
np.random.choice(comprimised_states) + 1
) # +1 as 0 would be any
action_decision = "DELETE"
action_permission = "ALLOW"
action_source_ip = compromised_node
# Randomly select a destination ID to block
action_destination_ip = np.random.choice(
list(range(1, number_of_nodes + 1)) + ["ANY"])
action_destination_ip = int(
action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip
list(range(1, number_of_nodes + 1)) + ["ANY"]
)
action_destination_ip = (
int(action_destination_ip)
if action_destination_ip != "ANY"
else action_destination_ip
)
action_protocol = service_num + 1 # +1 as 0 is any
# Randomly select a port to block
# Bad assumption that number of protocols equals number of ports AND no rules exist with an ANY port
# Bad assumption that number of protocols equals number of ports
# AND no rules exist with an ANY port
action_port = np.random.choice(list(range(1, len(s) + 1)))
action = [

View File

@@ -1,16 +1,21 @@
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.utils import (
get_new_action,
transform_change_obs_readable,
)
from primaite.agents.utils import (
transform_action_node_enum,
transform_change_obs_readable,
)
class HardCodedNodeAgent(HardCodedAgentSessionABC):
"""An Agent Session class that implements a deterministic Node agent."""
def _calculate_action(self, obs):
"""Given an observation calculate a good node-based action for the blue agent to take"""
"""
Calculate a good node-based action for the blue agent to take.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_dict = self._env.action_dict
r_obs = transform_change_obs_readable(obs)
_, o, os, *s = r_obs
@@ -18,7 +23,8 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
if len(r_obs) == 4: # only 1 service
s = [*s]
# Check in order of most important states (order doesn't currently matter, but it probably should)
# Check in order of most important states (order doesn't currently
# matter, but it probably should)
# First see if any OS states are compromised
for x, os_state in enumerate(os):
if os_state == "COMPROMISED":
@@ -26,8 +32,12 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
action_node_property = "OS"
property_action = "PATCHING"
action_service_index = 0 # does nothing isn't relevant for os
action = [action_node_id, action_node_property,
property_action, action_service_index]
action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
action = transform_action_node_enum(action)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
@@ -44,8 +54,12 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
property_action = "PATCHING"
action_service_index = service_num
action = [action_node_id, action_node_property,
property_action, action_service_index]
action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
action = transform_action_node_enum(action)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
@@ -63,8 +77,12 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
property_action = "PATCHING"
action_service_index = service_num
action = [action_node_id, action_node_property,
property_action, action_service_index]
action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
action = transform_action_node_enum(action)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
@@ -75,10 +93,18 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
if os_state == "OFF":
action_node_id = x + 1
action_node_property = "OPERATING"
property_action = "ON" # Why reset it when we can just turn it on
action_service_index = 0 # does nothing isn't relevant for operating state
action = [action_node_id, action_node_property,
property_action, action_service_index]
property_action = (
"ON" # Why reset it when we can just turn it on
)
action_service_index = (
0 # does nothing isn't relevant for operating state
)
action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
action = transform_action_node_enum(action, action_dict)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
@@ -89,8 +115,12 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
action_node_property = "NONE"
property_action = "ON"
action_service_index = 0
action = [action_node_id, action_node_property, property_action,
action_service_index]
action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
action = transform_action_node_enum(action)
action = get_new_action(action, action_dict)

View File

@@ -1,28 +1,35 @@
from __future__ import annotations
import json
from datetime import datetime
from pathlib import Path
from typing import Optional
from typing import Optional, Union
import tensorflow as tf
from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.a2c import A2CConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import UnifiedLogger
from ray.tune.registry import register_env
import tensorflow as tf
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier, \
DeepLearningFramework
from primaite.common.enums import (
AgentFramework,
AgentIdentifier,
DeepLearningFramework,
)
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
def _env_creator(env_config):
return Primaite(
training_config_path=env_config["training_config_path"],
lay_down_config_path=env_config["lay_down_config_path"],
session_path=env_config["session_path"],
timestamp_str=env_config["timestamp_str"]
timestamp_str=env_config["timestamp_str"],
)
@@ -37,16 +44,15 @@ def _custom_log_creator(session_path: Path):
class RLlibAgent(AgentSessionABC):
"""An AgentSession class that implements a Ray RLlib agent."""
def __init__(
self,
training_config_path,
lay_down_config_path
):
def __init__(self, training_config_path, lay_down_config_path):
super().__init__(training_config_path, lay_down_config_path)
if not self._training_config.agent_framework == AgentFramework.RLLIB:
msg = (f"Expected RLLIB agent_framework, "
f"got {self._training_config.agent_framework}")
msg = (
f"Expected RLLIB agent_framework, "
f"got {self._training_config.agent_framework}"
)
_LOGGER.error(msg)
raise ValueError(msg)
if self._training_config.agent_identifier == AgentIdentifier.PPO:
@@ -54,8 +60,10 @@ class RLlibAgent(AgentSessionABC):
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
self._agent_config_class = A2CConfig
else:
msg = ("Expected PPO or A2C agent_identifier, "
f"got {self._training_config.agent_identifier.value}")
msg = (
"Expected PPO or A2C agent_identifier, "
f"got {self._training_config.agent_identifier.value}"
)
_LOGGER.error(msg)
raise ValueError(msg)
self._agent_config: PPOConfig
@@ -86,8 +94,12 @@ class RLlibAgent(AgentSessionABC):
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
metadata_dict["total_episodes"] = self._current_result["episodes_total"]
metadata_dict["total_time_steps"] = self._current_result["timesteps_total"]
metadata_dict["total_episodes"] = self._current_result[
"episodes_total"
]
metadata_dict["total_time_steps"] = self._current_result[
"timesteps_total"
]
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
@@ -106,43 +118,48 @@ class RLlibAgent(AgentSessionABC):
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str
)
timestamp_str=self.timestamp_str,
),
)
self._agent_config.training(
train_batch_size=self._training_config.num_steps
)
self._agent_config.framework(
framework="tf"
)
self._agent_config.framework(framework="tf")
self._agent_config.rollouts(
num_rollout_workers=1,
num_envs_per_worker=1,
horizon=self._training_config.num_steps
horizon=self._training_config.num_steps,
)
self._agent: Algorithm = self._agent_config.build(
logger_creator=_custom_log_creator(self.session_path)
)
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._current_result["episodes_total"]
if checkpoint_n > 0 and episode_count > 0:
if (
(episode_count % checkpoint_n == 0)
or (episode_count == self._training_config.num_episodes)
if (episode_count % checkpoint_n == 0) or (
episode_count == self._training_config.num_episodes
):
self._agent.save(self.checkpoints_path)
self._agent.save(str(self.checkpoints_path))
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Evaluate the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param kwargs: Any agent-specific key-word args to be passed.
"""
# Temporarily override train_batch_size and horizon
if time_steps:
self._agent_config.train_batch_size = time_steps
@@ -150,37 +167,53 @@ class RLlibAgent(AgentSessionABC):
if not episodes:
episodes = self._training_config.num_episodes
_LOGGER.info(f"Beginning learning for {episodes} episodes @"
f" {time_steps} time steps...")
_LOGGER.info(
f"Beginning learning for {episodes} episodes @"
f" {time_steps} time steps..."
)
for i in range(episodes):
self._current_result = self._agent.train()
self._save_checkpoint()
if self._training_config.deep_learning_framework != DeepLearningFramework.TORCH:
if (
self._training_config.deep_learning_framework
!= DeepLearningFramework.TORCH
):
policy = self._agent.get_policy()
tf.compat.v1.summary.FileWriter(
self.session_path / "ray_results",
policy.get_session().graph
self.session_path / "ray_results", policy.get_session().graph
)
super().learn()
self._agent.stop()
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Evaluate the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param kwargs: Any agent-specific key-word args to be passed.
"""
raise NotImplementedError
def _get_latest_checkpoint(self):
raise NotImplementedError
@classmethod
def load(cls):
def load(cls, path: Union[str, Path]) -> RLlibAgent:
"""Load an agent from file."""
raise NotImplementedError
def save(self):
"""Save the agent."""
raise NotImplementedError
def export(self):
"""Export the agent to transportable file format."""
raise NotImplementedError

View File

@@ -1,27 +1,30 @@
from typing import Optional
from __future__ import annotations
from pathlib import Path
from typing import Optional, Union
import numpy as np
from stable_baselines3 import PPO, A2C
from stable_baselines3 import A2C, PPO
from stable_baselines3.ppo import MlpPolicy as PPOMlp
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.common.enums import AgentIdentifier, AgentFramework
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
class SB3Agent(AgentSessionABC):
def __init__(
self,
training_config_path,
lay_down_config_path
):
"""An AgentSession class that implements a Stable Baselines3 agent."""
def __init__(self, training_config_path, lay_down_config_path):
super().__init__(training_config_path, lay_down_config_path)
if not self._training_config.agent_framework == AgentFramework.SB3:
msg = (f"Expected SB3 agent_framework, "
f"got {self._training_config.agent_framework}")
msg = (
f"Expected SB3 agent_framework, "
f"got {self._training_config.agent_framework}"
)
_LOGGER.error(msg)
raise ValueError(msg)
if self._training_config.agent_identifier == AgentIdentifier.PPO:
@@ -29,8 +32,10 @@ class SB3Agent(AgentSessionABC):
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
self._agent_class = A2C
else:
msg = ("Expected PPO or A2C agent_identifier, "
f"got {self._training_config.agent_identifier.value}")
msg = (
"Expected PPO or A2C agent_identifier, "
f"got {self._training_config.agent_identifier}"
)
_LOGGER.error(msg)
raise ValueError(msg)
@@ -52,25 +57,26 @@ class SB3Agent(AgentSessionABC):
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str
timestamp_str=self.timestamp_str,
)
self._agent = self._agent_class(
PPOMlp,
self._env,
verbose=self.output_verbose_level,
n_steps=self._training_config.num_steps,
tensorboard_log=self._tensorboard_log_path
tensorboard_log=self._tensorboard_log_path,
)
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._env.episode_count
if checkpoint_n > 0 and episode_count > 0:
if (
(episode_count % checkpoint_n == 0)
or (episode_count == self._training_config.num_episodes)
if (episode_count % checkpoint_n == 0) or (
episode_count == self._training_config.num_episodes
):
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
checkpoint_path = (
self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
)
self._agent.save(checkpoint_path)
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
@@ -78,33 +84,54 @@ class SB3Agent(AgentSessionABC):
pass
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Train the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param kwargs: Any agent-specific key-word args to be passed.
"""
if not time_steps:
time_steps = self._training_config.num_steps
if not episodes:
episodes = self._training_config.num_episodes
self.is_eval = False
_LOGGER.info(f"Beginning learning for {episodes} episodes @"
f" {time_steps} time steps...")
_LOGGER.info(
f"Beginning learning for {episodes} episodes @"
f" {time_steps} time steps..."
)
for i in range(episodes):
self._agent.learn(total_timesteps=time_steps)
self._save_checkpoint()
self._env.close()
self.close()
super().learn()
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
deterministic: bool = True,
**kwargs
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
deterministic: bool = True,
**kwargs,
):
"""
Evaluate the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param deterministic: Whether the evaluation is deterministic.
:param kwargs: Any agent-specific key-word args to be passed.
"""
if not time_steps:
time_steps = self._training_config.num_steps
@@ -116,27 +143,31 @@ class SB3Agent(AgentSessionABC):
deterministic_str = "deterministic"
else:
deterministic_str = "non-deterministic"
_LOGGER.info(f"Beginning {deterministic_str} evaluation for "
f"{episodes} episodes @ {time_steps} time steps...")
_LOGGER.info(
f"Beginning {deterministic_str} evaluation for "
f"{episodes} episodes @ {time_steps} time steps..."
)
for episode in range(episodes):
obs = self._env.reset()
for step in range(time_steps):
action, _states = self._agent.predict(
obs,
deterministic=deterministic
obs, deterministic=deterministic
)
if isinstance(action, np.ndarray):
action = np.int64(action)
obs, rewards, done, info = self._env.step(action)
_LOGGER.info(f"Finished evaluation")
super().evaluate()
@classmethod
def load(self):
def load(cls, path: Union[str, Path]) -> SB3Agent:
"""Load an agent from file."""
raise NotImplementedError
def save(self):
"""Save the agent."""
raise NotImplementedError
def export(self):
"""Export the agent to transportable file format."""
raise NotImplementedError

View File

@@ -4,9 +4,9 @@ from primaite.common.enums import (
HardwareState,
LinkStatus,
NodeHardwareAction,
NodePOLType,
NodeSoftwareAction,
SoftwareState,
NodePOLType
)
@@ -16,14 +16,17 @@ def transform_action_node_readable(action):
example:
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_node_property = NodePOLType(action[1]).name
if action_node_property == "OPERATING":
property_action = NodeHardwareAction(action[2]).name
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[
2
] <= 1:
elif (
action_node_property == "OS" or action_node_property == "SERVICE"
) and action[2] <= 1:
property_action = NodeSoftwareAction(action[2]).name
else:
property_action = "NONE"
@@ -38,6 +41,9 @@ def transform_action_acl_readable(action):
example:
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"}
action_permissions = {0: "DENY", 1: "ALLOW"}
@@ -62,6 +68,9 @@ def is_valid_node_action(action):
Does NOT consider:
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
- Node already being in that state (turning an ON node ON)
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_r = transform_action_node_readable(action)
@@ -77,7 +86,10 @@ def is_valid_node_action(action):
if node_property == "OPERATING" and node_action == "PATCHING":
# Operating State cannot PATCH
return False
if node_property != "OPERATING" and node_action not in ["NONE", "PATCHING"]:
if node_property != "OPERATING" and node_action not in [
"NONE",
"PATCHING",
]:
# Software States can only do Nothing or Patch
return False
return True
@@ -92,6 +104,9 @@ def is_valid_acl_action(action):
Does NOT consider:
- Trying to create identical rules
- Trying to create a rule which is a subset of another rule (caused by "ANY")
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_r = transform_action_acl_readable(action)
@@ -118,7 +133,12 @@ def is_valid_acl_action(action):
def is_valid_acl_action_extra(action):
"""Harsher version of valid acl actions, does not allow action."""
"""
Harsher version of valid acl actions, does not allow action.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
if is_valid_acl_action(action) is False:
return False
@@ -136,13 +156,15 @@ def is_valid_acl_action_extra(action):
return True
def transform_change_obs_readable(obs):
"""Transform list of transactions to readable list of each observation property
"""
Transform list of transactions to readable list of each observation property.
example:
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']]
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
ids = [i for i in obs[:, 0]]
operating_states = [HardwareState(i).name for i in obs[:, 1]]
@@ -151,7 +173,9 @@ def transform_change_obs_readable(obs):
for service in range(3, obs.shape[1]):
# Links bit/s don't have a service state
service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]]
service_states = [
SoftwareState(i).name if i <= 4 else i for i in obs[:, service]
]
new_obs.append(service_states)
return new_obs
@@ -159,10 +183,13 @@ def transform_change_obs_readable(obs):
def transform_obs_readable(obs):
"""
example:
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]
"""
Transform observation to readable format.
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
changed_obs = transform_change_obs_readable(obs)
new_obs = list(zip(*changed_obs))
# Convert list of tuples to list of lists
@@ -172,7 +199,12 @@ def transform_obs_readable(obs):
def convert_to_new_obs(obs, num_nodes=10):
"""Convert original gym Box observation space to new multiDiscrete observation space"""
"""
Convert original gym Box observation space to new multiDiscrete observation space.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
# Remove ID columns, remove links and flatten to MultiDiscrete observation space
new_obs = obs[:num_nodes, 1:].flatten()
return new_obs
@@ -180,7 +212,9 @@ def convert_to_new_obs(obs, num_nodes=10):
def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
"""
Convert to old observation, links filled with 0's as no information is included in new observation space
Convert to old observation.
Links filled with 0's as no information is included in new observation space.
example:
obs = array([1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1, 1])
@@ -190,13 +224,17 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
[ 3, 1, 1, 1],
...
[20, 0, 0, 0]])
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
# Convert back to more readable, original format
reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2)
# Add empty links back and add node ID back
s = np.zeros([reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1], dtype=np.int64)
s = np.zeros(
[reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1],
dtype=np.int64,
)
s[:, 0] = range(1, num_nodes + num_links + 1) # Adding ID back
s[:num_nodes, 1:] = reshaped_nodes # put values back in
new_obs = s
@@ -209,14 +247,19 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
return new_obs
def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
"""Return string describing change between two observations
def describe_obs_change(
obs1, obs2, num_nodes=10, num_links=10, num_services=1
):
"""
Return string describing change between two observations.
example:
obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]])
obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]])
output = 'ID 1: SERVICE 2 set to GOOD'
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services)
obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services)
@@ -236,20 +279,27 @@ def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
def _describe_obs_change_helper(obs_change, is_link):
""" "
Helper funcion to describe what has changed
"""
Helper funcion to describe what has changed.
example:
[ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD"
Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.'
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
# Indexes where a change has occured, not including 0th index
index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1]
index_changed = [
i for i in range(1, len(obs_change)) if obs_change[i] != -1
]
# Node pol types, Indexes >= 3 are service nodes
NodePOLTypes = [
NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) for i in index_changed
NodePOLType(i).name
if i < 3
else NodePOLType(3).name + " " + str(i - 3)
for i in index_changed
]
# Account for hardware states, software sattes and links
states = [
@@ -263,8 +313,8 @@ def _describe_obs_change_helper(obs_change, is_link):
if not is_link:
desc = f"ID {obs_change[0]}:"
for NodePOLType, state in list(zip(NodePOLTypes, states)):
desc = desc + " " + NodePOLType + " changed to " + state + "."
for node_pol_type, state in list(zip(NodePOLTypes, states)):
desc = desc + " " + node_pol_type + " changed to " + state + "."
else:
desc = f"ID {obs_change[0]}: Link traffic changed to {states[0]}."
@@ -273,12 +323,14 @@ def _describe_obs_change_helper(obs_change, is_link):
def transform_action_node_enum(action):
"""
Convert a node action from readable string format, to enumerated format
Convert a node action from readable string format, to enumerated format.
example:
[1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0]
"""
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_node_id = action[0]
action_node_property = NodePOLType[action[1]].value
@@ -291,24 +343,33 @@ def transform_action_node_enum(action):
action_service_index = action[3]
new_action = [action_node_id, action_node_property, property_action, action_service_index]
new_action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
return new_action
def transform_action_node_readable(action):
"""
Convert a node action from enumerated format to readable format
Convert a node action from enumerated format to readable format.
example:
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
"""
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_node_property = NodePOLType(action[1]).name
if action_node_property == "OPERATING":
property_action = NodeHardwareAction(action[2]).name
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1:
elif (
action_node_property == "OS" or action_node_property == "SERVICE"
) and action[2] <= 1:
property_action = NodeSoftwareAction(action[2]).name
else:
property_action = "NONE"
@@ -319,9 +380,11 @@ def transform_action_node_readable(action):
def node_action_description(action):
"""
Generate string describing a node-based action
"""
Generate string describing a node-based action.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
if isinstance(action[1], (int, np.int64)):
# transform action to readable format
action = transform_action_node_readable(action)
@@ -334,7 +397,9 @@ def node_action_description(action):
if property_action == "NONE":
return ""
if node_property == "OPERATING" or node_property == "OS":
description = f"NODE {node_id}, {node_property}, SET TO {property_action}"
description = (
f"NODE {node_id}, {node_property}, SET TO {property_action}"
)
elif node_property == "SERVICE":
description = f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}"
else:
@@ -343,34 +408,13 @@ def node_action_description(action):
return description
def transform_action_acl_readable(action):
"""
Transform an ACL action to a more readable format
example:
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
"""
action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"}
action_permissions = {0: "DENY", 1: "ALLOW"}
action_decision = action_decisions[action[0]]
action_permission = action_permissions[action[1]]
# For IPs, Ports and Protocols, 0 means any, otherwise its just an index
new_action = [action_decision, action_permission] + list(action[2:6])
for n, val in enumerate(list(action[2:6])):
if val == 0:
new_action[n + 2] = "ANY"
return new_action
def transform_action_acl_enum(action):
"""
Convert a acl action from readable string format, to enumerated format
"""
Convert acl action from readable str format, to enumerated format.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2}
action_permissions = {"DENY": 0, "ALLOW": 1}
@@ -388,8 +432,12 @@ def transform_action_acl_enum(action):
def acl_action_description(action):
"""generate string describing a acl-based action"""
"""
Generate string describing an acl-based action.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
if isinstance(action[0], (int, np.int64)):
# transform action to readable format
action = transform_action_acl_readable(action)
@@ -406,11 +454,13 @@ def acl_action_description(action):
def get_node_of_ip(ip, node_dict):
"""
Get the node ID of an IP address
Get the node ID of an IP address.
node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes)
"""
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
for node_key, node_value in node_dict.items():
node_ip = node_value.ip_address
if node_ip == ip:
@@ -418,13 +468,16 @@ def get_node_of_ip(ip, node_dict):
def is_valid_node_action(action):
"""Is the node action an actual valid action
"""Is the node action an actual valid action.
Only uses information about the action to determine if the action has an effect
Does NOT consider:
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
- Node already being in that state (turning an ON node ON)
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_r = transform_action_node_readable(action)
@@ -438,7 +491,10 @@ def is_valid_node_action(action):
if node_property == "OPERATING" and node_action == "PATCHING":
# Operating State cannot PATCH
return False
if node_property != "OPERATING" and node_action not in ["NONE", "PATCHING"]:
if node_property != "OPERATING" and node_action not in [
"NONE",
"PATCHING",
]:
# Software States can only do Nothing or Patch
return False
return True
@@ -446,13 +502,16 @@ def is_valid_node_action(action):
def is_valid_acl_action(action):
"""
Is the ACL action an actual valid action
Is the ACL action an actual valid action.
Only uses information about the action to determine if the action has an effect
Does NOT consider:
- Trying to create identical rules
- Trying to create a rule which is a subset of another rule (caused by "ANY")
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_r = transform_action_acl_readable(action)
@@ -463,7 +522,11 @@ def is_valid_acl_action(action):
if action_decision == "NONE":
return False
if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY":
if (
action_source_id == action_destination_id
and action_source_id != "ANY"
and action_destination_id != "ANY"
):
# ACL rule towards itself
return False
if action_permission == "DENY":
@@ -475,7 +538,12 @@ def is_valid_acl_action(action):
def is_valid_acl_action_extra(action):
"""Harsher version of valid acl actions, does not allow action"""
"""
Harsher version of valid acl actions, does not allow action.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
if is_valid_acl_action(action) is False:
return False
@@ -494,33 +562,17 @@ def is_valid_acl_action_extra(action):
def get_new_action(old_action, action_dict):
"""Get new action (e.g. 32) from old action e.g. [1,1,1,0]
old_action can be either node or acl action type
"""
Get new action (e.g. 32) from old action e.g. [1,1,1,0].
Old_action can be either node or acl action type
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
for key, val in action_dict.items():
if list(val) == list(old_action):
return key
# Not all possible actions are included in dict, only valid action are
# if action is not in the dict, its an invalid action so return 0
return 0
def get_action_description(action, action_dict):
"""
Get a string describing/explaining what an action is doing in words
"""
action_array = action_dict[action]
if len(action_array) == 4:
# node actions have length 4
action_description = node_action_description(action_array)
elif len(action_array) == 6:
# acl actions have length 6
action_description = acl_action_description(action_array)
else:
# Should never happen
action_description = "Unrecognised action"
return action_description

View File

@@ -13,6 +13,8 @@ import yaml
from platformdirs import PlatformDirs
from typing_extensions import Annotated
from primaite.data_viz import PlotlyTemplate
app = typer.Typer()
@@ -54,7 +56,9 @@ def logs(last_n: Annotated[int, typer.Option("-n")]):
print(re.sub(r"\n*", "", line))
_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa
_LogLevel = Enum(
"LogLevel", {k: k for k in logging._levelToName.values()}
) # noqa
@app.command()
@@ -76,11 +80,12 @@ def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None):
primaite_config = yaml.safe_load(file)
if level:
primaite_config["log_level"] = level.value
primaite_config["logging"]["log_level"] = level.value
with open(user_config_path, "w") as file:
yaml.dump(primaite_config, file)
print(f"PrimAITE Log Level: {level}")
else:
level = primaite_config["log_level"]
level = primaite_config["logging"]["log_level"]
print(f"PrimAITE Log Level: {level}")
@@ -170,16 +175,50 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None):
ldc: The lay down config file path. Optional. If no value is passed then
example default lay down config is used from:
~/primaite/config/example_config/lay_down/lay_down_config_5_data_manipulation.yaml.
~/primaite/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml.
"""
from primaite.main import run
from primaite.config.lay_down_config import dos_very_basic_config_path
from primaite.config.training_config import main_training_config_path
from primaite.config.lay_down_config import data_manipulation_config_path
from primaite.main import run
if not tc:
tc = main_training_config_path()
if not ldc:
ldc = data_manipulation_config_path()
ldc = dos_very_basic_config_path()
run(training_config_path=tc, lay_down_config_path=ldc)
@app.command()
def plotly_template(
template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None
):
"""
View or set the plotly template for Session plots.
To View, simply call: primaite plotly-template
To set, call: primaite plotly-template <desired template>
For example, to set as plotly_dark, call: primaite plotly-template PLOTLY_DARK
"""
app_dirs = PlatformDirs(appname="primaite")
app_dirs.user_config_path.mkdir(exist_ok=True, parents=True)
user_config_path = app_dirs.user_config_path / "primaite_config.yaml"
if user_config_path.exists():
with open(user_config_path, "r") as file:
primaite_config = yaml.safe_load(file)
if template:
primaite_config["session"]["outputs"]["plots"][
"template"
] = template.value
with open(user_config_path, "w") as file:
yaml.dump(primaite_config, file)
print(f"PrimAITE plotly template: {template.value}")
else:
template = primaite_config["session"]["outputs"]["plots"][
"template"
]
print(f"PrimAITE plotly template: {template}")

View File

@@ -83,6 +83,7 @@ class Protocol(Enum):
class SessionType(Enum):
"""The type of PrimAITE Session to be run."""
TRAIN = 1
"Train an agent"
EVAL = 2
@@ -93,6 +94,7 @@ class SessionType(Enum):
class VerboseLevel(IntEnum):
"""PrimAITE Session Output verbose level."""
NO_OUTPUT = 0
INFO = 1
DEBUG = 2
@@ -100,6 +102,7 @@ class VerboseLevel(IntEnum):
class AgentFramework(Enum):
"""The agent algorithm framework/package."""
CUSTOM = 0
"Custom Agent"
SB3 = 1
@@ -110,6 +113,7 @@ class AgentFramework(Enum):
class DeepLearningFramework(Enum):
"""The deep learning framework."""
TF = "tf"
"Tensorflow"
TF2 = "tf2"
@@ -120,6 +124,7 @@ class DeepLearningFramework(Enum):
class AgentIdentifier(Enum):
"""The Red Agent algo/class."""
A2C = 1
"Advantage Actor Critic"
PPO = 2
@@ -136,6 +141,7 @@ class AgentIdentifier(Enum):
class HardCodedAgentView(Enum):
"""The view the deterministic hard-coded agent has of the environment."""
BASIC = 1
"The current observation space only"
FULL = 2
@@ -144,6 +150,7 @@ class HardCodedAgentView(Enum):
class ActionType(Enum):
"""Action type enumeration."""
NODE = 0
ACL = 1
ANY = 2
@@ -151,6 +158,7 @@ class ActionType(Enum):
class ObservationType(Enum):
"""Observation type enumeration."""
BOX = 0
MULTIDISCRETE = 1
@@ -193,6 +201,7 @@ class LinkStatus(Enum):
class OutputVerboseLevel(IntEnum):
"""The Agent output verbosity level."""
NONE = 0
"No Output"
INFO = 1

View File

@@ -35,10 +35,10 @@ hard_coded_agent_view: FULL
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
action_type: ANY
# Number of episodes to run per session
num_episodes: 1000
num_episodes: 10
# Number of time_steps per episode
num_steps: 256
@@ -47,14 +47,14 @@ num_steps: 256
# Set to 0 if no checkpoints are required. Default is 10
checkpoint_every_n_episodes: 10
# Time delay between steps (for generic agents)
# Time delay (milliseconds) between steps for CUSTOM agents.
time_delay: 5
# Type of session to be run. Options are:
# "TRAIN" (Trains an agent)
# "EVAL" (Evaluates an agent)
# "TRAIN_EVAL" (Trains then evaluates an agent)
session_type: TRAIN
session_type: TRAIN_EVAL
# Environment config values
# The high value for the observation space

View File

@@ -1,20 +1,20 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
from pathlib import Path
from typing import Final, Union, Dict, Any
from typing import Any, Dict, Final, Union
import networkx
import yaml
from primaite import USERS_CONFIG_DIR, getLogger
from primaite import getLogger, USERS_CONFIG_DIR
_LOGGER = getLogger(__name__)
_EXAMPLE_LAY_DOWN: Final[
Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"
_EXAMPLE_LAY_DOWN: Final[Path] = (
USERS_CONFIG_DIR / "example_config" / "lay_down"
)
def convert_legacy_lay_down_config_dict(
legacy_config_dict: Dict[str, Any]
legacy_config_dict: Dict[str, Any]
) -> Dict[str, Any]:
"""
Convert a legacy lay down config dict to the new format.
@@ -25,10 +25,7 @@ def convert_legacy_lay_down_config_dict(
return legacy_config_dict
def load(
file_path: Union[str, Path],
legacy_file: bool = False
) -> Dict:
def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict:
"""
Read in a lay down config yaml file.

View File

@@ -7,15 +7,22 @@ from typing import Any, Dict, Final, Optional, Union
import yaml
from primaite import USERS_CONFIG_DIR, getLogger
from primaite.common.enums import DeepLearningFramework, HardCodedAgentView
from primaite.common.enums import ActionType, AgentIdentifier, \
AgentFramework, SessionType, OutputVerboseLevel
from primaite import getLogger, USERS_CONFIG_DIR
from primaite.common.enums import (
ActionType,
AgentFramework,
AgentIdentifier,
DeepLearningFramework,
HardCodedAgentView,
OutputVerboseLevel,
SessionType,
)
_LOGGER = getLogger(__name__)
_EXAMPLE_TRAINING: Final[
Path] = USERS_CONFIG_DIR / "example_config" / "training"
_EXAMPLE_TRAINING: Final[Path] = (
USERS_CONFIG_DIR / "example_config" / "training"
)
def main_training_config_path() -> Path:
@@ -36,6 +43,7 @@ def main_training_config_path() -> Path:
@dataclass()
class TrainingConfig:
"""The Training Config class."""
agent_framework: AgentFramework = AgentFramework.SB3
"The AgentFramework"
@@ -171,12 +179,16 @@ class TrainingConfig:
file_system_scanning_limit: int = 5
"The time taken to scan the file system"
@classmethod
def from_dict(
cls,
config_dict: Dict[str, Union[str, int, bool]]
cls, config_dict: Dict[str, Union[str, int, bool]]
) -> TrainingConfig:
"""
Create an instance of TrainingConfig from a dict.
:param config_dict: The training config dict.
:return: The instance of TrainingConfig.
"""
field_enum_map = {
"agent_framework": AgentFramework,
"deep_learning_framework": DeepLearningFramework,
@@ -187,9 +199,9 @@ class TrainingConfig:
"hard_coded_agent_view": HardCodedAgentView,
}
for field, enum_class in field_enum_map.items():
if field in config_dict:
config_dict[field] = enum_class[config_dict[field]]
for key, value in field_enum_map.items():
if key in config_dict:
config_dict[key] = value[config_dict[key]]
return TrainingConfig(**config_dict)
def to_dict(self, json_serializable: bool = True):
@@ -213,23 +225,21 @@ class TrainingConfig:
return data
def __str__(self) -> str:
tc = f"TrainingConfig(agent_framework={self.agent_framework.name}, "
tc = f"{self.agent_framework}, "
if self.agent_framework is AgentFramework.RLLIB:
tc += f"deep_learning_framework=" \
f"{self.deep_learning_framework.name}, "
tc += f"agent_identifier={self.agent_identifier.name}, "
tc += f"{self.deep_learning_framework}, "
tc += f"{self.agent_identifier}, "
if self.agent_identifier is AgentIdentifier.HARDCODED:
tc += f"hard_coded_agent_view={self.hard_coded_agent_view.name}, "
tc += f"action_type={self.action_type.name}, "
tc += f"{self.hard_coded_agent_view}, "
tc += f"{self.action_type}, "
tc += f"observation_space={self.observation_space}, "
tc += f"num_episodes={self.num_episodes}, "
tc += f"num_steps={self.num_steps})"
tc += f"{self.num_episodes} episodes @ "
tc += f"{self.num_steps} steps"
return tc
def load(
file_path: Union[str, Path],
legacy_file: bool = False
file_path: Union[str, Path], legacy_file: bool = False
) -> TrainingConfig:
"""
Read in a training config yaml file.
@@ -273,12 +283,12 @@ def load(
def convert_legacy_training_config_dict(
legacy_config_dict: Dict[str, Any],
agent_framework: AgentFramework = AgentFramework.SB3,
agent_identifier: AgentIdentifier = AgentIdentifier.PPO,
action_type: ActionType = ActionType.ANY,
num_steps: int = 256,
output_verbose_level: OutputVerboseLevel = OutputVerboseLevel.INFO
legacy_config_dict: Dict[str, Any],
agent_framework: AgentFramework = AgentFramework.SB3,
agent_identifier: AgentIdentifier = AgentIdentifier.PPO,
action_type: ActionType = ActionType.ANY,
num_steps: int = 256,
output_verbose_level: OutputVerboseLevel = OutputVerboseLevel.INFO,
) -> Dict[str, Any]:
"""
Convert a legacy training config dict to the new format.
@@ -301,8 +311,12 @@ def convert_legacy_training_config_dict(
"agent_identifier": agent_identifier.name,
"action_type": action_type.name,
"num_steps": num_steps,
"output_verbose_level": output_verbose_level
"output_verbose_level": output_verbose_level.name,
}
session_type_map = {"TRAINING": "TRAIN", "EVALUATION": "EVAL"}
legacy_config_dict["sessionType"] = session_type_map[
legacy_config_dict["sessionType"]
]
for legacy_key, value in legacy_config_dict.items():
new_key = _get_new_key_from_legacy(legacy_key)
if new_key:

View File

@@ -0,0 +1,13 @@
from enum import Enum
class PlotlyTemplate(Enum):
"""The built-in plotly templates."""
PLOTLY = "plotly"
PLOTLY_WHITE = "plotly_white"
PLOTLY_DARK = "plotly_dark"
GGPLOT2 = "ggplot2"
SEABORN = "seaborn"
SIMPLE_WHITE = "simple_white"
NONE = "none"

View File

@@ -0,0 +1,73 @@
from pathlib import Path
from typing import Dict, Optional, Union
import plotly.graph_objects as go
import polars as pl
import yaml
from plotly.graph_objs import Figure
from primaite import _PLATFORM_DIRS
def _get_plotly_config() -> Dict:
"""Get the plotly config from primaite_config.yaml."""
user_config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml"
with open(user_config_path, "r") as file:
primaite_config = yaml.safe_load(file)
return primaite_config["session"]["outputs"]["plots"]
def plot_av_reward_per_episode(
av_reward_per_episode_csv: Union[str, Path],
title: Optional[str] = None,
subtitle: Optional[str] = None,
) -> Figure:
"""
Plot the average reward per episode from a csv session output.
:param av_reward_per_episode_csv: The average reward per episode csv
file path.
:param title: The plot title. This is optional.
:param subtitle: The plot subtitle. This is optional.
:return: The plot as an instance of ``plotly.graph_objs._figure.Figure``.
"""
df = pl.read_csv(av_reward_per_episode_csv)
if title:
if subtitle:
title = f"{title} <br>{subtitle}</sup>"
else:
if subtitle:
title = subtitle
config = _get_plotly_config()
layout = go.Layout(
autosize=config["size"]["auto_size"],
width=config["size"]["width"],
height=config["size"]["height"],
)
# Create the line graph with a colored line
fig = go.Figure(layout=layout)
fig.update_layout(template=config["template"])
fig.add_trace(
go.Scatter(
x=df["Episode"],
y=df["Average Reward"],
mode="lines",
name="Mean Reward per Episode",
)
)
# Set the layout of the graph
fig.update_layout(
xaxis={
"title": "Episode",
"type": "linear",
"rangeslider": {"visible": config["range_slider"]},
},
yaxis={"title": "Average Reward"},
title=title,
showlegend=False,
)
return fig

View File

@@ -1,7 +1,7 @@
"""Module for handling configurable observation spaces in PrimAITE."""
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Dict, Final, List, Tuple, Union
from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union
import numpy as np
from gym import spaces
@@ -77,7 +77,9 @@ class NodeLinkTable(AbstractObservationComponent):
)
# 3. Initialise Observation with zeroes
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
self.current_observation = np.zeros(
observation_shape, dtype=self._DATA_TYPE
)
def update(self):
"""Update the observation based on current environment state.
@@ -92,7 +94,9 @@ class NodeLinkTable(AbstractObservationComponent):
self.current_observation[item_index][0] = int(node.node_id)
self.current_observation[item_index][1] = node.hardware_state.value
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
self.current_observation[item_index][2] = node.software_state.value
self.current_observation[item_index][
2
] = node.software_state.value
self.current_observation[item_index][
3
] = node.file_system_state_observed.value
@@ -199,9 +203,16 @@ class NodeStatuses(AbstractObservationComponent):
if isinstance(node, ServiceNode):
for i, service in enumerate(self.env.services_list):
if node.has_service(service):
service_states[i] = node.get_service_state(service).value
service_states[i] = node.get_service_state(
service
).value
obs.extend(
[hardware_state, software_state, file_system_state, *service_states]
[
hardware_state,
software_state,
file_system_state,
*service_states,
]
)
self.current_observation[:] = obs
@@ -259,7 +270,9 @@ class LinkTrafficLevels(AbstractObservationComponent):
# 1. Define the shape of your observation space component
shape = (
[self._quantisation_levels] * self.env.num_links * self._entries_per_link
[self._quantisation_levels]
* self.env.num_links
* self._entries_per_link
)
# 2. Create Observation space
@@ -279,7 +292,9 @@ class LinkTrafficLevels(AbstractObservationComponent):
if self._combine_service_traffic:
loads = [link.get_current_load()]
else:
loads = [protocol.get_load() for protocol in link.protocol_list]
loads = [
protocol.get_load() for protocol in link.protocol_list
]
for load in loads:
if load <= 0:

View File

@@ -2,7 +2,7 @@
"""Main environment module containing the PRIMmary AI Training Evironment (Primaite) class."""
import copy
from pathlib import Path
from typing import Dict, Tuple, Union, Final
from typing import Dict, Final, Tuple, Union
import networkx as nx
import numpy as np
@@ -12,8 +12,10 @@ from matplotlib import pyplot as plt
from primaite import getLogger
from primaite.acl.access_control_list import AccessControlList
from primaite.agents.utils import is_valid_acl_action_extra, \
is_valid_node_action
from primaite.agents.utils import (
is_valid_acl_action_extra,
is_valid_node_action,
)
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
ActionType,
@@ -24,7 +26,8 @@ from primaite.common.enums import (
NodeType,
ObservationType,
Priority,
SoftwareState, SessionType,
SessionType,
SoftwareState,
)
from primaite.common.service import Service
from primaite.config import training_config
@@ -34,15 +37,18 @@ from primaite.environment.reward import calculate_reward_function
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.node import Node
from primaite.nodes.node_state_instruction_green import \
NodeStateInstructionGreen
from primaite.nodes.node_state_instruction_green import (
NodeStateInstructionGreen,
)
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from primaite.nodes.passive_node import PassiveNode
from primaite.nodes.service_node import ServiceNode
from primaite.pol.green_pol import apply_iers, apply_node_pol
from primaite.pol.ier import IER
from primaite.pol.red_agent_pol import apply_red_agent_iers, \
apply_red_agent_node_pol
from primaite.pol.red_agent_pol import (
apply_red_agent_iers,
apply_red_agent_node_pol,
)
from primaite.transactions.transaction import Transaction
from primaite.utils.session_output_writer import SessionOutputWriter
@@ -59,11 +65,11 @@ class Primaite(Env):
ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2
def __init__(
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
session_path: Path,
timestamp_str: str,
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
session_path: Path,
timestamp_str: str,
):
"""
The Primaite constructor.
@@ -237,27 +243,19 @@ class Primaite(Env):
)
self.episode_av_reward_writer = SessionOutputWriter(
self,
transaction_writer=False,
learning_session=True
self, transaction_writer=False, learning_session=True
)
self.transaction_writer = SessionOutputWriter(
self,
transaction_writer=True,
learning_session=True
self, transaction_writer=True, learning_session=True
)
def set_as_eval(self):
"""Set the writers to write to eval directories."""
self.episode_av_reward_writer = SessionOutputWriter(
self,
transaction_writer=False,
learning_session=False
self, transaction_writer=False, learning_session=False
)
self.transaction_writer = SessionOutputWriter(
self,
transaction_writer=True,
learning_session=False
self, transaction_writer=True, learning_session=False
)
self.episode_count = 0
self.step_count = 0
@@ -322,9 +320,7 @@ class Primaite(Env):
# Create a Transaction (metric) object for this step
transaction = Transaction(
self.agent_identifier,
self.episode_count,
self.step_count
self.agent_identifier, self.episode_count, self.step_count
)
# Load the initial observation space into the transaction
transaction.obs_space_pre = copy.deepcopy(self.env_obs)
@@ -354,8 +350,9 @@ class Primaite(Env):
self.nodes_post_pol = copy.deepcopy(self.nodes)
self.links_post_pol = copy.deepcopy(self.links)
# Reference
apply_node_pol(self.nodes_reference, self.node_pol,
self.step_count) # Node PoL
apply_node_pol(
self.nodes_reference, self.node_pol, self.step_count
) # Node PoL
apply_iers(
self.network_reference,
self.nodes_reference,
@@ -404,8 +401,10 @@ class Primaite(Env):
# For evaluation, need to trigger the done value = True when
# step count is reached in order to prevent neverending episode
done = True
_LOGGER.info(f"Episode: {self.episode_count}, "
f"Average Reward: {self.average_reward}")
_LOGGER.info(
f"Episode: {self.episode_count}, "
f"Average Reward: {self.average_reward}"
)
# Load the reward into the transaction
transaction.reward = reward
@@ -452,11 +451,11 @@ class Primaite(Env):
elif self.training_config.action_type == ActionType.ACL:
self.apply_actions_to_acl(_action)
elif (
len(self.action_dict[_action]) == 6
len(self.action_dict[_action]) == 6
): # ACL actions in multidiscrete form have len 6
self.apply_actions_to_acl(_action)
elif (
len(self.action_dict[_action]) == 4
len(self.action_dict[_action]) == 4
): # Node actions in multdiscrete (array) from have len 4
self.apply_actions_to_nodes(_action)
else:
@@ -525,7 +524,7 @@ class Primaite(Env):
# Patch (valid action if it's good or compromised)
node.set_service_state(
self.services_list[service_index],
SoftwareState.PATCHING
SoftwareState.PATCHING,
)
else:
# Node is not of Service Type
@@ -542,7 +541,10 @@ class Primaite(Env):
elif property_action == 2:
# Repair
# You cannot repair a destroyed file system - it needs restoring
if node.file_system_state_actual != FileSystemState.DESTROYED:
if (
node.file_system_state_actual
!= FileSystemState.DESTROYED
):
node.set_file_system_state(FileSystemState.REPAIRING)
elif property_action == 3:
# Restore
@@ -585,8 +587,9 @@ class Primaite(Env):
acl_rule_source = "ANY"
else:
node = list(self.nodes.values())[action_source_ip - 1]
if isinstance(node, ServiceNode) or isinstance(node,
ActiveNode):
if isinstance(node, ServiceNode) or isinstance(
node, ActiveNode
):
acl_rule_source = node.ip_address
else:
return
@@ -595,8 +598,9 @@ class Primaite(Env):
acl_rule_destination = "ANY"
else:
node = list(self.nodes.values())[action_destination_ip - 1]
if isinstance(node, ServiceNode) or isinstance(node,
ActiveNode):
if isinstance(node, ServiceNode) or isinstance(
node, ActiveNode
):
acl_rule_destination = node.ip_address
else:
return
@@ -681,8 +685,9 @@ class Primaite(Env):
:return: The observation space, initial observation (zeroed out array with the correct shape)
:rtype: Tuple[spaces.Space, np.ndarray]
"""
self.obs_handler = ObservationsHandler.from_config(self,
self.obs_config)
self.obs_handler = ObservationsHandler.from_config(
self, self.obs_config
)
return self.obs_handler.space, self.obs_handler.current_observation
@@ -790,7 +795,8 @@ class Primaite(Env):
service_port = service["port"]
service_state = SoftwareState[service["state"]]
node.add_service(
Service(service_protocol, service_port, service_state))
Service(service_protocol, service_port, service_state)
)
else:
# Bad formatting
pass
@@ -843,8 +849,9 @@ class Primaite(Env):
dest_node_ref: Node = self.nodes_reference[link_destination]
# Add link to network (reference)
self.network_reference.add_edge(source_node_ref, dest_node_ref,
id=link_name)
self.network_reference.add_edge(
source_node_ref, dest_node_ref, id=link_name
)
# Add link to link dictionary (reference)
self.links_reference[link_name] = Link(
@@ -1120,7 +1127,8 @@ class Primaite(Env):
node_id = item["node_id"]
node_class = item["node_class"]
node_hardware_state: HardwareState = HardwareState[
item["hardware_state"]]
item["hardware_state"]
]
node: NodeUnion = self.nodes[node_id]
node_ref = self.nodes_reference[node_id]
@@ -1186,8 +1194,12 @@ class Primaite(Env):
# Use MAX to ensure we get them all
for node_action in range(4):
for service_state in range(self.num_services):
action = [node, node_property, node_action,
service_state]
action = [
node,
node_property,
node_action,
service_state,
]
# check to see if it's a nothing action (has no effect)
if is_valid_node_action(action):
actions[action_key] = action

View File

@@ -46,7 +46,9 @@ def calculate_reward_function(
)
# Software State
if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode):
if isinstance(final_node, ActiveNode) or isinstance(
final_node, ServiceNode
):
reward_value += score_node_os_state(
final_node, initial_node, reference_node, config_values
)
@@ -81,7 +83,8 @@ def calculate_reward_function(
reference_blocked = not reference_ier.get_is_running()
live_blocked = not ier_value.get_is_running()
ier_reward = (
config_values.green_ier_blocked * ier_value.get_mission_criticality()
config_values.green_ier_blocked
* ier_value.get_mission_criticality()
)
if live_blocked and not reference_blocked:
@@ -104,7 +107,9 @@ def calculate_reward_function(
return reward_value
def score_node_operating_state(final_node, initial_node, reference_node, config_values):
def score_node_operating_state(
final_node, initial_node, reference_node, config_values
):
"""
Calculates score relating to the hardware state of a node.
@@ -153,7 +158,9 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
return score
def score_node_os_state(final_node, initial_node, reference_node, config_values):
def score_node_os_state(
final_node, initial_node, reference_node, config_values
):
"""
Calculates score relating to the Software State of a node.
@@ -204,7 +211,9 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
return score
def score_node_service_state(final_node, initial_node, reference_node, config_values):
def score_node_service_state(
final_node, initial_node, reference_node, config_values
):
"""
Calculates score relating to the service state(s) of a node.
@@ -276,7 +285,9 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
return score
def score_node_file_system(final_node, initial_node, reference_node, config_values):
def score_node_file_system(
final_node, initial_node, reference_node, config_values
):
"""
Calculates score relating to the file system state of a node.

View File

@@ -8,7 +8,9 @@ from primaite.common.protocol import Protocol
class Link(object):
"""Link class."""
def __init__(self, _id, _bandwidth, _source_node_name, _dest_node_name, _services):
def __init__(
self, _id, _bandwidth, _source_node_name, _dest_node_name, _services
):
"""
Init.

View File

@@ -10,7 +10,10 @@ from primaite.primaite_session import PrimaiteSession
_LOGGER = getLogger(__name__)
def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]):
def run(
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
):
"""Run the PrimAITE Session.
:param training_config_path: The training config filepath.

View File

@@ -87,7 +87,9 @@ class ActiveNode(Node):
f"Node.software_state:{self._software_state}"
)
def set_software_state_if_not_compromised(self, software_state: SoftwareState):
def set_software_state_if_not_compromised(
self, software_state: SoftwareState
):
"""
Sets Software State if the node is not compromised.
@@ -98,7 +100,9 @@ class ActiveNode(Node):
if self._software_state != SoftwareState.COMPROMISED:
self._software_state = software_state
if software_state == SoftwareState.PATCHING:
self.patching_count = self.config_values.os_patching_duration
self.patching_count = (
self.config_values.os_patching_duration
)
else:
_LOGGER.info(
f"The Nodes hardware state is OFF so OS State cannot be changed."
@@ -187,7 +191,9 @@ class ActiveNode(Node):
def start_file_system_scan(self):
"""Starts a file system scan."""
self.file_system_scanning = True
self.file_system_scanning_count = self.config_values.file_system_scanning_limit
self.file_system_scanning_count = (
self.config_values.file_system_scanning_limit
)
def update_file_system_state(self):
"""Updates file system status based on scanning/restore/repair cycle."""
@@ -206,7 +212,10 @@ class ActiveNode(Node):
self.file_system_state_observed = FileSystemState.GOOD
# Scanning updates
if self.file_system_scanning == True and self.file_system_scanning_count < 0:
if (
self.file_system_scanning == True
and self.file_system_scanning_count < 0
):
self.file_system_state_observed = self.file_system_state_actual
self.file_system_scanning = False
self.file_system_scanning_count = 0

View File

@@ -32,7 +32,9 @@ class NodeStateInstructionGreen(object):
self.end_step = _end_step
self.node_id = _node_id
self.node_pol_type = _node_pol_type
self.service_name = _service_name # Not used when not a service instruction
self.service_name = (
_service_name # Not used when not a service instruction
)
self.state = _state
def get_start_step(self):

View File

@@ -42,7 +42,9 @@ class NodeStateInstructionRed(object):
self.target_node_id = _target_node_id
self.initiator = _pol_initiator
self.pol_type: NodePOLType = _pol_type
self.service_name = pol_protocol # Not used when not a service instruction
self.service_name = (
pol_protocol # Not used when not a service instruction
)
self.state = _pol_state
self.source_node_id = _pol_source_node_id
self.source_node_service = _pol_source_node_service

View File

@@ -110,7 +110,9 @@ class ServiceNode(ActiveNode):
return False
return False
def set_service_state(self, protocol_name: str, software_state: SoftwareState):
def set_service_state(
self, protocol_name: str, software_state: SoftwareState
):
"""
Sets the software_state of a service (protocol) on the node.

View File

@@ -4,7 +4,7 @@ import os
import subprocess
import sys
from primaite import NOTEBOOKS_DIR, getLogger
from primaite import getLogger, NOTEBOOKS_DIR
_LOGGER = getLogger(__name__)

View File

@@ -6,10 +6,17 @@ from networkx import MultiGraph, shortest_path
from primaite.acl.access_control_list import AccessControlList
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import HardwareState, NodePOLType, NodeType, SoftwareState
from primaite.common.enums import (
HardwareState,
NodePOLType,
NodeType,
SoftwareState,
)
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
from primaite.nodes.node_state_instruction_green import (
NodeStateInstructionGreen,
)
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from primaite.nodes.service_node import ServiceNode
from primaite.pol.ier import IER
@@ -190,7 +197,9 @@ def apply_iers(
link_id = edge_dict[0].get("id")
link = links[link_id]
# Check whether the new load exceeds the bandwidth
if (link.get_current_load() + load) > link.get_bandwidth():
if (
link.get_current_load() + load
) > link.get_bandwidth():
link_capacity_exceeded = True
if _VERBOSE:
print("Link capacity exceeded")
@@ -204,7 +213,8 @@ def apply_iers(
while count < path_node_list_length - 1:
# Get the link between the next two nodes
edge_dict = network.get_edge_data(
path_node_list[count], path_node_list[count + 1]
path_node_list[count],
path_node_list[count + 1],
)
link_id = edge_dict[0].get("id")
link = links[link_id]
@@ -216,7 +226,9 @@ def apply_iers(
else:
# One of the nodes is not operational
if _VERBOSE:
print("Path not valid - one or more nodes not operational")
print(
"Path not valid - one or more nodes not operational"
)
pass
else:
@@ -231,7 +243,9 @@ def apply_iers(
def apply_node_pol(
nodes: Dict[str, NodeUnion],
node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]],
node_pol: Dict[
any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]
],
step: int,
):
"""
@@ -263,16 +277,22 @@ def apply_node_pol(
elif node_pol_type == NodePOLType.OS:
# Change OS state
# Don't allow PoL to fix something that is compromised. Only the Blue agent can do this
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
if isinstance(node, ActiveNode) or isinstance(
node, ServiceNode
):
node.set_software_state_if_not_compromised(state)
elif node_pol_type == NodePOLType.SERVICE:
# Change a service state
# Don't allow PoL to fix something that is compromised. Only the Blue agent can do this
if isinstance(node, ServiceNode):
node.set_service_state_if_not_compromised(service_name, state)
node.set_service_state_if_not_compromised(
service_name, state
)
else:
# Change the file system status
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
if isinstance(node, ActiveNode) or isinstance(
node, ServiceNode
):
node.set_file_system_state_if_not_compromised(state)
else:
# PoL is not valid in this time step

View File

@@ -176,7 +176,9 @@ def apply_red_agent_iers(
link_id = edge_dict[0].get("id")
link = links[link_id]
# Check whether the new load exceeds the bandwidth
if (link.get_current_load() + load) > link.get_bandwidth():
if (
link.get_current_load() + load
) > link.get_bandwidth():
link_capacity_exceeded = True
if _VERBOSE:
print("Link capacity exceeded")
@@ -190,7 +192,8 @@ def apply_red_agent_iers(
while count < path_node_list_length - 1:
# Get the link between the next two nodes
edge_dict = network.get_edge_data(
path_node_list[count], path_node_list[count + 1]
path_node_list[count],
path_node_list[count + 1],
)
link_id = edge_dict[0].get("id")
link = links[link_id]
@@ -200,16 +203,23 @@ def apply_red_agent_iers(
# This IER is now valid, so set it to running
ier_value.set_is_running(True)
if _VERBOSE:
print("Red IER was allowed to run in step " + str(step))
print(
"Red IER was allowed to run in step "
+ str(step)
)
else:
# One of the nodes is not operational
if _VERBOSE:
print("Path not valid - one or more nodes not operational")
print(
"Path not valid - one or more nodes not operational"
)
pass
else:
if _VERBOSE:
print("Red IER was NOT allowed to run in step " + str(step))
print(
"Red IER was NOT allowed to run in step " + str(step)
)
print("Source, Dest or ACL were not valid")
pass
# ------------------------------------
@@ -264,7 +274,9 @@ def apply_red_agent_node_pol(
passed_checks = True
elif initiator == NodePOLInitiator.IER:
# Need to check there is a red IER incoming
passed_checks = is_red_ier_incoming(target_node, iers, pol_type)
passed_checks = is_red_ier_incoming(
target_node, iers, pol_type
)
elif initiator == NodePOLInitiator.SERVICE:
# Need to check the condition of a service on another node
source_node = nodes[source_node_id]
@@ -308,7 +320,9 @@ def apply_red_agent_node_pol(
target_node.set_file_system_state(state)
else:
if _VERBOSE:
print("Node Red Agent PoL not allowed - did not pass checks")
print(
"Node Red Agent PoL not allowed - did not pass checks"
)
else:
# PoL is not valid in this time step
pass
@@ -323,7 +337,10 @@ def is_red_ier_incoming(node, iers, node_pol_type):
node_id = node.node_id
for ier_key, ier_value in iers.items():
if ier_value.get_is_running() and ier_value.get_dest_node_id() == node_id:
if (
ier_value.get_is_running()
and ier_value.get_dest_node_id() == node_id
):
if (
node_pol_type == NodePOLType.OPERATING
or node_pol_type == NodePOLType.OS

View File

@@ -1,54 +1,51 @@
from __future__ import annotations
import json
from datetime import datetime
from pathlib import Path
from typing import Final, Optional, Union, Dict
from uuid import uuid4
from typing import Dict, Final, Optional, Union
from primaite import getLogger, SESSIONS_DIR
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.agents.hardcoded_acl import HardCodedACLAgent
from primaite.agents.hardcoded_node import HardCodedNodeAgent
from primaite.agents.rllib import RLlibAgent
from primaite.agents.sb3 import SB3Agent
from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, \
RandomAgent, DummyAgent
from primaite.common.enums import AgentFramework, AgentIdentifier, \
ActionType, SessionType
from primaite.agents.simple import (
DoNothingACLAgent,
DoNothingNodeAgent,
DummyAgent,
RandomAgent,
)
from primaite.common.enums import (
ActionType,
AgentFramework,
AgentIdentifier,
SessionType,
)
from primaite.config import lay_down_config, training_config
from primaite.config.training_config import TrainingConfig
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
def _get_session_path(session_timestamp: datetime) -> Path:
"""
Get the directory path the session will output to.
This is set in the format of:
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
:param session_timestamp: This is the datetime that the session started.
:return: The session directory path.
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = SESSIONS_DIR / date_dir / session_path
session_path.mkdir(exist_ok=True, parents=True)
_LOGGER.debug(f"Created PrimAITE Session path: {session_path}")
return session_path
class PrimaiteSession:
"""
The PrimaiteSession class.
Provides a single learning and evaluation entry point for all training
and lay down configurations.
"""
def __init__(
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path]
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
):
"""
The PrimaiteSession constructor.
:param training_config_path: The training config path.
:param lay_down_config_path: The lay down config path.
"""
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path]] = training_config_path
@@ -64,22 +61,35 @@ class PrimaiteSession:
)
self._agent_session: AgentSessionABC = None # noqa
self.session_path: Path = None # noqa
self.timestamp_str: str = None # noqa
self.learning_path: Path = None # noqa
self.evaluation_path: Path = None # noqa
def setup(self):
"""Performs the session setup."""
if self._training_config.agent_framework == AgentFramework.CUSTOM:
if self._training_config.agent_identifier == AgentIdentifier.HARDCODED:
_LOGGER.debug(
f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}"
)
if (
self._training_config.agent_identifier
== AgentIdentifier.HARDCODED
):
_LOGGER.debug(
f"PrimaiteSession Setup: Agent Identifier ="
f" {AgentIdentifier.HARDCODED}"
)
if self._training_config.action_type == ActionType.NODE:
# Deterministic Hardcoded Agent with Node Action Space
self._agent_session = HardCodedNodeAgent(
self._training_config_path,
self._lay_down_config_path
self._training_config_path, self._lay_down_config_path
)
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
self._agent_session = HardCodedACLAgent(
self._training_config_path,
self._lay_down_config_path
self._training_config_path, self._lay_down_config_path
)
elif self._training_config.action_type == ActionType.ANY:
@@ -90,18 +100,23 @@ class PrimaiteSession:
# Invalid AgentIdentifier ActionType combo
raise ValueError
elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING:
elif (
self._training_config.agent_identifier
== AgentIdentifier.DO_NOTHING
):
_LOGGER.debug(
f"PrimaiteSession Setup: Agent Identifier ="
f" {AgentIdentifier.DO_NOTHINGD}"
)
if self._training_config.action_type == ActionType.NODE:
self._agent_session = DoNothingNodeAgent(
self._training_config_path,
self._lay_down_config_path
self._training_config_path, self._lay_down_config_path
)
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
self._agent_session = DoNothingACLAgent(
self._training_config_path,
self._lay_down_config_path
self._training_config_path, self._lay_down_config_path
)
elif self._training_config.action_type == ActionType.ANY:
@@ -112,15 +127,26 @@ class PrimaiteSession:
# Invalid AgentIdentifier ActionType combo
raise ValueError
elif self._training_config.agent_identifier == AgentIdentifier.RANDOM:
self._agent_session = RandomAgent(
self._training_config_path,
self._lay_down_config_path
elif (
self._training_config.agent_identifier
== AgentIdentifier.RANDOM
):
_LOGGER.debug(
f"PrimaiteSession Setup: Agent Identifier ="
f" {AgentIdentifier.RANDOM}"
)
self._agent_session = RandomAgent(
self._training_config_path, self._lay_down_config_path
)
elif (
self._training_config.agent_identifier == AgentIdentifier.DUMMY
):
_LOGGER.debug(
f"PrimaiteSession Setup: Agent Identifier ="
f" {AgentIdentifier.DUMMY}"
)
elif self._training_config.agent_identifier == AgentIdentifier.DUMMY:
self._agent_session = DummyAgent(
self._training_config_path,
self._lay_down_config_path
self._training_config_path, self._lay_down_config_path
)
else:
@@ -128,37 +154,64 @@ class PrimaiteSession:
raise ValueError
elif self._training_config.agent_framework == AgentFramework.SB3:
_LOGGER.debug(
f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}"
)
# Stable Baselines3 Agent
self._agent_session = SB3Agent(
self._training_config_path,
self._lay_down_config_path
self._training_config_path, self._lay_down_config_path
)
elif self._training_config.agent_framework == AgentFramework.RLLIB:
_LOGGER.debug(
f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}"
)
# Ray RLlib Agent
self._agent_session = RLlibAgent(
self._training_config_path,
self._lay_down_config_path
self._training_config_path, self._lay_down_config_path
)
else:
# Invalid AgentFramework
raise ValueError
self.session_path: Path = self._agent_session.session_path
self.timestamp_str: str = self._agent_session.timestamp_str
self.learning_path: Path = self._agent_session.learning_path
self.evaluation_path: Path = self._agent_session.evaluation_path
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Train the agent.
:param time_steps: The number of time steps per episode.
:param episodes: The number of episodes.
:param kwargs: Any agent-framework specific key word args.
"""
if not self._training_config.session_type == SessionType.EVAL:
self._agent_session.learn(time_steps, episodes, **kwargs)
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Evaluate the agent.
:param time_steps: The number of time steps per episode.
:param episodes: The number of episodes.
:param kwargs: Any agent-framework specific key word args.
"""
if not self._training_config.session_type == SessionType.TRAIN:
self._agent_session.evaluate(time_steps, episodes, **kwargs)
def close(self):
"""Closes the agent."""
self._agent_session.close()

View File

@@ -9,3 +9,14 @@ logging:
WARNING: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
ERROR: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
CRITICAL: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
# Session
session:
outputs:
plots:
size:
auto_size: false
width: 1500
height: 900
template: plotly_white
range_slider: false

View File

@@ -6,7 +6,7 @@ from pathlib import Path
import pkg_resources
from primaite import NOTEBOOKS_DIR, getLogger
from primaite import getLogger, NOTEBOOKS_DIR
_LOGGER = getLogger(__name__)
@@ -24,7 +24,9 @@ def run(overwrite_existing: bool = True):
for subdir, dirs, files in os.walk(notebooks_package_data_root):
for file in files:
fp = os.path.join(subdir, file)
path_split = os.path.relpath(fp, notebooks_package_data_root).split(os.sep)
path_split = os.path.relpath(
fp, notebooks_package_data_root
).split(os.sep)
target_fp = NOTEBOOKS_DIR / Path(*path_split)
target_fp.parent.mkdir(exist_ok=True, parents=True)
copy_file = not target_fp.is_file()

View File

@@ -5,7 +5,7 @@ from pathlib import Path
import pkg_resources
from primaite import USERS_CONFIG_DIR, getLogger
from primaite import getLogger, USERS_CONFIG_DIR
_LOGGER = getLogger(__name__)
@@ -24,7 +24,9 @@ def run(overwrite_existing=True):
for subdir, dirs, files in os.walk(configs_package_data_root):
for file in files:
fp = os.path.join(subdir, file)
path_split = os.path.relpath(fp, configs_package_data_root).split(os.sep)
path_split = os.path.relpath(fp, configs_package_data_root).split(
os.sep
)
target_fp = USERS_CONFIG_DIR / "example_config" / Path(*path_split)
target_fp.parent.mkdir(exist_ok=True, parents=True)
copy_file = not target_fp.is_file()

View File

@@ -1,5 +1,5 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
from primaite import _USER_DIRS, LOG_DIR, NOTEBOOKS_DIR, getLogger
from primaite import _USER_DIRS, getLogger, LOG_DIR, NOTEBOOKS_DIR
_LOGGER = getLogger(__name__)

View File

@@ -7,12 +7,7 @@ from typing import List, Tuple
class Transaction(object):
"""Transaction class."""
def __init__(
self,
agent_identifier,
episode_number,
step_number
):
def __init__(self, agent_identifier, episode_number, step_number):
"""
Transaction constructor.
@@ -37,6 +32,11 @@ class Transaction(object):
"The action space invoked by the agent"
def as_csv_data(self) -> Tuple[List, List]:
"""
Converts the Transaction to a csv data row and provides a header.
:return: A tuple consisting of (header, data).
"""
if isinstance(self.action_space, int):
action_length = self.action_space
else:
@@ -74,12 +74,14 @@ class Transaction(object):
str(self.reward),
]
row = (
row
+ _turn_action_space_to_array(self.action_space)
+ _turn_obs_space_to_array(self.obs_space_pre, obs_assets,
obs_features)
+ _turn_obs_space_to_array(self.obs_space_post, obs_assets,
obs_features)
row
+ _turn_action_space_to_array(self.action_space)
+ _turn_obs_space_to_array(
self.obs_space_pre, obs_assets, obs_features
)
+ _turn_obs_space_to_array(
self.obs_space_post, obs_assets, obs_features
)
)
return header, row

View File

@@ -0,0 +1,20 @@
from pathlib import Path
from typing import Dict, Union
# Using polars as it's faster than Pandas; it will speed things up when
# files get big!
import polars as pl
def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
"""
Read an average rewards per episode csv file and return as a dict.
The dictionary keys are the episode number, and the values are the mean
reward that episode.
:param av_rewards_csv_file: The average rewards per episode csv file path.
:return: The average rewards per episode cdv as a dict.
"""
d = pl.read_csv(av_rewards_csv_file).to_dict()
return {v: d["Average Reward"][i] for i, v in enumerate(d["Episode"])}

View File

@@ -1,7 +1,6 @@
import csv
from logging import Logger
from typing import List, Final, IO, Union, Tuple
from typing import TYPE_CHECKING
from typing import Final, List, Tuple, TYPE_CHECKING, Union
from primaite import getLogger
from primaite.transactions.transaction import Transaction
@@ -13,15 +12,22 @@ _LOGGER: Logger = getLogger(__name__)
class SessionOutputWriter:
"""
A session output writer class.
Is used to write session outputs to csv file.
"""
_AV_REWARD_PER_EPISODE_HEADER: Final[List[str]] = [
"Episode", "Average Reward"
"Episode",
"Average Reward",
]
def __init__(
self,
env: "Primaite",
transaction_writer: bool = False,
learning_session: bool = True
self,
env: "Primaite",
transaction_writer: bool = False,
learning_session: bool = True,
):
self._env = env
self.transaction_writer = transaction_writer
@@ -52,14 +58,21 @@ class SessionOutputWriter:
self._csv_writer = csv.writer(self._csv_file)
def __del__(self):
self.close()
def close(self):
"""Close the cvs file."""
if self._csv_file:
self._csv_file.close()
_LOGGER.info(f"Finished writing file: {self._csv_file_path}")
_LOGGER.debug(f"Finished writing file: {self._csv_file_path}")
def write(
self,
data: Union[Tuple, Transaction]
):
def write(self, data: Union[Tuple, Transaction]):
"""
Write a row of session data.
:param data: The row of data to write. Can be a Tuple or an instance
of Transaction.
"""
if isinstance(data, Transaction):
header, data = data.as_csv_data()
else:
@@ -69,5 +82,4 @@ class SessionOutputWriter:
self._init_csv_writer()
self._csv_writer.writerow(header)
self._first_write = False
self._csv_writer.writerow(data)

View File

@@ -6,7 +6,7 @@
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray[RLlib])
# "NONE" (Custom Agent)
agent_framework: RLLIB
agent_framework: SB3
# Sets which Red Agent algo/class will be used:
# "PPO" (Proximal Policy Optimization)
@@ -27,7 +27,7 @@ num_steps: 256
# Time delay between steps (for generic agents)
time_delay: 10
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAINING
session_type: TRAIN
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in

View File

@@ -1,11 +1,22 @@
# Main Config File
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: SB3
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: A2C
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_identifier: STABLE_BASELINES3_A2C
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
@@ -28,7 +39,7 @@ observation_space:
time_delay: 1
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAINING
session_type: TRAIN
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in

View File

@@ -1,11 +1,22 @@
# Main Config File
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: CUSTOM
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: RANDOM
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_identifier: NONE
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
@@ -24,7 +35,7 @@ observation_space:
time_delay: 1
# Filename of the scenario / laydown
session_type: TRAINING
session_type: TRAIN
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in

View File

@@ -1,11 +1,22 @@
# Main Config File
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: CUSTOM
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: RANDOM
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_identifier: NONE
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
@@ -25,7 +36,7 @@ observation_space:
time_delay: 1
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAINING
session_type: TRAIN
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in

View File

@@ -1,11 +1,22 @@
# Main Config File
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: CUSTOM
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: RANDOM
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_identifier: NONE
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
@@ -18,7 +29,7 @@ num_steps: 5
# Time delay between steps (for generic agents)
time_delay: 1
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAINING
session_type: TRAIN
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in

View File

@@ -1,10 +1,22 @@
# Main Config File
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: CUSTOM
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: DUMMY
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
agent_identifier: GENERIC
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
@@ -18,7 +30,7 @@ num_steps: 15
time_delay: 1
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAINING
session_type: EVAL
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in

View File

@@ -1,11 +1,22 @@
# Main Config File
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: CUSTOM
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: RANDOM
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_identifier: GENERIC
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
@@ -18,7 +29,7 @@ num_steps: 15
# Time delay between steps (for generic agents)
time_delay: 1
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAINING
session_type: EVAL
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in

View File

@@ -1,11 +1,22 @@
# Main Config File
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: CUSTOM
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: RANDOM
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_identifier: GENERIC
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
@@ -18,7 +29,7 @@ num_steps: 5
# Time delay between steps (for generic agents)
time_delay: 1
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAINING
session_type: EVAL
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in

View File

@@ -1,37 +1,151 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
import datetime
import shutil
import tempfile
import time
from datetime import datetime
from pathlib import Path
from typing import Union
from typing import Dict, Union
from unittest.mock import patch
import pytest
from primaite import getLogger
from primaite.common.enums import AgentIdentifier
from primaite.environment.primaite_env import Primaite
from primaite.primaite_session import PrimaiteSession
from primaite.utils.session_output_reader import av_rewards_dict
from tests.mock_and_patch.get_session_path_mock import get_temp_session_path
ACTION_SPACE_NODE_VALUES = 1
ACTION_SPACE_NODE_ACTION_VALUES = 1
_LOGGER = getLogger(__name__)
def _get_temp_session_path(session_timestamp: datetime) -> Path:
class TempPrimaiteSession(PrimaiteSession):
"""
A temporary PrimaiteSession class.
Uses context manager for deletion of files upon exit.
"""
def __init__(
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
):
super().__init__(training_config_path, lay_down_config_path)
self.setup()
def learn_av_reward_per_episode(self) -> Dict[int, float]:
"""Get the learn av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.learning_path / csv_file)
def eval_av_reward_per_episode_csv(self) -> Dict[int, float]:
"""Get the eval av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.evaluation_path / csv_file)
@property
def env(self) -> Primaite:
"""Direct access to the env for ease of testing."""
return self._agent_session._env # noqa
def __enter__(self):
return self
def __exit__(self, type, value, tb):
del self._agent_session._env.episode_av_reward_writer
del self._agent_session._env.transaction_writer
shutil.rmtree(self.session_path)
shutil.rmtree(self.session_path.parent)
_LOGGER.debug(f"Deleted temp session directory: {self.session_path}")
@pytest.fixture
def temp_primaite_session(request):
"""
Provides a temporary PrimaiteSession instance.
It's temporary as it uses a temporary directory as the session path.
To use this fixture you need to:
- parametrize your test function with:
- "temp_primaite_session"
- [[path to training config, path to lay down config]]
- Include the temp_primaite_session fixture as a param in your test
function.
- use the temp_primaite_session as a context manager assigning is the
name 'session'.
.. code:: python
from primaite.config.lay_down_config import dos_very_basic_config_path
from primaite.config.training_config import main_training_config_path
@pytest.mark.parametrize(
"temp_primaite_session",
[
[main_training_config_path(), dos_very_basic_config_path()]
],
indirect=True
)
def test_primaite_session(temp_primaite_session):
with temp_primaite_session as session:
# Learning outputs are saved in session.learning_path
session.learn()
# Evaluation outputs are saved in session.evaluation_path
session.evaluate()
# To ensure that all files are written, you must call .close()
session.close()
# If you need to inspect any session outputs, it must be done
# inside the context manager
# Now that we've exited the context manager, the
# session.session_path directory and its contents are deleted
"""
training_config_path = request.param[0]
lay_down_config_path = request.param[1]
with patch(
"primaite.agents.agent.get_session_path", get_temp_session_path
) as mck:
mck.session_timestamp = datetime.now()
return TempPrimaiteSession(training_config_path, lay_down_config_path)
@pytest.fixture
def temp_session_path() -> Path:
"""
Get a temp directory session path the test session will output to.
:param session_timestamp: This is the datetime that the session started.
:return: The session directory path.
"""
session_timestamp = datetime.now()
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
session_path = (
Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
)
session_path.mkdir(exist_ok=True, parents=True)
return session_path
def _get_primaite_env_from_config(
training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
temp_session_path,
):
"""Takes a config path and returns the created instance of Primaite."""
session_timestamp: datetime = datetime.now()
session_path = _get_temp_session_path(session_timestamp)
session_path = temp_session_path(session_timestamp)
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
env = Primaite(
@@ -45,7 +159,7 @@ def _get_primaite_env_from_config(
# TOOD: This needs t be refactored to happen outside. Should be part of
# a main Session class.
if env.training_config.agent_identifier == "GENERIC":
if env.training_config.agent_identifier is AgentIdentifier.RANDOM:
run_generic(env, config_values)
return env

View File

@@ -1,8 +0,0 @@
from primaite.config.lay_down_config import data_manipulation_config_path
from primaite.config.training_config import main_training_config_path
from primaite.main import run
def test_primaite_main_e2e():
"""Tests the primaite.main.run function end-to-end."""
run(main_training_config_path(), data_manipulation_config_path())

View File

@@ -0,0 +1,24 @@
import tempfile
from datetime import datetime
from pathlib import Path
from primaite import getLogger
_LOGGER = getLogger(__name__)
def get_temp_session_path(session_timestamp: datetime) -> Path:
"""
Get a temp directory session path the test session will output to.
:param session_timestamp: This is the datetime that the session started.
:return: The session directory path.
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = (
Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
)
session_path.mkdir(exist_ok=True, parents=True)
_LOGGER.debug(f"Created temp session directory: {session_path}")
return session_path

View File

@@ -60,7 +60,9 @@ def test_os_state_change_if_not_compromised(operating_state, expected_state):
1,
)
active_node.set_software_state_if_not_compromised(SoftwareState.OVERWHELMED)
active_node.set_software_state_if_not_compromised(
SoftwareState.OVERWHELMED
)
assert active_node.software_state == expected_state
@@ -98,7 +100,9 @@ def test_file_system_change(operating_state, expected_state):
(HardwareState.ON, FileSystemState.CORRUPT),
],
)
def test_file_system_change_if_not_compromised(operating_state, expected_state):
def test_file_system_change_if_not_compromised(
operating_state, expected_state
):
"""
Test that a node cannot change its file system state.
@@ -116,6 +120,8 @@ def test_file_system_change_if_not_compromised(operating_state, expected_state):
1,
)
active_node.set_file_system_state_if_not_compromised(FileSystemState.CORRUPT)
active_node.set_file_system_state_if_not_compromised(
FileSystemState.CORRUPT
)
assert active_node.file_system_state_actual == expected_state

View File

@@ -7,79 +7,78 @@ from primaite.environment.observations import (
NodeStatuses,
ObservationsHandler,
)
from primaite.environment.primaite_env import Primaite
from tests import TEST_CONFIG_ROOT
from tests.conftest import _get_primaite_env_from_config
@pytest.fixture
def env(request):
"""Build Primaite environment for integration tests of observation space."""
marker = request.node.get_closest_marker("env_config_paths")
training_config_path = marker.args[0]["training_config_path"]
lay_down_config_path = marker.args[0]["lay_down_config_path"]
env = _get_primaite_env_from_config(
training_config_path=training_config_path,
lay_down_config_path=lay_down_config_path,
)
yield env
@pytest.mark.env_config_paths(
dict(
training_config_path=TEST_CONFIG_ROOT
/ "obs_tests/main_config_without_obs.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
@pytest.mark.parametrize(
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
]
],
indirect=True,
)
def test_default_obs_space(env: Primaite):
def test_default_obs_space(temp_primaite_session):
"""Create environment with no obs space defined in config and check that the default obs space was created."""
env.update_environent_obs()
with temp_primaite_session as session:
session.env.update_environent_obs()
components = env.obs_handler.registered_obs_components
components = session.env.obs_handler.registered_obs_components
assert len(components) == 1
assert isinstance(components[0], NodeLinkTable)
assert len(components) == 1
assert isinstance(components[0], NodeLinkTable)
@pytest.mark.env_config_paths(
dict(
training_config_path=TEST_CONFIG_ROOT
/ "obs_tests/main_config_without_obs.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
@pytest.mark.parametrize(
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
]
],
indirect=True,
)
def test_registering_components(env: Primaite):
def test_registering_components(temp_primaite_session):
"""Test regitering and deregistering a component."""
handler = ObservationsHandler()
component = NodeStatuses(env)
handler.register(component)
assert component in handler.registered_obs_components
handler.deregister(component)
assert component not in handler.registered_obs_components
with temp_primaite_session as session:
env = session.env
handler = ObservationsHandler()
component = NodeStatuses(env)
handler.register(component)
assert component in handler.registered_obs_components
handler.deregister(component)
assert component not in handler.registered_obs_components
@pytest.mark.env_config_paths(
dict(
training_config_path=TEST_CONFIG_ROOT
/ "obs_tests/main_config_NODE_LINK_TABLE.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
@pytest.mark.parametrize(
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_LINK_TABLE.yaml",
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
]
],
indirect=True,
)
class TestNodeLinkTable:
"""Test the NodeLinkTable observation component (in isolation)."""
def test_obs_shape(self, env: Primaite):
def test_obs_shape(self, temp_primaite_session):
"""Try creating env with box observation space."""
env.update_environent_obs()
with temp_primaite_session as session:
env = session.env
env.update_environent_obs()
# we have three nodes and two links, with two service
# therefore the box observation space will have:
# * 5 rows (3 nodes + 2 links)
# * 6 columns (four fixed and two for the services)
assert env.env_obs.shape == (5, 6)
# we have three nodes and two links, with two service
# therefore the box observation space will have:
# * 5 rows (3 nodes + 2 links)
# * 6 columns (four fixed and two for the services)
assert env.env_obs.shape == (5, 6)
def test_value(self, env: Primaite):
def test_value(self, temp_primaite_session):
"""Test that the observation is generated correctly.
The laydown has:
@@ -125,36 +124,45 @@ class TestNodeLinkTable:
* 999 (999 traffic service1)
* 0 (no traffic for service2)
"""
# act = np.asarray([0,])
obs, reward, done, info = env.step(0) # apply the 'do nothing' action
with temp_primaite_session as session:
env = session.env
# act = np.asarray([0,])
obs, reward, done, info = env.step(
0
) # apply the 'do nothing' action
assert np.array_equal(
obs,
[
[1, 1, 3, 1, 1, 1],
[2, 1, 1, 1, 1, 4],
[3, 1, 1, 1, 0, 0],
[4, 0, 0, 0, 999, 0],
[5, 0, 0, 0, 999, 0],
],
)
assert np.array_equal(
obs,
[
[1, 1, 3, 1, 1, 1],
[2, 1, 1, 1, 1, 4],
[3, 1, 1, 1, 0, 0],
[4, 0, 0, 0, 999, 0],
[5, 0, 0, 0, 999, 0],
],
)
@pytest.mark.env_config_paths(
dict(
training_config_path=TEST_CONFIG_ROOT
/ "obs_tests/main_config_NODE_STATUSES.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
@pytest.mark.parametrize(
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_STATUSES.yaml",
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
]
],
indirect=True,
)
class TestNodeStatuses:
"""Test the NodeStatuses observation component (in isolation)."""
def test_obs_shape(self, env: Primaite):
def test_obs_shape(self, temp_primaite_session):
"""Try creating env with NodeStatuses as the only component."""
assert env.env_obs.shape == (15,)
with temp_primaite_session as session:
env = session.env
assert env.env_obs.shape == (15,)
def test_values(self, env: Primaite):
def test_values(self, temp_primaite_session):
"""Test that the hardware and software states are encoded correctly.
The laydown has:
@@ -181,28 +189,38 @@ class TestNodeStatuses:
* service 1 = n/a (0)
* service 2 = n/a (0)
"""
obs, _, _, _ = env.step(0) # apply the 'do nothing' action
assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0])
with temp_primaite_session as session:
env = session.env
obs, _, _, _ = env.step(0) # apply the 'do nothing' action
assert np.array_equal(
obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0]
)
@pytest.mark.env_config_paths(
dict(
training_config_path=TEST_CONFIG_ROOT
/ "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
@pytest.mark.parametrize(
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT
/ "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml",
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
]
],
indirect=True,
)
class TestLinkTrafficLevels:
"""Test the LinkTrafficLevels observation component (in isolation)."""
def test_obs_shape(self, env: Primaite):
def test_obs_shape(self, temp_primaite_session):
"""Try creating env with MultiDiscrete observation space."""
env.update_environent_obs()
with temp_primaite_session as session:
env = session.env
env.update_environent_obs()
# we have two links and two services, so the shape should be 2 * 2
assert env.env_obs.shape == (2 * 2,)
# we have two links and two services, so the shape should be 2 * 2
assert env.env_obs.shape == (2 * 2,)
def test_values(self, env: Primaite):
def test_values(self, temp_primaite_session):
"""Test that traffic values are encoded correctly.
The laydown has:
@@ -212,12 +230,14 @@ class TestLinkTrafficLevels:
* an IER trying to send 999 bits of data over both links the whole time (via the first service)
* link bandwidth of 1000, therefore the utilisation is 99.9%
"""
obs, reward, done, info = env.step(0)
obs, reward, done, info = env.step(0)
with temp_primaite_session as session:
env = session.env
obs, reward, done, info = env.step(0)
obs, reward, done, info = env.step(0)
# the observation space has combine_service_traffic set to False, so the space has this format:
# [link1_service1, link1_service2, link2_service1, link2_service2]
# we send 999 bits of data via link1 and link2 on service 1.
# therefore the first and third elements should be 6 and all others 0
# (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%)
assert np.array_equal(obs, [6, 0, 6, 0])
# the observation space has combine_service_traffic set to False, so the space has this format:
# [link1_service1, link1_service2, link2_service1, link2_service2]
# we send 999 bits of data via link1 and link2 on service 1.
# therefore the first and third elements should be 6 and all others 0
# (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%)
assert np.array_equal(obs, [6, 0, 6, 0])

View File

@@ -0,0 +1,61 @@
import os
import pytest
from primaite import getLogger
from primaite.config.lay_down_config import dos_very_basic_config_path
from primaite.config.training_config import main_training_config_path
_LOGGER = getLogger(__name__)
@pytest.mark.parametrize(
"temp_primaite_session",
[[main_training_config_path(), dos_very_basic_config_path()]],
indirect=True,
)
def test_primaite_session(temp_primaite_session):
"""Tests the PrimaiteSession class and its outputs."""
with temp_primaite_session as session:
session_path = session.session_path
assert session_path.exists()
session.learn()
# Learning outputs are saved in session.learning_path
session.evaluate()
# Evaluation outputs are saved in session.evaluation_path
# If you need to inspect any session outputs, it must be done inside
# the context manager
# Check that the metadata json file exists
assert (session_path / "session_metadata.json").exists()
# Check that the network png file exists
assert (session_path / f"network_{session.timestamp_str}.png").exists()
# Check that both the transactions and av reward csv files exist
for file in session.learning_path.iterdir():
if file.suffix == ".csv":
assert (
"all_transactions" in file.name
or "average_reward_per_episode" in file.name
)
# Check that both the transactions and av reward csv files exist
for file in session.evaluation_path.iterdir():
if file.suffix == ".csv":
assert (
"all_transactions" in file.name
or "average_reward_per_episode" in file.name
)
_LOGGER.debug("Inspecting files in temp session path...")
for dir_path, dir_names, file_names in os.walk(session_path):
for file in file_names:
path = os.path.join(dir_path, file)
file_str = path.split(str(session_path))[-1]
_LOGGER.debug(f" {file_str}")
# Now that we've exited the context manager, the session.session_path
# directory and its contents are deleted
assert not session_path.exists()

View File

@@ -18,7 +18,9 @@ from primaite.nodes.service_node import ServiceNode
"starting_operating_state, expected_operating_state",
[(HardwareState.RESETTING, HardwareState.ON)],
)
def test_node_resets_correctly(starting_operating_state, expected_operating_state):
def test_node_resets_correctly(
starting_operating_state, expected_operating_state
):
"""Tests that a node resets correctly."""
active_node = ActiveNode(
node_id="0",

View File

@@ -1,26 +1,33 @@
import pytest
from tests import TEST_CONFIG_ROOT
from tests.conftest import _get_primaite_env_from_config
def test_rewards_are_being_penalised_at_each_step_function():
@pytest.mark.parametrize(
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml",
TEST_CONFIG_ROOT / "one_node_states_on_off_lay_down_config.yaml",
]
],
indirect=True,
)
def test_rewards_are_being_penalised_at_each_step_function(
temp_primaite_session,
):
"""
Test that hardware state is penalised at each step.
When the initial state is OFF compared to reference state which is ON.
"""
env = _get_primaite_env_from_config(
training_config_path=TEST_CONFIG_ROOT
/ "one_node_states_on_off_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT
/ "one_node_states_on_off_lay_down_config.yaml",
)
"""
On different steps (of the 13 in total) these are the following rewards for config_6 which are activated:
On different steps (of the 13 in total) these are the following rewards
for config_6 which are activated:
File System State: goodShouldBeCorrupt = 5 (between Steps 1 & 3)
Hardware State: onShouldBeOff = -2 (between Steps 4 & 6)
Service State: goodShouldBeCompromised = 5 (between Steps 7 & 9)
Software State (Software State): goodShouldBeCompromised = 5 (between Steps 10 & 12)
Software State (Software State): goodShouldBeCompromised = 5 (between
Steps 10 & 12)
Total Reward: -2 - 2 + 5 + 5 + 5 + 5 + 5 + 5 = 26
Step Count: 13
@@ -28,5 +35,8 @@ def test_rewards_are_being_penalised_at_each_step_function():
For the 4 steps where this occurs the average reward is:
Average Reward: 2 (26 / 13)
"""
print("average reward", env.average_reward)
assert env.average_reward == -8.0
with temp_primaite_session as session:
session.evaluate()
session.close()
ev_rewards = session.eval_av_reward_per_episode_csv()
assert ev_rewards[1] == -8.0

View File

@@ -45,7 +45,9 @@ def test_service_state_change(operating_state, expected_state):
(HardwareState.ON, SoftwareState.OVERWHELMED),
],
)
def test_service_state_change_if_not_comprised(operating_state, expected_state):
def test_service_state_change_if_not_comprised(
operating_state, expected_state
):
"""
Test that a node cannot change the state of a running service.
@@ -65,6 +67,8 @@ def test_service_state_change_if_not_comprised(operating_state, expected_state):
service = Service("TCP", 80, SoftwareState.GOOD)
service_node.add_service(service)
service_node.set_service_state_if_not_compromised("TCP", SoftwareState.OVERWHELMED)
service_node.set_service_state_if_not_compromised(
"TCP", SoftwareState.OVERWHELMED
)
assert service_node.get_service_state("TCP") == expected_state

View File

@@ -1,9 +1,10 @@
import time
import pytest
from primaite.common.enums import HardwareState
from primaite.environment.primaite_env import Primaite
from tests import TEST_CONFIG_ROOT
from tests.conftest import _get_primaite_env_from_config
def run_generic_set_actions(env: Primaite):
@@ -44,59 +45,72 @@ def run_generic_set_actions(env: Primaite):
# env.close()
def test_single_action_space_is_valid():
"""Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations."""
env = _get_primaite_env_from_config(
training_config_path=TEST_CONFIG_ROOT / "single_action_space_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT
/ "single_action_space_lay_down_config.yaml",
)
@pytest.mark.parametrize(
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT / "single_action_space_main_config.yaml",
TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
]
],
indirect=True,
)
def test_single_action_space_is_valid(temp_primaite_session):
"""Test single action space is valid."""
with temp_primaite_session as session:
env = session.env
run_generic_set_actions(env)
# Retrieve the action space dictionary values from environment
env_action_space_dict = env.action_dict.values()
# Flags to check the conditions of the action space
contains_acl_actions = False
contains_node_actions = False
both_action_spaces = False
# Loop through each element of the list (which is every value from the dictionary)
for dict_item in env_action_space_dict:
# Node action detected
if len(dict_item) == 4:
contains_node_actions = True
# Link action detected
elif len(dict_item) == 6:
contains_acl_actions = True
# If both are there then the ANY action type is working
if contains_node_actions and contains_acl_actions:
both_action_spaces = True
# Check condition should be True
assert both_action_spaces
run_generic_set_actions(env)
# Retrieve the action space dictionary values from environment
env_action_space_dict = env.action_dict.values()
# Flags to check the conditions of the action space
contains_acl_actions = False
contains_node_actions = False
both_action_spaces = False
# Loop through each element of the list (which is every value from the dictionary)
for dict_item in env_action_space_dict:
# Node action detected
if len(dict_item) == 4:
contains_node_actions = True
# Link action detected
elif len(dict_item) == 6:
contains_acl_actions = True
# If both are there then the ANY action type is working
if contains_node_actions and contains_acl_actions:
both_action_spaces = True
# Check condition should be True
assert both_action_spaces
def test_agent_is_executing_actions_from_both_spaces():
@pytest.mark.parametrize(
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT
/ "single_action_space_fixed_blue_actions_main_config.yaml",
TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
]
],
indirect=True,
)
def test_agent_is_executing_actions_from_both_spaces(temp_primaite_session):
"""Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL)."""
env = _get_primaite_env_from_config(
training_config_path=TEST_CONFIG_ROOT
/ "single_action_space_fixed_blue_actions_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT
/ "single_action_space_lay_down_config.yaml",
)
# Run environment with specified fixed blue agent actions only
run_generic_set_actions(env)
# Retrieve hardware state of computer_1 node in laydown config
# Agent turned this off in Step 5
computer_node_hardware_state = env.nodes["1"].hardware_state
# Retrieve the Access Control List object stored by the environment at the end of the episode
access_control_list = env.acl
# Use the Access Control List object acl object attribute to get dictionary
# Use dictionary.values() to get total list of all items in the dictionary
acl_rules_list = access_control_list.acl.values()
# Length of this list tells you how many items are in the dictionary
# This number is the frequency of Access Control Rules in the environment
# In the scenario, we specified that the agent should create only 1 acl rule
num_of_rules = len(acl_rules_list)
# Therefore these statements below MUST be true
assert computer_node_hardware_state == HardwareState.OFF
assert num_of_rules == 1
with temp_primaite_session as session:
env = session.env
# Run environment with specified fixed blue agent actions only
run_generic_set_actions(env)
# Retrieve hardware state of computer_1 node in laydown config
# Agent turned this off in Step 5
computer_node_hardware_state = env.nodes["1"].hardware_state
# Retrieve the Access Control List object stored by the environment at the end of the episode
access_control_list = env.acl
# Use the Access Control List object acl object attribute to get dictionary
# Use dictionary.values() to get total list of all items in the dictionary
acl_rules_list = access_control_list.acl.values()
# Length of this list tells you how many items are in the dictionary
# This number is the frequency of Access Control Rules in the environment
# In the scenario, we specified that the agent should create only 1 acl rule
num_of_rules = len(acl_rules_list)
# Therefore these statements below MUST be true
assert computer_node_hardware_state == HardwareState.OFF
assert num_of_rules == 1

View File

@@ -16,7 +16,9 @@ def test_legacy_lay_down_config_yaml_conversion():
with open(new_path, "r") as file:
new_dict = yaml.safe_load(file)
converted_dict = training_config.convert_legacy_training_config_dict(legacy_dict)
converted_dict = training_config.convert_legacy_training_config_dict(
legacy_dict
)
for key, value in new_dict.items():
assert converted_dict[key] == value