#917 - Integrated the PrimaiteSession into all tests.
- Ran a full pre-commit hook and thus encountered tons of fixes required
This commit is contained in:
@@ -59,4 +59,4 @@ steps:
|
||||
|
||||
- script: |
|
||||
pytest tests/
|
||||
displayName: 'Run unmarked tests'
|
||||
displayName: 'Run tests'
|
||||
|
||||
@@ -13,6 +13,9 @@ repos:
|
||||
rev: 23.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
args: [ "--line-length=79" ]
|
||||
additional_dependencies:
|
||||
- jupyter
|
||||
- repo: http://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
@@ -22,4 +25,5 @@ repos:
|
||||
rev: 6.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
additional_dependencies: [ flake8-docstrings ]
|
||||
additional_dependencies:
|
||||
- flake8-docstrings
|
||||
|
||||
@@ -22,7 +22,7 @@ The environment config file consists of the following attributes:
|
||||
|
||||
|
||||
* **agent_framework** [enum]
|
||||
|
||||
|
||||
This identifies the agent framework to be used to instantiate the agent algorithm. Select from one of the following:
|
||||
|
||||
* NONE - Where a user developed agent is to be used
|
||||
@@ -30,14 +30,14 @@ The environment config file consists of the following attributes:
|
||||
* RLLIB - Ray RLlib.
|
||||
|
||||
* **agent_identifier**
|
||||
|
||||
|
||||
This identifies the agent to use for the session. Select from one of the following:
|
||||
|
||||
* A2C - Advantage Actor Critic
|
||||
* PPO - Proximal Policy Optimization
|
||||
* HARDCODED - A custom built deterministic agent
|
||||
* RANDOM - A Stochastic random agent
|
||||
|
||||
|
||||
|
||||
* **action_type** [enum]
|
||||
|
||||
|
||||
@@ -47,6 +47,8 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| asttokens | 2.2.1 | Apache 2.0 | https://github.com/gristlabs/asttokens |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| astunparse | 1.6.3 | BSD License | https://github.com/simonpercivall/astunparse |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| attrs | 23.1.0 | MIT License | https://www.attrs.org/en/stable/changelog.html |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| backcall | 0.2.0 | BSD License | https://github.com/takluyver/backcall |
|
||||
@@ -103,6 +105,8 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| flake8 | 6.0.0 | MIT License | https://github.com/pycqa/flake8 |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| flatbuffers | 23.5.26 | Apache Software License | https://google.github.io/flatbuffers/ |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| fonttools | 4.39.4 | MIT License | http://github.com/fonttools/fonttools |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| fqdn | 1.5.1 | Mozilla Public License 2.0 (MPL 2.0) | https://github.com/ypcrts/fqdn |
|
||||
@@ -111,9 +115,13 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| furo | 2023.3.27 | MIT License | UNKNOWN |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| gast | 0.4.0 | BSD License | https://github.com/serge-sans-paille/gast/ |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| google-auth | 2.19.0 | Apache Software License | https://github.com/googleapis/google-auth-library-python |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| google-auth-oauthlib | 1.0.0 | Apache Software License | https://github.com/GoogleCloudPlatform/google-auth-library-python-oauthlib |
|
||||
| google-auth-oauthlib | 0.4.6 | Apache Software License | https://github.com/GoogleCloudPlatform/google-auth-library-python-oauthlib |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| google-pasta | 0.2.0 | Apache Software License | https://github.com/google/pasta |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| grpcio | 1.51.3 | Apache Software License | https://grpc.io |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
@@ -121,6 +129,8 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| gymnasium-notices | 0.0.1 | MIT License | https://github.com/Farama-Foundation/gym-notices |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| h5py | 3.9.0 | BSD License | https://www.h5py.org/ |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| identify | 2.5.24 | MIT License | https://github.com/pre-commit/identify |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| idna | 3.4 | BSD License | https://github.com/kjd/idna |
|
||||
@@ -141,6 +151,8 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| isoduration | 20.11.0 | ISC License (ISCL) | https://github.com/bolsote/isoduration |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| jax | 0.4.12 | Apache-2.0 | https://github.com/google/jax |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| jedi | 0.18.2 | MIT License | https://github.com/davidhalter/jedi |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| json5 | 0.9.14 | Apache Software License | https://github.com/dpranke/pyjson5 |
|
||||
@@ -151,14 +163,14 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| jupyter-events | 0.6.3 | BSD License | http://jupyter.org |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| jupyter-server | 1.24.0 | BSD License | https://jupyter-server.readthedocs.io |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| jupyter-ydoc | 0.2.4 | BSD 3-Clause License | https://jupyter.org |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| jupyter_client | 8.2.0 | BSD License | https://jupyter.org |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| jupyter_core | 5.3.0 | BSD License | https://jupyter.org |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| jupyter_server | 2.6.0 | BSD License | https://jupyter-server.readthedocs.io |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| jupyter_server_fileid | 0.9.0 | BSD License | UNKNOWN |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| jupyter_server_terminals | 0.4.4 | BSD License | https://jupyter.org |
|
||||
@@ -171,10 +183,14 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| jupyterlab_server | 2.22.1 | BSD License | https://jupyterlab-server.readthedocs.io |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| keras | 2.12.0 | Apache Software License | https://keras.io/ |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| kiwisolver | 1.4.4 | BSD License | UNKNOWN |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| lazy_loader | 0.2 | BSD License | https://github.com/scientific-python/lazy_loader |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| libclang | 16.0.0 | Apache Software License | https://github.com/sighingnow/libclang |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| lz4 | 4.3.2 | BSD License | https://github.com/python-lz4/python-lz4 |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| markdown-it-py | 2.2.0 | MIT License | https://github.com/executablebooks/markdown-it-py |
|
||||
@@ -183,19 +199,23 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| matplotlib-inline | 0.1.6 | BSD 3-Clause | https://github.com/ipython/matplotlib-inline |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| mavizstyle | 1.0.0 | UNKNOWN | UNKNOWN |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| mccabe | 0.7.0 | MIT License | https://github.com/pycqa/mccabe |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| mdurl | 0.1.2 | MIT License | https://github.com/executablebooks/mdurl |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| mistune | 2.0.5 | BSD License | https://github.com/lepture/mistune |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| ml-dtypes | 0.2.0 | Apache Software License | UNKNOWN |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| mock | 5.0.2 | BSD License | http://mock.readthedocs.org/en/latest/ |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| mpmath | 1.3.0 | BSD License | http://mpmath.org/ |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| msgpack | 1.0.5 | Apache Software License | https://msgpack.org/ |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| nbclassic | 1.0.0 | BSD License | https://github.com/jupyter/nbclassic |
|
||||
| nbclassic | 0.5.6 | BSD License | https://github.com/jupyter/nbclassic |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| nbclient | 0.8.0 | BSD License | https://jupyter.org |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
@@ -217,6 +237,8 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| oauthlib | 3.2.2 | BSD License | https://github.com/oauthlib/oauthlib |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| opt-einsum | 3.3.0 | MIT | https://github.com/dgasmith/opt_einsum |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| overrides | 7.3.1 | Apache License, Version 2.0 | https://github.com/mkorpela/overrides |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| packaging | 23.1 | Apache Software License; BSD License | https://github.com/pypa/packaging |
|
||||
@@ -231,11 +253,17 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| platformdirs | 3.5.1 | MIT License | https://github.com/platformdirs/platformdirs |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| plotly | 5.15.0 | MIT License | https://plotly.com/python/ |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| pluggy | 1.0.0 | MIT License | https://github.com/pytest-dev/pluggy |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| pre-commit | 2.20.0 | MIT License | https://github.com/pre-commit/pre-commit |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| primaite | 1.2.1 | GFX | UNKNOWN |
|
||||
| primaite | 2.0.0rc1 | GFX | UNKNOWN |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| primaite | 2.0.0rc1 | GFX | UNKNOWN |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| primaite | 2.0.0rc1 | GFX | UNKNOWN |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| prometheus-client | 0.17.0 | Apache Software License | https://github.com/prometheus/client_python |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
@@ -295,6 +323,8 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| rsa | 4.9 | Apache Software License | https://stuvel.eu/rsa |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| ruff | 0.0.272 | MIT License | https://github.com/charliermarsh/ruff |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| scikit-image | 0.20.0 | BSD License | https://scikit-image.org |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| scipy | 1.10.1 | BSD License | https://scipy.org/ |
|
||||
@@ -335,14 +365,26 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| tabulate | 0.9.0 | MIT License | https://github.com/astanin/python-tabulate |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| tensorboard | 2.12.3 | Apache Software License | https://github.com/tensorflow/tensorboard |
|
||||
| tenacity | 8.2.2 | Apache Software License | https://github.com/jd/tenacity |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| tensorboard-data-server | 0.7.0 | Apache Software License | https://github.com/tensorflow/tensorboard/tree/master/tensorboard/data/server |
|
||||
| tensorboard | 2.11.2 | Apache Software License | https://github.com/tensorflow/tensorboard |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| tensorboard-data-server | 0.6.1 | Apache Software License | https://github.com/tensorflow/tensorboard/tree/master/tensorboard/data/server |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| tensorboard-plugin-wit | 1.8.1 | Apache 2.0 | https://whatif-tool.dev |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| tensorboardX | 2.6 | MIT License | https://github.com/lanpa/tensorboardX |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| tensorflow | 2.12.0 | Apache Software License | https://www.tensorflow.org/ |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| tensorflow-estimator | 2.12.0 | Apache Software License | https://www.tensorflow.org/ |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| tensorflow-intel | 2.12.0 | Apache Software License | https://www.tensorflow.org/ |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| tensorflow-io-gcs-filesystem | 0.31.0 | Apache Software License | https://github.com/tensorflow/io |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| termcolor | 2.3.0 | MIT License | https://github.com/termcolor/termcolor |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| terminado | 0.17.1 | BSD License | https://github.com/jupyter/terminado |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| tifffile | 2023.4.12 | BSD License | https://www.cgohlke.com |
|
||||
@@ -377,6 +419,8 @@
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| websocket-client | 1.5.2 | Apache Software License | https://github.com/websocket-client/websocket-client.git |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| wrapt | 1.14.1 | BSD License | https://github.com/GrahamDumpleton/wrapt |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| y-py | 0.5.9 | MIT License | https://github.com/y-crdt/ypy |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+
|
||||
| ypy-websocket | 0.8.2 | UNKNOWN | https://github.com/y-crdt/ypy-websocket |
|
||||
|
||||
@@ -31,6 +31,8 @@ dependencies = [
|
||||
"networkx==3.1",
|
||||
"numpy==1.23.5",
|
||||
"platformdirs==3.5.1",
|
||||
"plotly==5.15.0",
|
||||
"polars==0.18.4",
|
||||
"PyYAML==6.0",
|
||||
"ray[rllib]==2.2.0",
|
||||
"stable-baselines3==1.6.2",
|
||||
@@ -69,3 +71,12 @@ tensorflow = [
|
||||
|
||||
[project.scripts]
|
||||
primaite = "primaite.cli:app"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 79
|
||||
force_sort_within_sections = "False"
|
||||
order_by_type = "False"
|
||||
|
||||
[tool.black]
|
||||
line-length = 79
|
||||
|
||||
@@ -1 +1 @@
|
||||
2.0.0rc1
|
||||
2.0.0rc1
|
||||
|
||||
@@ -3,12 +3,10 @@ import logging
|
||||
import logging.config
|
||||
import sys
|
||||
from bisect import bisect
|
||||
from logging import Formatter, LogRecord, StreamHandler
|
||||
from logging import Logger
|
||||
from logging import Formatter, Logger, LogRecord, StreamHandler
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Final
|
||||
from typing import Dict, Final
|
||||
|
||||
import pkg_resources
|
||||
import yaml
|
||||
@@ -21,7 +19,6 @@ _PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite")
|
||||
def _get_primaite_config():
|
||||
config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml"
|
||||
if not config_path.exists():
|
||||
|
||||
config_path = Path(
|
||||
pkg_resources.resource_filename(
|
||||
"primaite", "setup/_package_data/primaite_config.yaml"
|
||||
@@ -37,7 +34,9 @@ def _get_primaite_config():
|
||||
"ERROR": logging.ERROR,
|
||||
"CRITICAL": logging.CRITICAL,
|
||||
}
|
||||
primaite_config["log_level"] = log_level_map[primaite_config["logging"]["log_level"]]
|
||||
primaite_config["log_level"] = log_level_map[
|
||||
primaite_config["logging"]["log_level"]
|
||||
]
|
||||
return primaite_config
|
||||
|
||||
|
||||
@@ -111,9 +110,13 @@ _LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter(
|
||||
{
|
||||
logging.DEBUG: _PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"],
|
||||
logging.INFO: _PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"],
|
||||
logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"],
|
||||
logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"][
|
||||
"WARNING"
|
||||
],
|
||||
logging.ERROR: _PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"],
|
||||
logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"]
|
||||
logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"][
|
||||
"CRITICAL"
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -10,7 +10,9 @@ class AccessControlList:
|
||||
|
||||
def __init__(self):
|
||||
"""Init."""
|
||||
self.acl: Dict[str, AccessControlList] = {} # A dictionary of ACL Rules
|
||||
self.acl: Dict[
|
||||
str, AccessControlList
|
||||
] = {} # A dictionary of ACL Rules
|
||||
|
||||
def check_address_match(self, _rule, _source_ip_address, _dest_ip_address):
|
||||
"""
|
||||
@@ -37,13 +39,17 @@ class AccessControlList:
|
||||
_rule.get_source_ip() == _source_ip_address
|
||||
and _rule.get_dest_ip() == "ANY"
|
||||
)
|
||||
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY")
|
||||
or (
|
||||
_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY"
|
||||
)
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port):
|
||||
def is_blocked(
|
||||
self, _source_ip_address, _dest_ip_address, _protocol, _port
|
||||
):
|
||||
"""
|
||||
Checks for rules that block a protocol / port.
|
||||
|
||||
@@ -87,7 +93,9 @@ class AccessControlList:
|
||||
_protocol: the protocol
|
||||
_port: the port
|
||||
"""
|
||||
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
new_rule = ACLRule(
|
||||
_permission, _source_ip, _dest_ip, _protocol, str(_port)
|
||||
)
|
||||
hash_value = hash(new_rule)
|
||||
self.acl[hash_value] = new_rule
|
||||
|
||||
@@ -102,7 +110,9 @@ class AccessControlList:
|
||||
_protocol: the protocol
|
||||
_port: the port
|
||||
"""
|
||||
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
rule = ACLRule(
|
||||
_permission, _source_ip, _dest_ip, _protocol, str(_port)
|
||||
)
|
||||
hash_value = hash(rule)
|
||||
# There will not always be something 'popable' since the agent will be trying random things
|
||||
try:
|
||||
@@ -114,7 +124,9 @@ class AccessControlList:
|
||||
"""Removes all rules."""
|
||||
self.acl.clear()
|
||||
|
||||
def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
def get_dictionary_hash(
|
||||
self, _permission, _source_ip, _dest_ip, _protocol, _port
|
||||
):
|
||||
"""
|
||||
Produces a hash value for a rule.
|
||||
|
||||
@@ -128,6 +140,8 @@ class AccessControlList:
|
||||
Returns:
|
||||
Hash value based on rule parameters.
|
||||
"""
|
||||
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
rule = ACLRule(
|
||||
_permission, _source_ip, _dest_ip, _protocol, str(_port)
|
||||
)
|
||||
hash_value = hash(rule)
|
||||
return hash_value
|
||||
|
||||
@@ -30,7 +30,13 @@ class ACLRule:
|
||||
Returns hash of core parameters.
|
||||
"""
|
||||
return hash(
|
||||
(self.permission, self.source_ip, self.dest_ip, self.protocol, self.port)
|
||||
(
|
||||
self.permission,
|
||||
self.source_ip,
|
||||
self.dest_ip,
|
||||
self.protocol,
|
||||
self.port,
|
||||
)
|
||||
)
|
||||
|
||||
def get_permission(self):
|
||||
|
||||
@@ -1,27 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Final, Dict, Union
|
||||
from typing import Dict, Final, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
|
||||
import primaite
|
||||
from primaite import getLogger, SESSIONS_DIR
|
||||
from primaite.config import lay_down_config
|
||||
from primaite.config import training_config
|
||||
from primaite.config import lay_down_config, training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.data_viz.session_plots import plot_av_reward_per_episode
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def _get_session_path(session_timestamp: datetime) -> Path:
|
||||
def get_session_path(session_timestamp: datetime) -> Path:
|
||||
"""
|
||||
Get a temp directory session path the test session will output to.
|
||||
Get the directory path the session will output to.
|
||||
|
||||
This is set in the format of:
|
||||
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
|
||||
:param session_timestamp: This is the datetime that the session started.
|
||||
:return: The session directory path.
|
||||
@@ -35,13 +39,15 @@ def _get_session_path(session_timestamp: datetime) -> Path:
|
||||
|
||||
|
||||
class AgentSessionABC(ABC):
|
||||
"""
|
||||
An ABC that manages training and/or evaluation of agents in PrimAITE.
|
||||
|
||||
This class cannot be directly instantiated and must be inherited from
|
||||
with all implemented abstract methods implemented.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path,
|
||||
lay_down_config_path
|
||||
):
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Final[Union[Path]] = training_config_path
|
||||
@@ -66,9 +72,8 @@ class AgentSessionABC(ABC):
|
||||
self._uuid = str(uuid4())
|
||||
self.session_timestamp: datetime = datetime.now()
|
||||
"The session timestamp"
|
||||
self.session_path = _get_session_path(self.session_timestamp)
|
||||
self.session_path = get_session_path(self.session_timestamp)
|
||||
"The Session path"
|
||||
self.checkpoints_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@property
|
||||
def timestamp_str(self) -> str:
|
||||
@@ -78,17 +83,23 @@ class AgentSessionABC(ABC):
|
||||
@property
|
||||
def learning_path(self) -> Path:
|
||||
"""The learning outputs path."""
|
||||
return self.session_path / "learning"
|
||||
path = self.session_path / "learning"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def evaluation_path(self) -> Path:
|
||||
"""The evaluation outputs path."""
|
||||
return self.session_path / "evaluation"
|
||||
path = self.session_path / "evaluation"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def checkpoints_path(self) -> Path:
|
||||
"""The Session checkpoints path."""
|
||||
return self.learning_path / "checkpoints"
|
||||
path = self.learning_path / "checkpoints"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def uuid(self):
|
||||
@@ -118,14 +129,8 @@ class AgentSessionABC(ABC):
|
||||
"uuid": self.uuid,
|
||||
"start_datetime": self.session_timestamp.isoformat(),
|
||||
"end_datetime": None,
|
||||
"learning": {
|
||||
"total_episodes": None,
|
||||
"total_time_steps": None
|
||||
},
|
||||
"evaluation": {
|
||||
"total_episodes": None,
|
||||
"total_time_steps": None
|
||||
},
|
||||
"learning": {"total_episodes": None, "total_time_steps": None},
|
||||
"evaluation": {"total_episodes": None, "total_time_steps": None},
|
||||
"env": {
|
||||
"training_config": self._training_config.to_dict(
|
||||
json_serializable=True
|
||||
@@ -156,11 +161,19 @@ class AgentSessionABC(ABC):
|
||||
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
|
||||
if not self.is_eval:
|
||||
metadata_dict["learning"]["total_episodes"] = self._env.episode_count # noqa
|
||||
metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa
|
||||
metadata_dict["learning"][
|
||||
"total_episodes"
|
||||
] = self._env.episode_count # noqa
|
||||
metadata_dict["learning"][
|
||||
"total_time_steps"
|
||||
] = self._env.total_step_count # noqa
|
||||
else:
|
||||
metadata_dict["evaluation"]["total_episodes"] = self._env.episode_count # noqa
|
||||
metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa
|
||||
metadata_dict["evaluation"][
|
||||
"total_episodes"
|
||||
] = self._env.episode_count # noqa
|
||||
metadata_dict["evaluation"][
|
||||
"total_time_steps"
|
||||
] = self._env.total_step_count # noqa
|
||||
|
||||
filepath = self.session_path / "session_metadata.json"
|
||||
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||
@@ -187,26 +200,47 @@ class AgentSessionABC(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
if self._can_learn:
|
||||
_LOGGER.info("Finished learning")
|
||||
_LOGGER.debug("Writing transactions")
|
||||
self._update_session_metadata_file()
|
||||
self._can_evaluate = True
|
||||
self.is_eval = False
|
||||
self._plot_av_reward_per_episode(learning_session=True)
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
self._env.set_as_eval() # noqa
|
||||
self.is_eval = True
|
||||
self._plot_av_reward_per_episode(learning_session=False)
|
||||
_LOGGER.info("Finished evaluation")
|
||||
|
||||
@abstractmethod
|
||||
@@ -216,6 +250,7 @@ class AgentSessionABC(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load(cls, path: Union[str, Path]) -> AgentSessionABC:
|
||||
"""Load an agent from file."""
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
|
||||
@@ -246,21 +281,56 @@ class AgentSessionABC(ABC):
|
||||
|
||||
else:
|
||||
# Session path does not exist
|
||||
msg = f"Failed to load PrimAITE Session, path does not exist: {path}"
|
||||
msg = (
|
||||
f"Failed to load PrimAITE Session, path does not exist: {path}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
"""Save the agent."""
|
||||
self._agent.save(self.session_path)
|
||||
|
||||
@abstractmethod
|
||||
def export(self):
|
||||
"""Export the agent to transportable file format."""
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
"""Closes the agent."""
|
||||
self._env.episode_av_reward_writer.close() # noqa
|
||||
self._env.transaction_writer.close() # noqa
|
||||
|
||||
def _plot_av_reward_per_episode(self, learning_session: bool = True):
|
||||
# self.close()
|
||||
title = f"PrimAITE Session {self.timestamp_str} "
|
||||
subtitle = str(self._training_config)
|
||||
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
|
||||
image_file = f"average_reward_per_episode_{self.timestamp_str}.png"
|
||||
if learning_session:
|
||||
title += "(Learning)"
|
||||
path = self.learning_path / csv_file
|
||||
image_path = self.learning_path / image_file
|
||||
else:
|
||||
title += "(Evaluation)"
|
||||
path = self.evaluation_path / csv_file
|
||||
image_path = self.evaluation_path / image_file
|
||||
|
||||
fig = plot_av_reward_per_episode(path, title, subtitle)
|
||||
fig.write_image(image_path)
|
||||
_LOGGER.debug(f"Saved average rewards per episode plot to: {path}")
|
||||
|
||||
|
||||
class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
"""
|
||||
An Agent Session ABC for evaluation deterministic agents.
|
||||
|
||||
This class cannot be directly instantiated and must be inherited from
|
||||
with all implemented abstract methods implemented.
|
||||
"""
|
||||
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
self._setup()
|
||||
@@ -270,13 +340,12 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str
|
||||
timestamp_str=self.timestamp_str,
|
||||
)
|
||||
super()._setup()
|
||||
self._can_learn = False
|
||||
self._can_evaluate = True
|
||||
|
||||
|
||||
def _save_checkpoint(self):
|
||||
pass
|
||||
|
||||
@@ -284,11 +353,20 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
_LOGGER.warning("Deterministic agents cannot learn")
|
||||
|
||||
@abstractmethod
|
||||
@@ -296,20 +374,31 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
pass
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
self._env.set_as_eval() # noqa
|
||||
self.is_eval = True
|
||||
|
||||
if not time_steps:
|
||||
time_steps = self._training_config.num_steps
|
||||
|
||||
if not episodes:
|
||||
episodes = self._training_config.num_episodes
|
||||
|
||||
obs = self._env.reset()
|
||||
for episode in range(episodes):
|
||||
# Reset env and collect initial observation
|
||||
obs = self._env.reset()
|
||||
for step in range(time_steps):
|
||||
# Calculate action
|
||||
action = self._calculate_action(obs)
|
||||
@@ -322,15 +411,18 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
|
||||
# Introduce a delay between steps
|
||||
time.sleep(self._training_config.time_delay / 1000)
|
||||
obs = self._env.reset()
|
||||
self._env.close()
|
||||
super().evaluate()
|
||||
|
||||
@classmethod
|
||||
def load(cls):
|
||||
"""Load an agent from file."""
|
||||
_LOGGER.warning("Deterministic agents cannot be loaded")
|
||||
|
||||
def save(self):
|
||||
"""Save the agent."""
|
||||
_LOGGER.warning("Deterministic agents cannot be saved")
|
||||
|
||||
def export(self):
|
||||
"""Export the agent to transportable file format."""
|
||||
_LOGGER.warning("Deterministic agents cannot be exported")
|
||||
|
||||
@@ -11,9 +11,13 @@ from primaite.common.enums import HardCodedAgentView
|
||||
|
||||
|
||||
class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
"""An Agent Session class that implements a deterministic ACL agent."""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC:
|
||||
if (
|
||||
self._training_config.hard_coded_agent_view
|
||||
== HardCodedAgentView.BASIC
|
||||
):
|
||||
# Basic view action using only the current observation
|
||||
return self._calculate_action_basic_view(obs)
|
||||
else:
|
||||
@@ -22,6 +26,12 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
return self._calculate_action_full_view(obs)
|
||||
|
||||
def get_blocked_green_iers(self, green_iers, acl, nodes):
|
||||
"""
|
||||
Get blocked green IERs.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
blocked_green_iers = {}
|
||||
|
||||
for green_ier_id, green_ier in green_iers.items():
|
||||
@@ -33,8 +43,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
port = green_ier.get_port()
|
||||
|
||||
# Can be blocked by an ACL or by default (no allow rule exists)
|
||||
if acl.is_blocked(source_node_address, dest_node_address, protocol,
|
||||
port):
|
||||
if acl.is_blocked(
|
||||
source_node_address, dest_node_address, protocol, port
|
||||
):
|
||||
blocked_green_iers[green_ier_id] = green_ier
|
||||
|
||||
return blocked_green_iers
|
||||
@@ -42,8 +53,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
def get_matching_acl_rules_for_ier(self, ier, acl, nodes):
|
||||
"""
|
||||
Get matching ACL rules for an IER.
|
||||
"""
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
source_node_id = ier.get_source_node_id()
|
||||
source_node_address = nodes[source_node_id].ip_address
|
||||
dest_node_id = ier.get_dest_node_id()
|
||||
@@ -51,17 +64,22 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
protocol = ier.get_protocol() # e.g. 'TCP'
|
||||
port = ier.get_port()
|
||||
|
||||
matching_rules = acl.get_relevant_rules(source_node_address,
|
||||
dest_node_address, protocol,
|
||||
port)
|
||||
matching_rules = acl.get_relevant_rules(
|
||||
source_node_address, dest_node_address, protocol, port
|
||||
)
|
||||
return matching_rules
|
||||
|
||||
def get_blocking_acl_rules_for_ier(self, ier, acl, nodes):
|
||||
"""
|
||||
Get blocking ACL rules for an IER.
|
||||
Warning: Can return empty dict but IER can still be blocked by default (No ALLOW rule, therefore blocked)
|
||||
"""
|
||||
|
||||
.. warning::
|
||||
Can return empty dict but IER can still be blocked by default
|
||||
(No ALLOW rule, therefore blocked).
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
|
||||
|
||||
blocked_rules = {}
|
||||
@@ -74,8 +92,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
def get_allow_acl_rules_for_ier(self, ier, acl, nodes):
|
||||
"""
|
||||
Get all allowing ACL rules for an IER.
|
||||
"""
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
|
||||
|
||||
allowed_rules = {}
|
||||
@@ -85,9 +105,22 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
return allowed_rules
|
||||
|
||||
def get_matching_acl_rules(self, source_node_id, dest_node_id, protocol,
|
||||
port, acl,
|
||||
nodes, services_list):
|
||||
def get_matching_acl_rules(
|
||||
self,
|
||||
source_node_id,
|
||||
dest_node_id,
|
||||
protocol,
|
||||
port,
|
||||
acl,
|
||||
nodes,
|
||||
services_list,
|
||||
):
|
||||
"""
|
||||
Get matching ACL rules.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
if source_node_id != "ANY":
|
||||
source_node_address = nodes[str(source_node_id)].ip_address
|
||||
else:
|
||||
@@ -100,21 +133,39 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
if protocol != "ANY":
|
||||
protocol = services_list[
|
||||
protocol - 1] # -1 as dont have to account for ANY in list of services
|
||||
protocol - 1
|
||||
] # -1 as dont have to account for ANY in list of services
|
||||
|
||||
matching_rules = acl.get_relevant_rules(source_node_address,
|
||||
dest_node_address, protocol,
|
||||
port)
|
||||
matching_rules = acl.get_relevant_rules(
|
||||
source_node_address, dest_node_address, protocol, port
|
||||
)
|
||||
return matching_rules
|
||||
|
||||
def get_allow_acl_rules(self, source_node_id, dest_node_id, protocol,
|
||||
port, acl,
|
||||
nodes, services_list):
|
||||
matching_rules = self.get_matching_acl_rules(source_node_id,
|
||||
dest_node_id,
|
||||
protocol, port, acl,
|
||||
nodes,
|
||||
services_list)
|
||||
def get_allow_acl_rules(
|
||||
self,
|
||||
source_node_id,
|
||||
dest_node_id,
|
||||
protocol,
|
||||
port,
|
||||
acl,
|
||||
nodes,
|
||||
services_list,
|
||||
):
|
||||
"""
|
||||
Get the ALLOW ACL rules.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules(
|
||||
source_node_id,
|
||||
dest_node_id,
|
||||
protocol,
|
||||
port,
|
||||
acl,
|
||||
nodes,
|
||||
services_list,
|
||||
)
|
||||
|
||||
allowed_rules = {}
|
||||
for rule_key, rule_value in matching_rules.items():
|
||||
@@ -123,14 +174,31 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
return allowed_rules
|
||||
|
||||
def get_deny_acl_rules(self, source_node_id, dest_node_id, protocol, port,
|
||||
acl,
|
||||
nodes, services_list):
|
||||
matching_rules = self.get_matching_acl_rules(source_node_id,
|
||||
dest_node_id,
|
||||
protocol, port, acl,
|
||||
nodes,
|
||||
services_list)
|
||||
def get_deny_acl_rules(
|
||||
self,
|
||||
source_node_id,
|
||||
dest_node_id,
|
||||
protocol,
|
||||
port,
|
||||
acl,
|
||||
nodes,
|
||||
services_list,
|
||||
):
|
||||
"""
|
||||
Get the DENY ACL rules.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules(
|
||||
source_node_id,
|
||||
dest_node_id,
|
||||
protocol,
|
||||
port,
|
||||
acl,
|
||||
nodes,
|
||||
services_list,
|
||||
)
|
||||
|
||||
allowed_rules = {}
|
||||
for rule_key, rule_value in matching_rules.items():
|
||||
@@ -141,7 +209,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
def _calculate_action_full_view(self, obs):
|
||||
"""
|
||||
Given an observation and the environment calculate a good acl-based action for the blue agent to take
|
||||
Calculate a good acl-based action for the blue agent to take.
|
||||
|
||||
Knowledge of just the observation space is insufficient for a perfect solution, as we need to know:
|
||||
|
||||
@@ -167,8 +235,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing
|
||||
an overwhelmed state, so we don't do this.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
#obs = convert_to_old_obs(obs)
|
||||
# obs = convert_to_old_obs(obs)
|
||||
r_obs = transform_change_obs_readable(obs)
|
||||
_, _, _, *s = r_obs
|
||||
|
||||
@@ -184,7 +254,6 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
for service_num, service_states in enumerate(s):
|
||||
for x, service_state in enumerate(service_states):
|
||||
if service_state == "COMPROMISED":
|
||||
|
||||
action_source_id = x + 1 # +1 as 0 is any
|
||||
action_destination_id = "ANY"
|
||||
action_protocol = service_num + 1 # +1 as 0 is any
|
||||
@@ -215,19 +284,23 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
action_permission = "ALLOW"
|
||||
action_source_ip = rule.get_source_ip()
|
||||
action_source_id = int(
|
||||
get_node_of_ip(action_source_ip, self._env.nodes))
|
||||
get_node_of_ip(action_source_ip, self._env.nodes)
|
||||
)
|
||||
action_destination_ip = rule.get_dest_ip()
|
||||
action_destination_id = int(
|
||||
get_node_of_ip(action_destination_ip,
|
||||
self._env.nodes))
|
||||
get_node_of_ip(
|
||||
action_destination_ip, self._env.nodes
|
||||
)
|
||||
)
|
||||
action_protocol_name = rule.get_protocol()
|
||||
action_protocol = (
|
||||
self._env.services_list.index(
|
||||
action_protocol_name) + 1
|
||||
self._env.services_list.index(action_protocol_name)
|
||||
+ 1
|
||||
) # convert name e.g. 'TCP' to index
|
||||
action_port_name = rule.get_port()
|
||||
action_port = self._env.ports_list.index(
|
||||
action_port_name) + 1 # convert port name e.g. '80' to index
|
||||
action_port = (
|
||||
self._env.ports_list.index(action_port_name) + 1
|
||||
) # convert port name e.g. '80' to index
|
||||
|
||||
found_action = True
|
||||
break
|
||||
@@ -258,21 +331,21 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
if not found_action:
|
||||
# Which Green IERS are blocked
|
||||
blocked_green_iers = self.get_blocked_green_iers(
|
||||
self._env.green_iers, self._env.acl,
|
||||
self._env.nodes)
|
||||
self._env.green_iers, self._env.acl, self._env.nodes
|
||||
)
|
||||
for ier_key, ier in blocked_green_iers.items():
|
||||
|
||||
# Which ALLOW rules are allowing this IER (none)
|
||||
allowing_rules = self.get_allow_acl_rules_for_ier(ier,
|
||||
self._env.acl,
|
||||
self._env.nodes)
|
||||
allowing_rules = self.get_allow_acl_rules_for_ier(
|
||||
ier, self._env.acl, self._env.nodes
|
||||
)
|
||||
|
||||
# If there are no blocking rules, it may be being blocked by default
|
||||
# If there is already an allow rule
|
||||
node_id_to_check = int(ier.get_source_node_id())
|
||||
service_name_to_check = ier.get_protocol()
|
||||
service_id_to_check = self._env.services_list.index(
|
||||
service_name_to_check)
|
||||
service_name_to_check
|
||||
)
|
||||
|
||||
# Service state of the the source node in the ier
|
||||
service_state = s[service_id_to_check][node_id_to_check - 1]
|
||||
@@ -283,11 +356,13 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
action_source_id = int(ier.get_source_node_id())
|
||||
action_destination_id = int(ier.get_dest_node_id())
|
||||
action_protocol_name = ier.get_protocol()
|
||||
action_protocol = self._env.services_list.index(
|
||||
action_protocol_name) + 1 # convert name e.g. 'TCP' to index
|
||||
action_protocol = (
|
||||
self._env.services_list.index(action_protocol_name) + 1
|
||||
) # convert name e.g. 'TCP' to index
|
||||
action_port_name = ier.get_port()
|
||||
action_port = self._env.ports_list.index(
|
||||
action_port_name) + 1 # convert port name e.g. '80' to index
|
||||
action_port = (
|
||||
self._env.ports_list.index(action_port_name) + 1
|
||||
) # convert port name e.g. '80' to index
|
||||
|
||||
found_action = True
|
||||
break
|
||||
@@ -311,19 +386,25 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
return action
|
||||
|
||||
def _calculate_action_basic_view(self, obs):
|
||||
"""Given an observation calculate a good acl-based action for the blue agent to take
|
||||
"""Calculate a good acl-based action for the blue agent to take.
|
||||
|
||||
Uses ONLY information from the current observation with NO knowledge of previous actions taken and
|
||||
NO reward feedback.
|
||||
Uses ONLY information from the current observation with NO knowledge
|
||||
of previous actions taken and NO reward feedback.
|
||||
|
||||
We rely on randomness to select the precise action, as we want to block all traffic originating from
|
||||
a compromised node, without being able to tell:
|
||||
We rely on randomness to select the precise action, as we want to
|
||||
block all traffic originating from a compromised node, without being
|
||||
able to tell:
|
||||
1. Which ACL rules already exist
|
||||
1. Which actions the agent has already tried.
|
||||
2. Which actions the agent has already tried.
|
||||
|
||||
There is a high probability that the correct rule will not be deleted before the state becomes overwhelmed.
|
||||
There is a high probability that the correct rule will not be deleted
|
||||
before the state becomes overwhelmed.
|
||||
|
||||
Currently a deny rule does not overwrite an allow rule. The allow rules must be deleted.
|
||||
Currently, a deny rule does not overwrite an allow rule. The allow
|
||||
rules must be deleted.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
action_dict = self._env.action_dict
|
||||
r_obs = transform_change_obs_readable(obs)
|
||||
@@ -333,27 +414,35 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
s = [*s]
|
||||
|
||||
number_of_nodes = len(
|
||||
[i for i in o if i != "NONE"]) # number of nodes (not links)
|
||||
[i for i in o if i != "NONE"]
|
||||
) # number of nodes (not links)
|
||||
for service_num, service_states in enumerate(s):
|
||||
comprimised_states = [n for n, i in enumerate(service_states) if
|
||||
i == "COMPROMISED"]
|
||||
comprimised_states = [
|
||||
n for n, i in enumerate(service_states) if i == "COMPROMISED"
|
||||
]
|
||||
if len(comprimised_states) == 0:
|
||||
# No states are COMPROMISED, try the next service
|
||||
continue
|
||||
|
||||
compromised_node = np.random.choice(
|
||||
comprimised_states) + 1 # +1 as 0 would be any
|
||||
compromised_node = (
|
||||
np.random.choice(comprimised_states) + 1
|
||||
) # +1 as 0 would be any
|
||||
action_decision = "DELETE"
|
||||
action_permission = "ALLOW"
|
||||
action_source_ip = compromised_node
|
||||
# Randomly select a destination ID to block
|
||||
action_destination_ip = np.random.choice(
|
||||
list(range(1, number_of_nodes + 1)) + ["ANY"])
|
||||
action_destination_ip = int(
|
||||
action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip
|
||||
list(range(1, number_of_nodes + 1)) + ["ANY"]
|
||||
)
|
||||
action_destination_ip = (
|
||||
int(action_destination_ip)
|
||||
if action_destination_ip != "ANY"
|
||||
else action_destination_ip
|
||||
)
|
||||
action_protocol = service_num + 1 # +1 as 0 is any
|
||||
# Randomly select a port to block
|
||||
# Bad assumption that number of protocols equals number of ports AND no rules exist with an ANY port
|
||||
# Bad assumption that number of protocols equals number of ports
|
||||
# AND no rules exist with an ANY port
|
||||
action_port = np.random.choice(list(range(1, len(s) + 1)))
|
||||
|
||||
action = [
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
from primaite.agents.agent import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import (
|
||||
get_new_action,
|
||||
transform_change_obs_readable,
|
||||
)
|
||||
from primaite.agents.utils import (
|
||||
transform_action_node_enum,
|
||||
transform_change_obs_readable,
|
||||
)
|
||||
|
||||
|
||||
class HardCodedNodeAgent(HardCodedAgentSessionABC):
|
||||
"""An Agent Session class that implements a deterministic Node agent."""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
"""Given an observation calculate a good node-based action for the blue agent to take"""
|
||||
"""
|
||||
Calculate a good node-based action for the blue agent to take.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
action_dict = self._env.action_dict
|
||||
r_obs = transform_change_obs_readable(obs)
|
||||
_, o, os, *s = r_obs
|
||||
@@ -18,7 +23,8 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
|
||||
if len(r_obs) == 4: # only 1 service
|
||||
s = [*s]
|
||||
|
||||
# Check in order of most important states (order doesn't currently matter, but it probably should)
|
||||
# Check in order of most important states (order doesn't currently
|
||||
# matter, but it probably should)
|
||||
# First see if any OS states are compromised
|
||||
for x, os_state in enumerate(os):
|
||||
if os_state == "COMPROMISED":
|
||||
@@ -26,8 +32,12 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
|
||||
action_node_property = "OS"
|
||||
property_action = "PATCHING"
|
||||
action_service_index = 0 # does nothing isn't relevant for os
|
||||
action = [action_node_id, action_node_property,
|
||||
property_action, action_service_index]
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
@@ -44,8 +54,12 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
|
||||
property_action = "PATCHING"
|
||||
action_service_index = service_num
|
||||
|
||||
action = [action_node_id, action_node_property,
|
||||
property_action, action_service_index]
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
@@ -63,8 +77,12 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
|
||||
property_action = "PATCHING"
|
||||
action_service_index = service_num
|
||||
|
||||
action = [action_node_id, action_node_property,
|
||||
property_action, action_service_index]
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
@@ -75,10 +93,18 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
|
||||
if os_state == "OFF":
|
||||
action_node_id = x + 1
|
||||
action_node_property = "OPERATING"
|
||||
property_action = "ON" # Why reset it when we can just turn it on
|
||||
action_service_index = 0 # does nothing isn't relevant for operating state
|
||||
action = [action_node_id, action_node_property,
|
||||
property_action, action_service_index]
|
||||
property_action = (
|
||||
"ON" # Why reset it when we can just turn it on
|
||||
)
|
||||
action_service_index = (
|
||||
0 # does nothing isn't relevant for operating state
|
||||
)
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
action = transform_action_node_enum(action, action_dict)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
@@ -89,8 +115,12 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
|
||||
action_node_property = "NONE"
|
||||
property_action = "ON"
|
||||
action_service_index = 0
|
||||
action = [action_node_id, action_node_property, property_action,
|
||||
action_service_index]
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
|
||||
|
||||
@@ -1,28 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import tensorflow as tf
|
||||
from ray.rllib.algorithms import Algorithm
|
||||
from ray.rllib.algorithms.a2c import A2CConfig
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.registry import register_env
|
||||
import tensorflow as tf
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier, \
|
||||
DeepLearningFramework
|
||||
from primaite.common.enums import (
|
||||
AgentFramework,
|
||||
AgentIdentifier,
|
||||
DeepLearningFramework,
|
||||
)
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def _env_creator(env_config):
|
||||
return Primaite(
|
||||
training_config_path=env_config["training_config_path"],
|
||||
lay_down_config_path=env_config["lay_down_config_path"],
|
||||
session_path=env_config["session_path"],
|
||||
timestamp_str=env_config["timestamp_str"]
|
||||
timestamp_str=env_config["timestamp_str"],
|
||||
)
|
||||
|
||||
|
||||
@@ -37,16 +44,15 @@ def _custom_log_creator(session_path: Path):
|
||||
|
||||
|
||||
class RLlibAgent(AgentSessionABC):
|
||||
"""An AgentSession class that implements a Ray RLlib agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path,
|
||||
lay_down_config_path
|
||||
):
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
if not self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
msg = (f"Expected RLLIB agent_framework, "
|
||||
f"got {self._training_config.agent_framework}")
|
||||
msg = (
|
||||
f"Expected RLLIB agent_framework, "
|
||||
f"got {self._training_config.agent_framework}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
||||
@@ -54,8 +60,10 @@ class RLlibAgent(AgentSessionABC):
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
||||
self._agent_config_class = A2CConfig
|
||||
else:
|
||||
msg = ("Expected PPO or A2C agent_identifier, "
|
||||
f"got {self._training_config.agent_identifier.value}")
|
||||
msg = (
|
||||
"Expected PPO or A2C agent_identifier, "
|
||||
f"got {self._training_config.agent_identifier.value}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
self._agent_config: PPOConfig
|
||||
@@ -86,8 +94,12 @@ class RLlibAgent(AgentSessionABC):
|
||||
metadata_dict = json.load(file)
|
||||
|
||||
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
metadata_dict["total_episodes"] = self._current_result["episodes_total"]
|
||||
metadata_dict["total_time_steps"] = self._current_result["timesteps_total"]
|
||||
metadata_dict["total_episodes"] = self._current_result[
|
||||
"episodes_total"
|
||||
]
|
||||
metadata_dict["total_time_steps"] = self._current_result[
|
||||
"timesteps_total"
|
||||
]
|
||||
|
||||
filepath = self.session_path / "session_metadata.json"
|
||||
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||
@@ -106,43 +118,48 @@ class RLlibAgent(AgentSessionABC):
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str
|
||||
)
|
||||
timestamp_str=self.timestamp_str,
|
||||
),
|
||||
)
|
||||
|
||||
self._agent_config.training(
|
||||
train_batch_size=self._training_config.num_steps
|
||||
)
|
||||
self._agent_config.framework(
|
||||
framework="tf"
|
||||
)
|
||||
self._agent_config.framework(framework="tf")
|
||||
|
||||
self._agent_config.rollouts(
|
||||
num_rollout_workers=1,
|
||||
num_envs_per_worker=1,
|
||||
horizon=self._training_config.num_steps
|
||||
horizon=self._training_config.num_steps,
|
||||
)
|
||||
self._agent: Algorithm = self._agent_config.build(
|
||||
logger_creator=_custom_log_creator(self.session_path)
|
||||
)
|
||||
|
||||
|
||||
def _save_checkpoint(self):
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._current_result["episodes_total"]
|
||||
if checkpoint_n > 0 and episode_count > 0:
|
||||
if (
|
||||
(episode_count % checkpoint_n == 0)
|
||||
or (episode_count == self._training_config.num_episodes)
|
||||
if (episode_count % checkpoint_n == 0) or (
|
||||
episode_count == self._training_config.num_episodes
|
||||
):
|
||||
self._agent.save(self.checkpoints_path)
|
||||
self._agent.save(str(self.checkpoints_path))
|
||||
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
# Temporarily override train_batch_size and horizon
|
||||
if time_steps:
|
||||
self._agent_config.train_batch_size = time_steps
|
||||
@@ -150,37 +167,53 @@ class RLlibAgent(AgentSessionABC):
|
||||
|
||||
if not episodes:
|
||||
episodes = self._training_config.num_episodes
|
||||
_LOGGER.info(f"Beginning learning for {episodes} episodes @"
|
||||
f" {time_steps} time steps...")
|
||||
_LOGGER.info(
|
||||
f"Beginning learning for {episodes} episodes @"
|
||||
f" {time_steps} time steps..."
|
||||
)
|
||||
for i in range(episodes):
|
||||
self._current_result = self._agent.train()
|
||||
self._save_checkpoint()
|
||||
if self._training_config.deep_learning_framework != DeepLearningFramework.TORCH:
|
||||
if (
|
||||
self._training_config.deep_learning_framework
|
||||
!= DeepLearningFramework.TORCH
|
||||
):
|
||||
policy = self._agent.get_policy()
|
||||
tf.compat.v1.summary.FileWriter(
|
||||
self.session_path / "ray_results",
|
||||
policy.get_session().graph
|
||||
self.session_path / "ray_results", policy.get_session().graph
|
||||
)
|
||||
super().learn()
|
||||
self._agent.stop()
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def load(cls):
|
||||
def load(cls, path: Union[str, Path]) -> RLlibAgent:
|
||||
"""Load an agent from file."""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self):
|
||||
"""Save the agent."""
|
||||
raise NotImplementedError
|
||||
|
||||
def export(self):
|
||||
"""Export the agent to transportable file format."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,27 +1,30 @@
|
||||
from typing import Optional
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import PPO, A2C
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.common.enums import AgentIdentifier, AgentFramework
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class SB3Agent(AgentSessionABC):
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path,
|
||||
lay_down_config_path
|
||||
):
|
||||
"""An AgentSession class that implements a Stable Baselines3 agent."""
|
||||
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
if not self._training_config.agent_framework == AgentFramework.SB3:
|
||||
msg = (f"Expected SB3 agent_framework, "
|
||||
f"got {self._training_config.agent_framework}")
|
||||
msg = (
|
||||
f"Expected SB3 agent_framework, "
|
||||
f"got {self._training_config.agent_framework}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
||||
@@ -29,8 +32,10 @@ class SB3Agent(AgentSessionABC):
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
||||
self._agent_class = A2C
|
||||
else:
|
||||
msg = ("Expected PPO or A2C agent_identifier, "
|
||||
f"got {self._training_config.agent_identifier.value}")
|
||||
msg = (
|
||||
"Expected PPO or A2C agent_identifier, "
|
||||
f"got {self._training_config.agent_identifier}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
@@ -52,25 +57,26 @@ class SB3Agent(AgentSessionABC):
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str
|
||||
timestamp_str=self.timestamp_str,
|
||||
)
|
||||
self._agent = self._agent_class(
|
||||
PPOMlp,
|
||||
self._env,
|
||||
verbose=self.output_verbose_level,
|
||||
n_steps=self._training_config.num_steps,
|
||||
tensorboard_log=self._tensorboard_log_path
|
||||
tensorboard_log=self._tensorboard_log_path,
|
||||
)
|
||||
|
||||
def _save_checkpoint(self):
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._env.episode_count
|
||||
if checkpoint_n > 0 and episode_count > 0:
|
||||
if (
|
||||
(episode_count % checkpoint_n == 0)
|
||||
or (episode_count == self._training_config.num_episodes)
|
||||
if (episode_count % checkpoint_n == 0) or (
|
||||
episode_count == self._training_config.num_episodes
|
||||
):
|
||||
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
|
||||
checkpoint_path = (
|
||||
self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
|
||||
)
|
||||
self._agent.save(checkpoint_path)
|
||||
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
|
||||
|
||||
@@ -78,33 +84,54 @@ class SB3Agent(AgentSessionABC):
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
if not time_steps:
|
||||
time_steps = self._training_config.num_steps
|
||||
|
||||
if not episodes:
|
||||
episodes = self._training_config.num_episodes
|
||||
self.is_eval = False
|
||||
_LOGGER.info(f"Beginning learning for {episodes} episodes @"
|
||||
f" {time_steps} time steps...")
|
||||
_LOGGER.info(
|
||||
f"Beginning learning for {episodes} episodes @"
|
||||
f" {time_steps} time steps..."
|
||||
)
|
||||
for i in range(episodes):
|
||||
self._agent.learn(total_timesteps=time_steps)
|
||||
self._save_checkpoint()
|
||||
|
||||
self._env.close()
|
||||
self.close()
|
||||
super().learn()
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
deterministic: bool = True,
|
||||
**kwargs
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
deterministic: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param deterministic: Whether the evaluation is deterministic.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
if not time_steps:
|
||||
time_steps = self._training_config.num_steps
|
||||
|
||||
@@ -116,27 +143,31 @@ class SB3Agent(AgentSessionABC):
|
||||
deterministic_str = "deterministic"
|
||||
else:
|
||||
deterministic_str = "non-deterministic"
|
||||
_LOGGER.info(f"Beginning {deterministic_str} evaluation for "
|
||||
f"{episodes} episodes @ {time_steps} time steps...")
|
||||
_LOGGER.info(
|
||||
f"Beginning {deterministic_str} evaluation for "
|
||||
f"{episodes} episodes @ {time_steps} time steps..."
|
||||
)
|
||||
for episode in range(episodes):
|
||||
obs = self._env.reset()
|
||||
|
||||
for step in range(time_steps):
|
||||
action, _states = self._agent.predict(
|
||||
obs,
|
||||
deterministic=deterministic
|
||||
obs, deterministic=deterministic
|
||||
)
|
||||
if isinstance(action, np.ndarray):
|
||||
action = np.int64(action)
|
||||
obs, rewards, done, info = self._env.step(action)
|
||||
_LOGGER.info(f"Finished evaluation")
|
||||
super().evaluate()
|
||||
|
||||
@classmethod
|
||||
def load(self):
|
||||
def load(cls, path: Union[str, Path]) -> SB3Agent:
|
||||
"""Load an agent from file."""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self):
|
||||
"""Save the agent."""
|
||||
raise NotImplementedError
|
||||
|
||||
def export(self):
|
||||
"""Export the agent to transportable file format."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -4,9 +4,9 @@ from primaite.common.enums import (
|
||||
HardwareState,
|
||||
LinkStatus,
|
||||
NodeHardwareAction,
|
||||
NodePOLType,
|
||||
NodeSoftwareAction,
|
||||
SoftwareState,
|
||||
NodePOLType
|
||||
)
|
||||
|
||||
|
||||
@@ -16,14 +16,17 @@ def transform_action_node_readable(action):
|
||||
|
||||
example:
|
||||
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
action_node_property = NodePOLType(action[1]).name
|
||||
|
||||
if action_node_property == "OPERATING":
|
||||
property_action = NodeHardwareAction(action[2]).name
|
||||
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[
|
||||
2
|
||||
] <= 1:
|
||||
elif (
|
||||
action_node_property == "OS" or action_node_property == "SERVICE"
|
||||
) and action[2] <= 1:
|
||||
property_action = NodeSoftwareAction(action[2]).name
|
||||
else:
|
||||
property_action = "NONE"
|
||||
@@ -38,6 +41,9 @@ def transform_action_acl_readable(action):
|
||||
|
||||
example:
|
||||
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"}
|
||||
action_permissions = {0: "DENY", 1: "ALLOW"}
|
||||
@@ -62,6 +68,9 @@ def is_valid_node_action(action):
|
||||
Does NOT consider:
|
||||
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
|
||||
- Node already being in that state (turning an ON node ON)
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
action_r = transform_action_node_readable(action)
|
||||
|
||||
@@ -77,7 +86,10 @@ def is_valid_node_action(action):
|
||||
if node_property == "OPERATING" and node_action == "PATCHING":
|
||||
# Operating State cannot PATCH
|
||||
return False
|
||||
if node_property != "OPERATING" and node_action not in ["NONE", "PATCHING"]:
|
||||
if node_property != "OPERATING" and node_action not in [
|
||||
"NONE",
|
||||
"PATCHING",
|
||||
]:
|
||||
# Software States can only do Nothing or Patch
|
||||
return False
|
||||
return True
|
||||
@@ -92,6 +104,9 @@ def is_valid_acl_action(action):
|
||||
Does NOT consider:
|
||||
- Trying to create identical rules
|
||||
- Trying to create a rule which is a subset of another rule (caused by "ANY")
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
action_r = transform_action_acl_readable(action)
|
||||
|
||||
@@ -118,7 +133,12 @@ def is_valid_acl_action(action):
|
||||
|
||||
|
||||
def is_valid_acl_action_extra(action):
|
||||
"""Harsher version of valid acl actions, does not allow action."""
|
||||
"""
|
||||
Harsher version of valid acl actions, does not allow action.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
if is_valid_acl_action(action) is False:
|
||||
return False
|
||||
|
||||
@@ -136,13 +156,15 @@ def is_valid_acl_action_extra(action):
|
||||
return True
|
||||
|
||||
|
||||
|
||||
def transform_change_obs_readable(obs):
|
||||
"""Transform list of transactions to readable list of each observation property
|
||||
"""
|
||||
Transform list of transactions to readable list of each observation property.
|
||||
|
||||
example:
|
||||
|
||||
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']]
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
ids = [i for i in obs[:, 0]]
|
||||
operating_states = [HardwareState(i).name for i in obs[:, 1]]
|
||||
@@ -151,7 +173,9 @@ def transform_change_obs_readable(obs):
|
||||
|
||||
for service in range(3, obs.shape[1]):
|
||||
# Links bit/s don't have a service state
|
||||
service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]]
|
||||
service_states = [
|
||||
SoftwareState(i).name if i <= 4 else i for i in obs[:, service]
|
||||
]
|
||||
new_obs.append(service_states)
|
||||
|
||||
return new_obs
|
||||
@@ -159,10 +183,13 @@ def transform_change_obs_readable(obs):
|
||||
|
||||
def transform_obs_readable(obs):
|
||||
"""
|
||||
example:
|
||||
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]
|
||||
"""
|
||||
Transform observation to readable format.
|
||||
|
||||
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
changed_obs = transform_change_obs_readable(obs)
|
||||
new_obs = list(zip(*changed_obs))
|
||||
# Convert list of tuples to list of lists
|
||||
@@ -172,7 +199,12 @@ def transform_obs_readable(obs):
|
||||
|
||||
|
||||
def convert_to_new_obs(obs, num_nodes=10):
|
||||
"""Convert original gym Box observation space to new multiDiscrete observation space"""
|
||||
"""
|
||||
Convert original gym Box observation space to new multiDiscrete observation space.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
# Remove ID columns, remove links and flatten to MultiDiscrete observation space
|
||||
new_obs = obs[:num_nodes, 1:].flatten()
|
||||
return new_obs
|
||||
@@ -180,7 +212,9 @@ def convert_to_new_obs(obs, num_nodes=10):
|
||||
|
||||
def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
|
||||
"""
|
||||
Convert to old observation, links filled with 0's as no information is included in new observation space
|
||||
Convert to old observation.
|
||||
|
||||
Links filled with 0's as no information is included in new observation space.
|
||||
|
||||
example:
|
||||
obs = array([1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1, 1])
|
||||
@@ -190,13 +224,17 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
|
||||
[ 3, 1, 1, 1],
|
||||
...
|
||||
[20, 0, 0, 0]])
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
|
||||
# Convert back to more readable, original format
|
||||
reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2)
|
||||
|
||||
# Add empty links back and add node ID back
|
||||
s = np.zeros([reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1], dtype=np.int64)
|
||||
s = np.zeros(
|
||||
[reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1],
|
||||
dtype=np.int64,
|
||||
)
|
||||
s[:, 0] = range(1, num_nodes + num_links + 1) # Adding ID back
|
||||
s[:num_nodes, 1:] = reshaped_nodes # put values back in
|
||||
new_obs = s
|
||||
@@ -209,14 +247,19 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
|
||||
return new_obs
|
||||
|
||||
|
||||
def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
|
||||
"""Return string describing change between two observations
|
||||
def describe_obs_change(
|
||||
obs1, obs2, num_nodes=10, num_links=10, num_services=1
|
||||
):
|
||||
"""
|
||||
Return string describing change between two observations.
|
||||
|
||||
example:
|
||||
obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]])
|
||||
obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]])
|
||||
output = 'ID 1: SERVICE 2 set to GOOD'
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services)
|
||||
obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services)
|
||||
@@ -236,20 +279,27 @@ def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
|
||||
|
||||
|
||||
def _describe_obs_change_helper(obs_change, is_link):
|
||||
""" "
|
||||
Helper funcion to describe what has changed
|
||||
"""
|
||||
Helper funcion to describe what has changed.
|
||||
|
||||
example:
|
||||
[ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD"
|
||||
|
||||
Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.'
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
# Indexes where a change has occured, not including 0th index
|
||||
index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1]
|
||||
index_changed = [
|
||||
i for i in range(1, len(obs_change)) if obs_change[i] != -1
|
||||
]
|
||||
# Node pol types, Indexes >= 3 are service nodes
|
||||
NodePOLTypes = [
|
||||
NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) for i in index_changed
|
||||
NodePOLType(i).name
|
||||
if i < 3
|
||||
else NodePOLType(3).name + " " + str(i - 3)
|
||||
for i in index_changed
|
||||
]
|
||||
# Account for hardware states, software sattes and links
|
||||
states = [
|
||||
@@ -263,8 +313,8 @@ def _describe_obs_change_helper(obs_change, is_link):
|
||||
|
||||
if not is_link:
|
||||
desc = f"ID {obs_change[0]}:"
|
||||
for NodePOLType, state in list(zip(NodePOLTypes, states)):
|
||||
desc = desc + " " + NodePOLType + " changed to " + state + "."
|
||||
for node_pol_type, state in list(zip(NodePOLTypes, states)):
|
||||
desc = desc + " " + node_pol_type + " changed to " + state + "."
|
||||
else:
|
||||
desc = f"ID {obs_change[0]}: Link traffic changed to {states[0]}."
|
||||
|
||||
@@ -273,12 +323,14 @@ def _describe_obs_change_helper(obs_change, is_link):
|
||||
|
||||
def transform_action_node_enum(action):
|
||||
"""
|
||||
Convert a node action from readable string format, to enumerated format
|
||||
Convert a node action from readable string format, to enumerated format.
|
||||
|
||||
example:
|
||||
[1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0]
|
||||
"""
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
action_node_id = action[0]
|
||||
action_node_property = NodePOLType[action[1]].value
|
||||
|
||||
@@ -291,24 +343,33 @@ def transform_action_node_enum(action):
|
||||
|
||||
action_service_index = action[3]
|
||||
|
||||
new_action = [action_node_id, action_node_property, property_action, action_service_index]
|
||||
new_action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
|
||||
return new_action
|
||||
|
||||
|
||||
def transform_action_node_readable(action):
|
||||
"""
|
||||
Convert a node action from enumerated format to readable format
|
||||
Convert a node action from enumerated format to readable format.
|
||||
|
||||
example:
|
||||
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
|
||||
"""
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
action_node_property = NodePOLType(action[1]).name
|
||||
|
||||
if action_node_property == "OPERATING":
|
||||
property_action = NodeHardwareAction(action[2]).name
|
||||
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1:
|
||||
elif (
|
||||
action_node_property == "OS" or action_node_property == "SERVICE"
|
||||
) and action[2] <= 1:
|
||||
property_action = NodeSoftwareAction(action[2]).name
|
||||
else:
|
||||
property_action = "NONE"
|
||||
@@ -319,9 +380,11 @@ def transform_action_node_readable(action):
|
||||
|
||||
def node_action_description(action):
|
||||
"""
|
||||
Generate string describing a node-based action
|
||||
"""
|
||||
Generate string describing a node-based action.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
if isinstance(action[1], (int, np.int64)):
|
||||
# transform action to readable format
|
||||
action = transform_action_node_readable(action)
|
||||
@@ -334,7 +397,9 @@ def node_action_description(action):
|
||||
if property_action == "NONE":
|
||||
return ""
|
||||
if node_property == "OPERATING" or node_property == "OS":
|
||||
description = f"NODE {node_id}, {node_property}, SET TO {property_action}"
|
||||
description = (
|
||||
f"NODE {node_id}, {node_property}, SET TO {property_action}"
|
||||
)
|
||||
elif node_property == "SERVICE":
|
||||
description = f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}"
|
||||
else:
|
||||
@@ -343,34 +408,13 @@ def node_action_description(action):
|
||||
return description
|
||||
|
||||
|
||||
def transform_action_acl_readable(action):
|
||||
"""
|
||||
Transform an ACL action to a more readable format
|
||||
|
||||
example:
|
||||
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
|
||||
"""
|
||||
|
||||
action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"}
|
||||
action_permissions = {0: "DENY", 1: "ALLOW"}
|
||||
|
||||
action_decision = action_decisions[action[0]]
|
||||
action_permission = action_permissions[action[1]]
|
||||
|
||||
# For IPs, Ports and Protocols, 0 means any, otherwise its just an index
|
||||
new_action = [action_decision, action_permission] + list(action[2:6])
|
||||
for n, val in enumerate(list(action[2:6])):
|
||||
if val == 0:
|
||||
new_action[n + 2] = "ANY"
|
||||
|
||||
return new_action
|
||||
|
||||
|
||||
def transform_action_acl_enum(action):
|
||||
"""
|
||||
Convert a acl action from readable string format, to enumerated format
|
||||
"""
|
||||
Convert acl action from readable str format, to enumerated format.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2}
|
||||
action_permissions = {"DENY": 0, "ALLOW": 1}
|
||||
|
||||
@@ -388,8 +432,12 @@ def transform_action_acl_enum(action):
|
||||
|
||||
|
||||
def acl_action_description(action):
|
||||
"""generate string describing a acl-based action"""
|
||||
"""
|
||||
Generate string describing an acl-based action.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
if isinstance(action[0], (int, np.int64)):
|
||||
# transform action to readable format
|
||||
action = transform_action_acl_readable(action)
|
||||
@@ -406,11 +454,13 @@ def acl_action_description(action):
|
||||
|
||||
def get_node_of_ip(ip, node_dict):
|
||||
"""
|
||||
Get the node ID of an IP address
|
||||
Get the node ID of an IP address.
|
||||
|
||||
node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes)
|
||||
"""
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
for node_key, node_value in node_dict.items():
|
||||
node_ip = node_value.ip_address
|
||||
if node_ip == ip:
|
||||
@@ -418,13 +468,16 @@ def get_node_of_ip(ip, node_dict):
|
||||
|
||||
|
||||
def is_valid_node_action(action):
|
||||
"""Is the node action an actual valid action
|
||||
"""Is the node action an actual valid action.
|
||||
|
||||
Only uses information about the action to determine if the action has an effect
|
||||
|
||||
Does NOT consider:
|
||||
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
|
||||
- Node already being in that state (turning an ON node ON)
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
action_r = transform_action_node_readable(action)
|
||||
|
||||
@@ -438,7 +491,10 @@ def is_valid_node_action(action):
|
||||
if node_property == "OPERATING" and node_action == "PATCHING":
|
||||
# Operating State cannot PATCH
|
||||
return False
|
||||
if node_property != "OPERATING" and node_action not in ["NONE", "PATCHING"]:
|
||||
if node_property != "OPERATING" and node_action not in [
|
||||
"NONE",
|
||||
"PATCHING",
|
||||
]:
|
||||
# Software States can only do Nothing or Patch
|
||||
return False
|
||||
return True
|
||||
@@ -446,13 +502,16 @@ def is_valid_node_action(action):
|
||||
|
||||
def is_valid_acl_action(action):
|
||||
"""
|
||||
Is the ACL action an actual valid action
|
||||
Is the ACL action an actual valid action.
|
||||
|
||||
Only uses information about the action to determine if the action has an effect
|
||||
|
||||
Does NOT consider:
|
||||
- Trying to create identical rules
|
||||
- Trying to create a rule which is a subset of another rule (caused by "ANY")
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
action_r = transform_action_acl_readable(action)
|
||||
|
||||
@@ -463,7 +522,11 @@ def is_valid_acl_action(action):
|
||||
|
||||
if action_decision == "NONE":
|
||||
return False
|
||||
if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY":
|
||||
if (
|
||||
action_source_id == action_destination_id
|
||||
and action_source_id != "ANY"
|
||||
and action_destination_id != "ANY"
|
||||
):
|
||||
# ACL rule towards itself
|
||||
return False
|
||||
if action_permission == "DENY":
|
||||
@@ -475,7 +538,12 @@ def is_valid_acl_action(action):
|
||||
|
||||
|
||||
def is_valid_acl_action_extra(action):
|
||||
"""Harsher version of valid acl actions, does not allow action"""
|
||||
"""
|
||||
Harsher version of valid acl actions, does not allow action.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
if is_valid_acl_action(action) is False:
|
||||
return False
|
||||
|
||||
@@ -494,33 +562,17 @@ def is_valid_acl_action_extra(action):
|
||||
|
||||
|
||||
def get_new_action(old_action, action_dict):
|
||||
"""Get new action (e.g. 32) from old action e.g. [1,1,1,0]
|
||||
|
||||
old_action can be either node or acl action type
|
||||
"""
|
||||
Get new action (e.g. 32) from old action e.g. [1,1,1,0].
|
||||
|
||||
Old_action can be either node or acl action type
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
for key, val in action_dict.items():
|
||||
if list(val) == list(old_action):
|
||||
return key
|
||||
# Not all possible actions are included in dict, only valid action are
|
||||
# if action is not in the dict, its an invalid action so return 0
|
||||
return 0
|
||||
|
||||
|
||||
def get_action_description(action, action_dict):
|
||||
"""
|
||||
Get a string describing/explaining what an action is doing in words
|
||||
"""
|
||||
|
||||
action_array = action_dict[action]
|
||||
if len(action_array) == 4:
|
||||
# node actions have length 4
|
||||
action_description = node_action_description(action_array)
|
||||
elif len(action_array) == 6:
|
||||
# acl actions have length 6
|
||||
action_description = acl_action_description(action_array)
|
||||
else:
|
||||
# Should never happen
|
||||
action_description = "Unrecognised action"
|
||||
|
||||
return action_description
|
||||
|
||||
@@ -13,6 +13,8 @@ import yaml
|
||||
from platformdirs import PlatformDirs
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from primaite.data_viz import PlotlyTemplate
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@@ -54,7 +56,9 @@ def logs(last_n: Annotated[int, typer.Option("-n")]):
|
||||
print(re.sub(r"\n*", "", line))
|
||||
|
||||
|
||||
_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa
|
||||
_LogLevel = Enum(
|
||||
"LogLevel", {k: k for k in logging._levelToName.values()}
|
||||
) # noqa
|
||||
|
||||
|
||||
@app.command()
|
||||
@@ -76,11 +80,12 @@ def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None):
|
||||
primaite_config = yaml.safe_load(file)
|
||||
|
||||
if level:
|
||||
primaite_config["log_level"] = level.value
|
||||
primaite_config["logging"]["log_level"] = level.value
|
||||
with open(user_config_path, "w") as file:
|
||||
yaml.dump(primaite_config, file)
|
||||
print(f"PrimAITE Log Level: {level}")
|
||||
else:
|
||||
level = primaite_config["log_level"]
|
||||
level = primaite_config["logging"]["log_level"]
|
||||
print(f"PrimAITE Log Level: {level}")
|
||||
|
||||
|
||||
@@ -170,16 +175,50 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None):
|
||||
|
||||
ldc: The lay down config file path. Optional. If no value is passed then
|
||||
example default lay down config is used from:
|
||||
~/primaite/config/example_config/lay_down/lay_down_config_5_data_manipulation.yaml.
|
||||
~/primaite/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml.
|
||||
"""
|
||||
from primaite.main import run
|
||||
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from primaite.config.training_config import main_training_config_path
|
||||
from primaite.config.lay_down_config import data_manipulation_config_path
|
||||
from primaite.main import run
|
||||
|
||||
if not tc:
|
||||
tc = main_training_config_path()
|
||||
|
||||
if not ldc:
|
||||
ldc = data_manipulation_config_path()
|
||||
ldc = dos_very_basic_config_path()
|
||||
|
||||
run(training_config_path=tc, lay_down_config_path=ldc)
|
||||
|
||||
|
||||
@app.command()
|
||||
def plotly_template(
|
||||
template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None
|
||||
):
|
||||
"""
|
||||
View or set the plotly template for Session plots.
|
||||
|
||||
To View, simply call: primaite plotly-template
|
||||
|
||||
To set, call: primaite plotly-template <desired template>
|
||||
|
||||
For example, to set as plotly_dark, call: primaite plotly-template PLOTLY_DARK
|
||||
"""
|
||||
app_dirs = PlatformDirs(appname="primaite")
|
||||
app_dirs.user_config_path.mkdir(exist_ok=True, parents=True)
|
||||
user_config_path = app_dirs.user_config_path / "primaite_config.yaml"
|
||||
if user_config_path.exists():
|
||||
with open(user_config_path, "r") as file:
|
||||
primaite_config = yaml.safe_load(file)
|
||||
|
||||
if template:
|
||||
primaite_config["session"]["outputs"]["plots"][
|
||||
"template"
|
||||
] = template.value
|
||||
with open(user_config_path, "w") as file:
|
||||
yaml.dump(primaite_config, file)
|
||||
print(f"PrimAITE plotly template: {template.value}")
|
||||
else:
|
||||
template = primaite_config["session"]["outputs"]["plots"][
|
||||
"template"
|
||||
]
|
||||
print(f"PrimAITE plotly template: {template}")
|
||||
|
||||
@@ -83,6 +83,7 @@ class Protocol(Enum):
|
||||
|
||||
class SessionType(Enum):
|
||||
"""The type of PrimAITE Session to be run."""
|
||||
|
||||
TRAIN = 1
|
||||
"Train an agent"
|
||||
EVAL = 2
|
||||
@@ -93,6 +94,7 @@ class SessionType(Enum):
|
||||
|
||||
class VerboseLevel(IntEnum):
|
||||
"""PrimAITE Session Output verbose level."""
|
||||
|
||||
NO_OUTPUT = 0
|
||||
INFO = 1
|
||||
DEBUG = 2
|
||||
@@ -100,6 +102,7 @@ class VerboseLevel(IntEnum):
|
||||
|
||||
class AgentFramework(Enum):
|
||||
"""The agent algorithm framework/package."""
|
||||
|
||||
CUSTOM = 0
|
||||
"Custom Agent"
|
||||
SB3 = 1
|
||||
@@ -110,6 +113,7 @@ class AgentFramework(Enum):
|
||||
|
||||
class DeepLearningFramework(Enum):
|
||||
"""The deep learning framework."""
|
||||
|
||||
TF = "tf"
|
||||
"Tensorflow"
|
||||
TF2 = "tf2"
|
||||
@@ -120,6 +124,7 @@ class DeepLearningFramework(Enum):
|
||||
|
||||
class AgentIdentifier(Enum):
|
||||
"""The Red Agent algo/class."""
|
||||
|
||||
A2C = 1
|
||||
"Advantage Actor Critic"
|
||||
PPO = 2
|
||||
@@ -136,6 +141,7 @@ class AgentIdentifier(Enum):
|
||||
|
||||
class HardCodedAgentView(Enum):
|
||||
"""The view the deterministic hard-coded agent has of the environment."""
|
||||
|
||||
BASIC = 1
|
||||
"The current observation space only"
|
||||
FULL = 2
|
||||
@@ -144,6 +150,7 @@ class HardCodedAgentView(Enum):
|
||||
|
||||
class ActionType(Enum):
|
||||
"""Action type enumeration."""
|
||||
|
||||
NODE = 0
|
||||
ACL = 1
|
||||
ANY = 2
|
||||
@@ -151,6 +158,7 @@ class ActionType(Enum):
|
||||
|
||||
class ObservationType(Enum):
|
||||
"""Observation type enumeration."""
|
||||
|
||||
BOX = 0
|
||||
MULTIDISCRETE = 1
|
||||
|
||||
@@ -193,6 +201,7 @@ class LinkStatus(Enum):
|
||||
|
||||
class OutputVerboseLevel(IntEnum):
|
||||
"""The Agent output verbosity level."""
|
||||
|
||||
NONE = 0
|
||||
"No Output"
|
||||
INFO = 1
|
||||
|
||||
@@ -35,10 +35,10 @@ hard_coded_agent_view: FULL
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: NODE
|
||||
action_type: ANY
|
||||
|
||||
# Number of episodes to run per session
|
||||
num_episodes: 1000
|
||||
num_episodes: 10
|
||||
|
||||
# Number of time_steps per episode
|
||||
num_steps: 256
|
||||
@@ -47,14 +47,14 @@ num_steps: 256
|
||||
# Set to 0 if no checkpoints are required. Default is 10
|
||||
checkpoint_every_n_episodes: 10
|
||||
|
||||
# Time delay between steps (for generic agents)
|
||||
# Time delay (milliseconds) between steps for CUSTOM agents.
|
||||
time_delay: 5
|
||||
|
||||
# Type of session to be run. Options are:
|
||||
# "TRAIN" (Trains an agent)
|
||||
# "EVAL" (Evaluates an agent)
|
||||
# "TRAIN_EVAL" (Trains then evaluates an agent)
|
||||
session_type: TRAIN
|
||||
session_type: TRAIN_EVAL
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
from pathlib import Path
|
||||
from typing import Final, Union, Dict, Any
|
||||
from typing import Any, Dict, Final, Union
|
||||
|
||||
import networkx
|
||||
import yaml
|
||||
|
||||
from primaite import USERS_CONFIG_DIR, getLogger
|
||||
from primaite import getLogger, USERS_CONFIG_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_LAY_DOWN: Final[
|
||||
Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"
|
||||
_EXAMPLE_LAY_DOWN: Final[Path] = (
|
||||
USERS_CONFIG_DIR / "example_config" / "lay_down"
|
||||
)
|
||||
|
||||
|
||||
def convert_legacy_lay_down_config_dict(
|
||||
legacy_config_dict: Dict[str, Any]
|
||||
legacy_config_dict: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a legacy lay down config dict to the new format.
|
||||
@@ -25,10 +25,7 @@ def convert_legacy_lay_down_config_dict(
|
||||
return legacy_config_dict
|
||||
|
||||
|
||||
def load(
|
||||
file_path: Union[str, Path],
|
||||
legacy_file: bool = False
|
||||
) -> Dict:
|
||||
def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict:
|
||||
"""
|
||||
Read in a lay down config yaml file.
|
||||
|
||||
|
||||
@@ -7,15 +7,22 @@ from typing import Any, Dict, Final, Optional, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite import USERS_CONFIG_DIR, getLogger
|
||||
from primaite.common.enums import DeepLearningFramework, HardCodedAgentView
|
||||
from primaite.common.enums import ActionType, AgentIdentifier, \
|
||||
AgentFramework, SessionType, OutputVerboseLevel
|
||||
from primaite import getLogger, USERS_CONFIG_DIR
|
||||
from primaite.common.enums import (
|
||||
ActionType,
|
||||
AgentFramework,
|
||||
AgentIdentifier,
|
||||
DeepLearningFramework,
|
||||
HardCodedAgentView,
|
||||
OutputVerboseLevel,
|
||||
SessionType,
|
||||
)
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_TRAINING: Final[
|
||||
Path] = USERS_CONFIG_DIR / "example_config" / "training"
|
||||
_EXAMPLE_TRAINING: Final[Path] = (
|
||||
USERS_CONFIG_DIR / "example_config" / "training"
|
||||
)
|
||||
|
||||
|
||||
def main_training_config_path() -> Path:
|
||||
@@ -36,6 +43,7 @@ def main_training_config_path() -> Path:
|
||||
@dataclass()
|
||||
class TrainingConfig:
|
||||
"""The Training Config class."""
|
||||
|
||||
agent_framework: AgentFramework = AgentFramework.SB3
|
||||
"The AgentFramework"
|
||||
|
||||
@@ -171,12 +179,16 @@ class TrainingConfig:
|
||||
file_system_scanning_limit: int = 5
|
||||
"The time taken to scan the file system"
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
config_dict: Dict[str, Union[str, int, bool]]
|
||||
cls, config_dict: Dict[str, Union[str, int, bool]]
|
||||
) -> TrainingConfig:
|
||||
"""
|
||||
Create an instance of TrainingConfig from a dict.
|
||||
|
||||
:param config_dict: The training config dict.
|
||||
:return: The instance of TrainingConfig.
|
||||
"""
|
||||
field_enum_map = {
|
||||
"agent_framework": AgentFramework,
|
||||
"deep_learning_framework": DeepLearningFramework,
|
||||
@@ -187,9 +199,9 @@ class TrainingConfig:
|
||||
"hard_coded_agent_view": HardCodedAgentView,
|
||||
}
|
||||
|
||||
for field, enum_class in field_enum_map.items():
|
||||
if field in config_dict:
|
||||
config_dict[field] = enum_class[config_dict[field]]
|
||||
for key, value in field_enum_map.items():
|
||||
if key in config_dict:
|
||||
config_dict[key] = value[config_dict[key]]
|
||||
return TrainingConfig(**config_dict)
|
||||
|
||||
def to_dict(self, json_serializable: bool = True):
|
||||
@@ -213,23 +225,21 @@ class TrainingConfig:
|
||||
return data
|
||||
|
||||
def __str__(self) -> str:
|
||||
tc = f"TrainingConfig(agent_framework={self.agent_framework.name}, "
|
||||
tc = f"{self.agent_framework}, "
|
||||
if self.agent_framework is AgentFramework.RLLIB:
|
||||
tc += f"deep_learning_framework=" \
|
||||
f"{self.deep_learning_framework.name}, "
|
||||
tc += f"agent_identifier={self.agent_identifier.name}, "
|
||||
tc += f"{self.deep_learning_framework}, "
|
||||
tc += f"{self.agent_identifier}, "
|
||||
if self.agent_identifier is AgentIdentifier.HARDCODED:
|
||||
tc += f"hard_coded_agent_view={self.hard_coded_agent_view.name}, "
|
||||
tc += f"action_type={self.action_type.name}, "
|
||||
tc += f"{self.hard_coded_agent_view}, "
|
||||
tc += f"{self.action_type}, "
|
||||
tc += f"observation_space={self.observation_space}, "
|
||||
tc += f"num_episodes={self.num_episodes}, "
|
||||
tc += f"num_steps={self.num_steps})"
|
||||
tc += f"{self.num_episodes} episodes @ "
|
||||
tc += f"{self.num_steps} steps"
|
||||
return tc
|
||||
|
||||
|
||||
def load(
|
||||
file_path: Union[str, Path],
|
||||
legacy_file: bool = False
|
||||
file_path: Union[str, Path], legacy_file: bool = False
|
||||
) -> TrainingConfig:
|
||||
"""
|
||||
Read in a training config yaml file.
|
||||
@@ -273,12 +283,12 @@ def load(
|
||||
|
||||
|
||||
def convert_legacy_training_config_dict(
|
||||
legacy_config_dict: Dict[str, Any],
|
||||
agent_framework: AgentFramework = AgentFramework.SB3,
|
||||
agent_identifier: AgentIdentifier = AgentIdentifier.PPO,
|
||||
action_type: ActionType = ActionType.ANY,
|
||||
num_steps: int = 256,
|
||||
output_verbose_level: OutputVerboseLevel = OutputVerboseLevel.INFO
|
||||
legacy_config_dict: Dict[str, Any],
|
||||
agent_framework: AgentFramework = AgentFramework.SB3,
|
||||
agent_identifier: AgentIdentifier = AgentIdentifier.PPO,
|
||||
action_type: ActionType = ActionType.ANY,
|
||||
num_steps: int = 256,
|
||||
output_verbose_level: OutputVerboseLevel = OutputVerboseLevel.INFO,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a legacy training config dict to the new format.
|
||||
@@ -301,8 +311,12 @@ def convert_legacy_training_config_dict(
|
||||
"agent_identifier": agent_identifier.name,
|
||||
"action_type": action_type.name,
|
||||
"num_steps": num_steps,
|
||||
"output_verbose_level": output_verbose_level
|
||||
"output_verbose_level": output_verbose_level.name,
|
||||
}
|
||||
session_type_map = {"TRAINING": "TRAIN", "EVALUATION": "EVAL"}
|
||||
legacy_config_dict["sessionType"] = session_type_map[
|
||||
legacy_config_dict["sessionType"]
|
||||
]
|
||||
for legacy_key, value in legacy_config_dict.items():
|
||||
new_key = _get_new_key_from_legacy(legacy_key)
|
||||
if new_key:
|
||||
|
||||
13
src/primaite/data_viz/__init__.py
Normal file
13
src/primaite/data_viz/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class PlotlyTemplate(Enum):
|
||||
"""The built-in plotly templates."""
|
||||
|
||||
PLOTLY = "plotly"
|
||||
PLOTLY_WHITE = "plotly_white"
|
||||
PLOTLY_DARK = "plotly_dark"
|
||||
GGPLOT2 = "ggplot2"
|
||||
SEABORN = "seaborn"
|
||||
SIMPLE_WHITE = "simple_white"
|
||||
NONE = "none"
|
||||
73
src/primaite/data_viz/session_plots.py
Normal file
73
src/primaite/data_viz/session_plots.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import plotly.graph_objects as go
|
||||
import polars as pl
|
||||
import yaml
|
||||
from plotly.graph_objs import Figure
|
||||
|
||||
from primaite import _PLATFORM_DIRS
|
||||
|
||||
|
||||
def _get_plotly_config() -> Dict:
|
||||
"""Get the plotly config from primaite_config.yaml."""
|
||||
user_config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml"
|
||||
with open(user_config_path, "r") as file:
|
||||
primaite_config = yaml.safe_load(file)
|
||||
return primaite_config["session"]["outputs"]["plots"]
|
||||
|
||||
|
||||
def plot_av_reward_per_episode(
|
||||
av_reward_per_episode_csv: Union[str, Path],
|
||||
title: Optional[str] = None,
|
||||
subtitle: Optional[str] = None,
|
||||
) -> Figure:
|
||||
"""
|
||||
Plot the average reward per episode from a csv session output.
|
||||
|
||||
:param av_reward_per_episode_csv: The average reward per episode csv
|
||||
file path.
|
||||
:param title: The plot title. This is optional.
|
||||
:param subtitle: The plot subtitle. This is optional.
|
||||
:return: The plot as an instance of ``plotly.graph_objs._figure.Figure``.
|
||||
"""
|
||||
df = pl.read_csv(av_reward_per_episode_csv)
|
||||
|
||||
if title:
|
||||
if subtitle:
|
||||
title = f"{title} <br>{subtitle}</sup>"
|
||||
else:
|
||||
if subtitle:
|
||||
title = subtitle
|
||||
|
||||
config = _get_plotly_config()
|
||||
layout = go.Layout(
|
||||
autosize=config["size"]["auto_size"],
|
||||
width=config["size"]["width"],
|
||||
height=config["size"]["height"],
|
||||
)
|
||||
# Create the line graph with a colored line
|
||||
fig = go.Figure(layout=layout)
|
||||
fig.update_layout(template=config["template"])
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=df["Episode"],
|
||||
y=df["Average Reward"],
|
||||
mode="lines",
|
||||
name="Mean Reward per Episode",
|
||||
)
|
||||
)
|
||||
|
||||
# Set the layout of the graph
|
||||
fig.update_layout(
|
||||
xaxis={
|
||||
"title": "Episode",
|
||||
"type": "linear",
|
||||
"rangeslider": {"visible": config["range_slider"]},
|
||||
},
|
||||
yaxis={"title": "Average Reward"},
|
||||
title=title,
|
||||
showlegend=False,
|
||||
)
|
||||
|
||||
return fig
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Module for handling configurable observation spaces in PrimAITE."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Dict, Final, List, Tuple, Union
|
||||
from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
@@ -77,7 +77,9 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
)
|
||||
|
||||
# 3. Initialise Observation with zeroes
|
||||
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
|
||||
self.current_observation = np.zeros(
|
||||
observation_shape, dtype=self._DATA_TYPE
|
||||
)
|
||||
|
||||
def update(self):
|
||||
"""Update the observation based on current environment state.
|
||||
@@ -92,7 +94,9 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
self.current_observation[item_index][0] = int(node.node_id)
|
||||
self.current_observation[item_index][1] = node.hardware_state.value
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
self.current_observation[item_index][2] = node.software_state.value
|
||||
self.current_observation[item_index][
|
||||
2
|
||||
] = node.software_state.value
|
||||
self.current_observation[item_index][
|
||||
3
|
||||
] = node.file_system_state_observed.value
|
||||
@@ -199,9 +203,16 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
if isinstance(node, ServiceNode):
|
||||
for i, service in enumerate(self.env.services_list):
|
||||
if node.has_service(service):
|
||||
service_states[i] = node.get_service_state(service).value
|
||||
service_states[i] = node.get_service_state(
|
||||
service
|
||||
).value
|
||||
obs.extend(
|
||||
[hardware_state, software_state, file_system_state, *service_states]
|
||||
[
|
||||
hardware_state,
|
||||
software_state,
|
||||
file_system_state,
|
||||
*service_states,
|
||||
]
|
||||
)
|
||||
self.current_observation[:] = obs
|
||||
|
||||
@@ -259,7 +270,9 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
shape = (
|
||||
[self._quantisation_levels] * self.env.num_links * self._entries_per_link
|
||||
[self._quantisation_levels]
|
||||
* self.env.num_links
|
||||
* self._entries_per_link
|
||||
)
|
||||
|
||||
# 2. Create Observation space
|
||||
@@ -279,7 +292,9 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
if self._combine_service_traffic:
|
||||
loads = [link.get_current_load()]
|
||||
else:
|
||||
loads = [protocol.get_load() for protocol in link.protocol_list]
|
||||
loads = [
|
||||
protocol.get_load() for protocol in link.protocol_list
|
||||
]
|
||||
|
||||
for load in loads:
|
||||
if load <= 0:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"""Main environment module containing the PRIMmary AI Training Evironment (Primaite) class."""
|
||||
import copy
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple, Union, Final
|
||||
from typing import Dict, Final, Tuple, Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
@@ -12,8 +12,10 @@ from matplotlib import pyplot as plt
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.agents.utils import is_valid_acl_action_extra, \
|
||||
is_valid_node_action
|
||||
from primaite.agents.utils import (
|
||||
is_valid_acl_action_extra,
|
||||
is_valid_node_action,
|
||||
)
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import (
|
||||
ActionType,
|
||||
@@ -24,7 +26,8 @@ from primaite.common.enums import (
|
||||
NodeType,
|
||||
ObservationType,
|
||||
Priority,
|
||||
SoftwareState, SessionType,
|
||||
SessionType,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.service import Service
|
||||
from primaite.config import training_config
|
||||
@@ -34,15 +37,18 @@ from primaite.environment.reward import calculate_reward_function
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node import Node
|
||||
from primaite.nodes.node_state_instruction_green import \
|
||||
NodeStateInstructionGreen
|
||||
from primaite.nodes.node_state_instruction_green import (
|
||||
NodeStateInstructionGreen,
|
||||
)
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.nodes.passive_node import PassiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.green_pol import apply_iers, apply_node_pol
|
||||
from primaite.pol.ier import IER
|
||||
from primaite.pol.red_agent_pol import apply_red_agent_iers, \
|
||||
apply_red_agent_node_pol
|
||||
from primaite.pol.red_agent_pol import (
|
||||
apply_red_agent_iers,
|
||||
apply_red_agent_node_pol,
|
||||
)
|
||||
from primaite.transactions.transaction import Transaction
|
||||
from primaite.utils.session_output_writer import SessionOutputWriter
|
||||
|
||||
@@ -59,11 +65,11 @@ class Primaite(Env):
|
||||
ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
session_path: Path,
|
||||
timestamp_str: str,
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
session_path: Path,
|
||||
timestamp_str: str,
|
||||
):
|
||||
"""
|
||||
The Primaite constructor.
|
||||
@@ -237,27 +243,19 @@ class Primaite(Env):
|
||||
)
|
||||
|
||||
self.episode_av_reward_writer = SessionOutputWriter(
|
||||
self,
|
||||
transaction_writer=False,
|
||||
learning_session=True
|
||||
self, transaction_writer=False, learning_session=True
|
||||
)
|
||||
self.transaction_writer = SessionOutputWriter(
|
||||
self,
|
||||
transaction_writer=True,
|
||||
learning_session=True
|
||||
self, transaction_writer=True, learning_session=True
|
||||
)
|
||||
|
||||
def set_as_eval(self):
|
||||
"""Set the writers to write to eval directories."""
|
||||
self.episode_av_reward_writer = SessionOutputWriter(
|
||||
self,
|
||||
transaction_writer=False,
|
||||
learning_session=False
|
||||
self, transaction_writer=False, learning_session=False
|
||||
)
|
||||
self.transaction_writer = SessionOutputWriter(
|
||||
self,
|
||||
transaction_writer=True,
|
||||
learning_session=False
|
||||
self, transaction_writer=True, learning_session=False
|
||||
)
|
||||
self.episode_count = 0
|
||||
self.step_count = 0
|
||||
@@ -322,9 +320,7 @@ class Primaite(Env):
|
||||
|
||||
# Create a Transaction (metric) object for this step
|
||||
transaction = Transaction(
|
||||
self.agent_identifier,
|
||||
self.episode_count,
|
||||
self.step_count
|
||||
self.agent_identifier, self.episode_count, self.step_count
|
||||
)
|
||||
# Load the initial observation space into the transaction
|
||||
transaction.obs_space_pre = copy.deepcopy(self.env_obs)
|
||||
@@ -354,8 +350,9 @@ class Primaite(Env):
|
||||
self.nodes_post_pol = copy.deepcopy(self.nodes)
|
||||
self.links_post_pol = copy.deepcopy(self.links)
|
||||
# Reference
|
||||
apply_node_pol(self.nodes_reference, self.node_pol,
|
||||
self.step_count) # Node PoL
|
||||
apply_node_pol(
|
||||
self.nodes_reference, self.node_pol, self.step_count
|
||||
) # Node PoL
|
||||
apply_iers(
|
||||
self.network_reference,
|
||||
self.nodes_reference,
|
||||
@@ -404,8 +401,10 @@ class Primaite(Env):
|
||||
# For evaluation, need to trigger the done value = True when
|
||||
# step count is reached in order to prevent neverending episode
|
||||
done = True
|
||||
_LOGGER.info(f"Episode: {self.episode_count}, "
|
||||
f"Average Reward: {self.average_reward}")
|
||||
_LOGGER.info(
|
||||
f"Episode: {self.episode_count}, "
|
||||
f"Average Reward: {self.average_reward}"
|
||||
)
|
||||
# Load the reward into the transaction
|
||||
transaction.reward = reward
|
||||
|
||||
@@ -452,11 +451,11 @@ class Primaite(Env):
|
||||
elif self.training_config.action_type == ActionType.ACL:
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif (
|
||||
len(self.action_dict[_action]) == 6
|
||||
len(self.action_dict[_action]) == 6
|
||||
): # ACL actions in multidiscrete form have len 6
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif (
|
||||
len(self.action_dict[_action]) == 4
|
||||
len(self.action_dict[_action]) == 4
|
||||
): # Node actions in multdiscrete (array) from have len 4
|
||||
self.apply_actions_to_nodes(_action)
|
||||
else:
|
||||
@@ -525,7 +524,7 @@ class Primaite(Env):
|
||||
# Patch (valid action if it's good or compromised)
|
||||
node.set_service_state(
|
||||
self.services_list[service_index],
|
||||
SoftwareState.PATCHING
|
||||
SoftwareState.PATCHING,
|
||||
)
|
||||
else:
|
||||
# Node is not of Service Type
|
||||
@@ -542,7 +541,10 @@ class Primaite(Env):
|
||||
elif property_action == 2:
|
||||
# Repair
|
||||
# You cannot repair a destroyed file system - it needs restoring
|
||||
if node.file_system_state_actual != FileSystemState.DESTROYED:
|
||||
if (
|
||||
node.file_system_state_actual
|
||||
!= FileSystemState.DESTROYED
|
||||
):
|
||||
node.set_file_system_state(FileSystemState.REPAIRING)
|
||||
elif property_action == 3:
|
||||
# Restore
|
||||
@@ -585,8 +587,9 @@ class Primaite(Env):
|
||||
acl_rule_source = "ANY"
|
||||
else:
|
||||
node = list(self.nodes.values())[action_source_ip - 1]
|
||||
if isinstance(node, ServiceNode) or isinstance(node,
|
||||
ActiveNode):
|
||||
if isinstance(node, ServiceNode) or isinstance(
|
||||
node, ActiveNode
|
||||
):
|
||||
acl_rule_source = node.ip_address
|
||||
else:
|
||||
return
|
||||
@@ -595,8 +598,9 @@ class Primaite(Env):
|
||||
acl_rule_destination = "ANY"
|
||||
else:
|
||||
node = list(self.nodes.values())[action_destination_ip - 1]
|
||||
if isinstance(node, ServiceNode) or isinstance(node,
|
||||
ActiveNode):
|
||||
if isinstance(node, ServiceNode) or isinstance(
|
||||
node, ActiveNode
|
||||
):
|
||||
acl_rule_destination = node.ip_address
|
||||
else:
|
||||
return
|
||||
@@ -681,8 +685,9 @@ class Primaite(Env):
|
||||
:return: The observation space, initial observation (zeroed out array with the correct shape)
|
||||
:rtype: Tuple[spaces.Space, np.ndarray]
|
||||
"""
|
||||
self.obs_handler = ObservationsHandler.from_config(self,
|
||||
self.obs_config)
|
||||
self.obs_handler = ObservationsHandler.from_config(
|
||||
self, self.obs_config
|
||||
)
|
||||
|
||||
return self.obs_handler.space, self.obs_handler.current_observation
|
||||
|
||||
@@ -790,7 +795,8 @@ class Primaite(Env):
|
||||
service_port = service["port"]
|
||||
service_state = SoftwareState[service["state"]]
|
||||
node.add_service(
|
||||
Service(service_protocol, service_port, service_state))
|
||||
Service(service_protocol, service_port, service_state)
|
||||
)
|
||||
else:
|
||||
# Bad formatting
|
||||
pass
|
||||
@@ -843,8 +849,9 @@ class Primaite(Env):
|
||||
dest_node_ref: Node = self.nodes_reference[link_destination]
|
||||
|
||||
# Add link to network (reference)
|
||||
self.network_reference.add_edge(source_node_ref, dest_node_ref,
|
||||
id=link_name)
|
||||
self.network_reference.add_edge(
|
||||
source_node_ref, dest_node_ref, id=link_name
|
||||
)
|
||||
|
||||
# Add link to link dictionary (reference)
|
||||
self.links_reference[link_name] = Link(
|
||||
@@ -1120,7 +1127,8 @@ class Primaite(Env):
|
||||
node_id = item["node_id"]
|
||||
node_class = item["node_class"]
|
||||
node_hardware_state: HardwareState = HardwareState[
|
||||
item["hardware_state"]]
|
||||
item["hardware_state"]
|
||||
]
|
||||
|
||||
node: NodeUnion = self.nodes[node_id]
|
||||
node_ref = self.nodes_reference[node_id]
|
||||
@@ -1186,8 +1194,12 @@ class Primaite(Env):
|
||||
# Use MAX to ensure we get them all
|
||||
for node_action in range(4):
|
||||
for service_state in range(self.num_services):
|
||||
action = [node, node_property, node_action,
|
||||
service_state]
|
||||
action = [
|
||||
node,
|
||||
node_property,
|
||||
node_action,
|
||||
service_state,
|
||||
]
|
||||
# check to see if it's a nothing action (has no effect)
|
||||
if is_valid_node_action(action):
|
||||
actions[action_key] = action
|
||||
|
||||
@@ -46,7 +46,9 @@ def calculate_reward_function(
|
||||
)
|
||||
|
||||
# Software State
|
||||
if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode):
|
||||
if isinstance(final_node, ActiveNode) or isinstance(
|
||||
final_node, ServiceNode
|
||||
):
|
||||
reward_value += score_node_os_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
)
|
||||
@@ -81,7 +83,8 @@ def calculate_reward_function(
|
||||
reference_blocked = not reference_ier.get_is_running()
|
||||
live_blocked = not ier_value.get_is_running()
|
||||
ier_reward = (
|
||||
config_values.green_ier_blocked * ier_value.get_mission_criticality()
|
||||
config_values.green_ier_blocked
|
||||
* ier_value.get_mission_criticality()
|
||||
)
|
||||
|
||||
if live_blocked and not reference_blocked:
|
||||
@@ -104,7 +107,9 @@ def calculate_reward_function(
|
||||
return reward_value
|
||||
|
||||
|
||||
def score_node_operating_state(final_node, initial_node, reference_node, config_values):
|
||||
def score_node_operating_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
):
|
||||
"""
|
||||
Calculates score relating to the hardware state of a node.
|
||||
|
||||
@@ -153,7 +158,9 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
|
||||
return score
|
||||
|
||||
|
||||
def score_node_os_state(final_node, initial_node, reference_node, config_values):
|
||||
def score_node_os_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
):
|
||||
"""
|
||||
Calculates score relating to the Software State of a node.
|
||||
|
||||
@@ -204,7 +211,9 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
|
||||
return score
|
||||
|
||||
|
||||
def score_node_service_state(final_node, initial_node, reference_node, config_values):
|
||||
def score_node_service_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
):
|
||||
"""
|
||||
Calculates score relating to the service state(s) of a node.
|
||||
|
||||
@@ -276,7 +285,9 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
|
||||
return score
|
||||
|
||||
|
||||
def score_node_file_system(final_node, initial_node, reference_node, config_values):
|
||||
def score_node_file_system(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
):
|
||||
"""
|
||||
Calculates score relating to the file system state of a node.
|
||||
|
||||
|
||||
@@ -8,7 +8,9 @@ from primaite.common.protocol import Protocol
|
||||
class Link(object):
|
||||
"""Link class."""
|
||||
|
||||
def __init__(self, _id, _bandwidth, _source_node_name, _dest_node_name, _services):
|
||||
def __init__(
|
||||
self, _id, _bandwidth, _source_node_name, _dest_node_name, _services
|
||||
):
|
||||
"""
|
||||
Init.
|
||||
|
||||
|
||||
@@ -10,7 +10,10 @@ from primaite.primaite_session import PrimaiteSession
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]):
|
||||
def run(
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
):
|
||||
"""Run the PrimAITE Session.
|
||||
|
||||
:param training_config_path: The training config filepath.
|
||||
|
||||
@@ -87,7 +87,9 @@ class ActiveNode(Node):
|
||||
f"Node.software_state:{self._software_state}"
|
||||
)
|
||||
|
||||
def set_software_state_if_not_compromised(self, software_state: SoftwareState):
|
||||
def set_software_state_if_not_compromised(
|
||||
self, software_state: SoftwareState
|
||||
):
|
||||
"""
|
||||
Sets Software State if the node is not compromised.
|
||||
|
||||
@@ -98,7 +100,9 @@ class ActiveNode(Node):
|
||||
if self._software_state != SoftwareState.COMPROMISED:
|
||||
self._software_state = software_state
|
||||
if software_state == SoftwareState.PATCHING:
|
||||
self.patching_count = self.config_values.os_patching_duration
|
||||
self.patching_count = (
|
||||
self.config_values.os_patching_duration
|
||||
)
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so OS State cannot be changed."
|
||||
@@ -187,7 +191,9 @@ class ActiveNode(Node):
|
||||
def start_file_system_scan(self):
|
||||
"""Starts a file system scan."""
|
||||
self.file_system_scanning = True
|
||||
self.file_system_scanning_count = self.config_values.file_system_scanning_limit
|
||||
self.file_system_scanning_count = (
|
||||
self.config_values.file_system_scanning_limit
|
||||
)
|
||||
|
||||
def update_file_system_state(self):
|
||||
"""Updates file system status based on scanning/restore/repair cycle."""
|
||||
@@ -206,7 +212,10 @@ class ActiveNode(Node):
|
||||
self.file_system_state_observed = FileSystemState.GOOD
|
||||
|
||||
# Scanning updates
|
||||
if self.file_system_scanning == True and self.file_system_scanning_count < 0:
|
||||
if (
|
||||
self.file_system_scanning == True
|
||||
and self.file_system_scanning_count < 0
|
||||
):
|
||||
self.file_system_state_observed = self.file_system_state_actual
|
||||
self.file_system_scanning = False
|
||||
self.file_system_scanning_count = 0
|
||||
|
||||
@@ -32,7 +32,9 @@ class NodeStateInstructionGreen(object):
|
||||
self.end_step = _end_step
|
||||
self.node_id = _node_id
|
||||
self.node_pol_type = _node_pol_type
|
||||
self.service_name = _service_name # Not used when not a service instruction
|
||||
self.service_name = (
|
||||
_service_name # Not used when not a service instruction
|
||||
)
|
||||
self.state = _state
|
||||
|
||||
def get_start_step(self):
|
||||
|
||||
@@ -42,7 +42,9 @@ class NodeStateInstructionRed(object):
|
||||
self.target_node_id = _target_node_id
|
||||
self.initiator = _pol_initiator
|
||||
self.pol_type: NodePOLType = _pol_type
|
||||
self.service_name = pol_protocol # Not used when not a service instruction
|
||||
self.service_name = (
|
||||
pol_protocol # Not used when not a service instruction
|
||||
)
|
||||
self.state = _pol_state
|
||||
self.source_node_id = _pol_source_node_id
|
||||
self.source_node_service = _pol_source_node_service
|
||||
|
||||
@@ -110,7 +110,9 @@ class ServiceNode(ActiveNode):
|
||||
return False
|
||||
return False
|
||||
|
||||
def set_service_state(self, protocol_name: str, software_state: SoftwareState):
|
||||
def set_service_state(
|
||||
self, protocol_name: str, software_state: SoftwareState
|
||||
):
|
||||
"""
|
||||
Sets the software_state of a service (protocol) on the node.
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from primaite import NOTEBOOKS_DIR, getLogger
|
||||
from primaite import getLogger, NOTEBOOKS_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -6,10 +6,17 @@ from networkx import MultiGraph, shortest_path
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import HardwareState, NodePOLType, NodeType, SoftwareState
|
||||
from primaite.common.enums import (
|
||||
HardwareState,
|
||||
NodePOLType,
|
||||
NodeType,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
|
||||
from primaite.nodes.node_state_instruction_green import (
|
||||
NodeStateInstructionGreen,
|
||||
)
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.ier import IER
|
||||
@@ -190,7 +197,9 @@ def apply_iers(
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Check whether the new load exceeds the bandwidth
|
||||
if (link.get_current_load() + load) > link.get_bandwidth():
|
||||
if (
|
||||
link.get_current_load() + load
|
||||
) > link.get_bandwidth():
|
||||
link_capacity_exceeded = True
|
||||
if _VERBOSE:
|
||||
print("Link capacity exceeded")
|
||||
@@ -204,7 +213,8 @@ def apply_iers(
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count], path_node_list[count + 1]
|
||||
path_node_list[count],
|
||||
path_node_list[count + 1],
|
||||
)
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
@@ -216,7 +226,9 @@ def apply_iers(
|
||||
else:
|
||||
# One of the nodes is not operational
|
||||
if _VERBOSE:
|
||||
print("Path not valid - one or more nodes not operational")
|
||||
print(
|
||||
"Path not valid - one or more nodes not operational"
|
||||
)
|
||||
pass
|
||||
|
||||
else:
|
||||
@@ -231,7 +243,9 @@ def apply_iers(
|
||||
|
||||
def apply_node_pol(
|
||||
nodes: Dict[str, NodeUnion],
|
||||
node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]],
|
||||
node_pol: Dict[
|
||||
any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]
|
||||
],
|
||||
step: int,
|
||||
):
|
||||
"""
|
||||
@@ -263,16 +277,22 @@ def apply_node_pol(
|
||||
elif node_pol_type == NodePOLType.OS:
|
||||
# Change OS state
|
||||
# Don't allow PoL to fix something that is compromised. Only the Blue agent can do this
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
if isinstance(node, ActiveNode) or isinstance(
|
||||
node, ServiceNode
|
||||
):
|
||||
node.set_software_state_if_not_compromised(state)
|
||||
elif node_pol_type == NodePOLType.SERVICE:
|
||||
# Change a service state
|
||||
# Don't allow PoL to fix something that is compromised. Only the Blue agent can do this
|
||||
if isinstance(node, ServiceNode):
|
||||
node.set_service_state_if_not_compromised(service_name, state)
|
||||
node.set_service_state_if_not_compromised(
|
||||
service_name, state
|
||||
)
|
||||
else:
|
||||
# Change the file system status
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
if isinstance(node, ActiveNode) or isinstance(
|
||||
node, ServiceNode
|
||||
):
|
||||
node.set_file_system_state_if_not_compromised(state)
|
||||
else:
|
||||
# PoL is not valid in this time step
|
||||
|
||||
@@ -176,7 +176,9 @@ def apply_red_agent_iers(
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Check whether the new load exceeds the bandwidth
|
||||
if (link.get_current_load() + load) > link.get_bandwidth():
|
||||
if (
|
||||
link.get_current_load() + load
|
||||
) > link.get_bandwidth():
|
||||
link_capacity_exceeded = True
|
||||
if _VERBOSE:
|
||||
print("Link capacity exceeded")
|
||||
@@ -190,7 +192,8 @@ def apply_red_agent_iers(
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count], path_node_list[count + 1]
|
||||
path_node_list[count],
|
||||
path_node_list[count + 1],
|
||||
)
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
@@ -200,16 +203,23 @@ def apply_red_agent_iers(
|
||||
# This IER is now valid, so set it to running
|
||||
ier_value.set_is_running(True)
|
||||
if _VERBOSE:
|
||||
print("Red IER was allowed to run in step " + str(step))
|
||||
print(
|
||||
"Red IER was allowed to run in step "
|
||||
+ str(step)
|
||||
)
|
||||
else:
|
||||
# One of the nodes is not operational
|
||||
if _VERBOSE:
|
||||
print("Path not valid - one or more nodes not operational")
|
||||
print(
|
||||
"Path not valid - one or more nodes not operational"
|
||||
)
|
||||
pass
|
||||
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("Red IER was NOT allowed to run in step " + str(step))
|
||||
print(
|
||||
"Red IER was NOT allowed to run in step " + str(step)
|
||||
)
|
||||
print("Source, Dest or ACL were not valid")
|
||||
pass
|
||||
# ------------------------------------
|
||||
@@ -264,7 +274,9 @@ def apply_red_agent_node_pol(
|
||||
passed_checks = True
|
||||
elif initiator == NodePOLInitiator.IER:
|
||||
# Need to check there is a red IER incoming
|
||||
passed_checks = is_red_ier_incoming(target_node, iers, pol_type)
|
||||
passed_checks = is_red_ier_incoming(
|
||||
target_node, iers, pol_type
|
||||
)
|
||||
elif initiator == NodePOLInitiator.SERVICE:
|
||||
# Need to check the condition of a service on another node
|
||||
source_node = nodes[source_node_id]
|
||||
@@ -308,7 +320,9 @@ def apply_red_agent_node_pol(
|
||||
target_node.set_file_system_state(state)
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("Node Red Agent PoL not allowed - did not pass checks")
|
||||
print(
|
||||
"Node Red Agent PoL not allowed - did not pass checks"
|
||||
)
|
||||
else:
|
||||
# PoL is not valid in this time step
|
||||
pass
|
||||
@@ -323,7 +337,10 @@ def is_red_ier_incoming(node, iers, node_pol_type):
|
||||
node_id = node.node_id
|
||||
|
||||
for ier_key, ier_value in iers.items():
|
||||
if ier_value.get_is_running() and ier_value.get_dest_node_id() == node_id:
|
||||
if (
|
||||
ier_value.get_is_running()
|
||||
and ier_value.get_dest_node_id() == node_id
|
||||
):
|
||||
if (
|
||||
node_pol_type == NodePOLType.OPERATING
|
||||
or node_pol_type == NodePOLType.OS
|
||||
|
||||
@@ -1,54 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Final, Optional, Union, Dict
|
||||
from uuid import uuid4
|
||||
from typing import Dict, Final, Optional, Union
|
||||
|
||||
from primaite import getLogger, SESSIONS_DIR
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.agents.hardcoded_acl import HardCodedACLAgent
|
||||
from primaite.agents.hardcoded_node import HardCodedNodeAgent
|
||||
from primaite.agents.rllib import RLlibAgent
|
||||
from primaite.agents.sb3 import SB3Agent
|
||||
from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, \
|
||||
RandomAgent, DummyAgent
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier, \
|
||||
ActionType, SessionType
|
||||
from primaite.agents.simple import (
|
||||
DoNothingACLAgent,
|
||||
DoNothingNodeAgent,
|
||||
DummyAgent,
|
||||
RandomAgent,
|
||||
)
|
||||
from primaite.common.enums import (
|
||||
ActionType,
|
||||
AgentFramework,
|
||||
AgentIdentifier,
|
||||
SessionType,
|
||||
)
|
||||
from primaite.config import lay_down_config, training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def _get_session_path(session_timestamp: datetime) -> Path:
|
||||
"""
|
||||
Get the directory path the session will output to.
|
||||
|
||||
This is set in the format of:
|
||||
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
|
||||
:param session_timestamp: This is the datetime that the session started.
|
||||
:return: The session directory path.
|
||||
"""
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = SESSIONS_DIR / date_dir / session_path
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
_LOGGER.debug(f"Created PrimAITE Session path: {session_path}")
|
||||
|
||||
return session_path
|
||||
|
||||
|
||||
class PrimaiteSession:
|
||||
"""
|
||||
The PrimaiteSession class.
|
||||
|
||||
Provides a single learning and evaluation entry point for all training
|
||||
and lay down configurations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path]
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
):
|
||||
"""
|
||||
The PrimaiteSession constructor.
|
||||
|
||||
:param training_config_path: The training config path.
|
||||
:param lay_down_config_path: The lay down config path.
|
||||
"""
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Final[Union[Path]] = training_config_path
|
||||
@@ -64,22 +61,35 @@ class PrimaiteSession:
|
||||
)
|
||||
|
||||
self._agent_session: AgentSessionABC = None # noqa
|
||||
self.session_path: Path = None # noqa
|
||||
self.timestamp_str: str = None # noqa
|
||||
self.learning_path: Path = None # noqa
|
||||
self.evaluation_path: Path = None # noqa
|
||||
|
||||
def setup(self):
|
||||
"""Performs the session setup."""
|
||||
if self._training_config.agent_framework == AgentFramework.CUSTOM:
|
||||
if self._training_config.agent_identifier == AgentIdentifier.HARDCODED:
|
||||
_LOGGER.debug(
|
||||
f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}"
|
||||
)
|
||||
if (
|
||||
self._training_config.agent_identifier
|
||||
== AgentIdentifier.HARDCODED
|
||||
):
|
||||
_LOGGER.debug(
|
||||
f"PrimaiteSession Setup: Agent Identifier ="
|
||||
f" {AgentIdentifier.HARDCODED}"
|
||||
)
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
# Deterministic Hardcoded Agent with Node Action Space
|
||||
self._agent_session = HardCodedNodeAgent(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
self._agent_session = HardCodedACLAgent(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
@@ -90,18 +100,23 @@ class PrimaiteSession:
|
||||
# Invalid AgentIdentifier ActionType combo
|
||||
raise ValueError
|
||||
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING:
|
||||
elif (
|
||||
self._training_config.agent_identifier
|
||||
== AgentIdentifier.DO_NOTHING
|
||||
):
|
||||
_LOGGER.debug(
|
||||
f"PrimaiteSession Setup: Agent Identifier ="
|
||||
f" {AgentIdentifier.DO_NOTHINGD}"
|
||||
)
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
self._agent_session = DoNothingNodeAgent(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
self._agent_session = DoNothingACLAgent(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
@@ -112,15 +127,26 @@ class PrimaiteSession:
|
||||
# Invalid AgentIdentifier ActionType combo
|
||||
raise ValueError
|
||||
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.RANDOM:
|
||||
self._agent_session = RandomAgent(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path
|
||||
elif (
|
||||
self._training_config.agent_identifier
|
||||
== AgentIdentifier.RANDOM
|
||||
):
|
||||
_LOGGER.debug(
|
||||
f"PrimaiteSession Setup: Agent Identifier ="
|
||||
f" {AgentIdentifier.RANDOM}"
|
||||
)
|
||||
self._agent_session = RandomAgent(
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
elif (
|
||||
self._training_config.agent_identifier == AgentIdentifier.DUMMY
|
||||
):
|
||||
_LOGGER.debug(
|
||||
f"PrimaiteSession Setup: Agent Identifier ="
|
||||
f" {AgentIdentifier.DUMMY}"
|
||||
)
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.DUMMY:
|
||||
self._agent_session = DummyAgent(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -128,37 +154,64 @@ class PrimaiteSession:
|
||||
raise ValueError
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.SB3:
|
||||
_LOGGER.debug(
|
||||
f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}"
|
||||
)
|
||||
# Stable Baselines3 Agent
|
||||
self._agent_session = SB3Agent(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
_LOGGER.debug(
|
||||
f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}"
|
||||
)
|
||||
# Ray RLlib Agent
|
||||
self._agent_session = RLlibAgent(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework
|
||||
raise ValueError
|
||||
|
||||
self.session_path: Path = self._agent_session.session_path
|
||||
self.timestamp_str: str = self._agent_session.timestamp_str
|
||||
self.learning_path: Path = self._agent_session.learning_path
|
||||
self.evaluation_path: Path = self._agent_session.evaluation_path
|
||||
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param time_steps: The number of time steps per episode.
|
||||
:param episodes: The number of episodes.
|
||||
:param kwargs: Any agent-framework specific key word args.
|
||||
"""
|
||||
if not self._training_config.session_type == SessionType.EVAL:
|
||||
self._agent_session.learn(time_steps, episodes, **kwargs)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param time_steps: The number of time steps per episode.
|
||||
:param episodes: The number of episodes.
|
||||
:param kwargs: Any agent-framework specific key word args.
|
||||
"""
|
||||
if not self._training_config.session_type == SessionType.TRAIN:
|
||||
self._agent_session.evaluate(time_steps, episodes, **kwargs)
|
||||
|
||||
def close(self):
|
||||
"""Closes the agent."""
|
||||
self._agent_session.close()
|
||||
|
||||
@@ -9,3 +9,14 @@ logging:
|
||||
WARNING: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
ERROR: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
CRITICAL: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
|
||||
# Session
|
||||
session:
|
||||
outputs:
|
||||
plots:
|
||||
size:
|
||||
auto_size: false
|
||||
width: 1500
|
||||
height: 900
|
||||
template: plotly_white
|
||||
range_slider: false
|
||||
|
||||
@@ -6,7 +6,7 @@ from pathlib import Path
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from primaite import NOTEBOOKS_DIR, getLogger
|
||||
from primaite import getLogger, NOTEBOOKS_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -24,7 +24,9 @@ def run(overwrite_existing: bool = True):
|
||||
for subdir, dirs, files in os.walk(notebooks_package_data_root):
|
||||
for file in files:
|
||||
fp = os.path.join(subdir, file)
|
||||
path_split = os.path.relpath(fp, notebooks_package_data_root).split(os.sep)
|
||||
path_split = os.path.relpath(
|
||||
fp, notebooks_package_data_root
|
||||
).split(os.sep)
|
||||
target_fp = NOTEBOOKS_DIR / Path(*path_split)
|
||||
target_fp.parent.mkdir(exist_ok=True, parents=True)
|
||||
copy_file = not target_fp.is_file()
|
||||
|
||||
@@ -5,7 +5,7 @@ from pathlib import Path
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from primaite import USERS_CONFIG_DIR, getLogger
|
||||
from primaite import getLogger, USERS_CONFIG_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -24,7 +24,9 @@ def run(overwrite_existing=True):
|
||||
for subdir, dirs, files in os.walk(configs_package_data_root):
|
||||
for file in files:
|
||||
fp = os.path.join(subdir, file)
|
||||
path_split = os.path.relpath(fp, configs_package_data_root).split(os.sep)
|
||||
path_split = os.path.relpath(fp, configs_package_data_root).split(
|
||||
os.sep
|
||||
)
|
||||
target_fp = USERS_CONFIG_DIR / "example_config" / Path(*path_split)
|
||||
target_fp.parent.mkdir(exist_ok=True, parents=True)
|
||||
copy_file = not target_fp.is_file()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
from primaite import _USER_DIRS, LOG_DIR, NOTEBOOKS_DIR, getLogger
|
||||
from primaite import _USER_DIRS, getLogger, LOG_DIR, NOTEBOOKS_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -7,12 +7,7 @@ from typing import List, Tuple
|
||||
class Transaction(object):
|
||||
"""Transaction class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_identifier,
|
||||
episode_number,
|
||||
step_number
|
||||
):
|
||||
def __init__(self, agent_identifier, episode_number, step_number):
|
||||
"""
|
||||
Transaction constructor.
|
||||
|
||||
@@ -37,6 +32,11 @@ class Transaction(object):
|
||||
"The action space invoked by the agent"
|
||||
|
||||
def as_csv_data(self) -> Tuple[List, List]:
|
||||
"""
|
||||
Converts the Transaction to a csv data row and provides a header.
|
||||
|
||||
:return: A tuple consisting of (header, data).
|
||||
"""
|
||||
if isinstance(self.action_space, int):
|
||||
action_length = self.action_space
|
||||
else:
|
||||
@@ -74,12 +74,14 @@ class Transaction(object):
|
||||
str(self.reward),
|
||||
]
|
||||
row = (
|
||||
row
|
||||
+ _turn_action_space_to_array(self.action_space)
|
||||
+ _turn_obs_space_to_array(self.obs_space_pre, obs_assets,
|
||||
obs_features)
|
||||
+ _turn_obs_space_to_array(self.obs_space_post, obs_assets,
|
||||
obs_features)
|
||||
row
|
||||
+ _turn_action_space_to_array(self.action_space)
|
||||
+ _turn_obs_space_to_array(
|
||||
self.obs_space_pre, obs_assets, obs_features
|
||||
)
|
||||
+ _turn_obs_space_to_array(
|
||||
self.obs_space_post, obs_assets, obs_features
|
||||
)
|
||||
)
|
||||
return header, row
|
||||
|
||||
|
||||
20
src/primaite/utils/session_output_reader.py
Normal file
20
src/primaite/utils/session_output_reader.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
# Using polars as it's faster than Pandas; it will speed things up when
|
||||
# files get big!
|
||||
import polars as pl
|
||||
|
||||
|
||||
def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
|
||||
"""
|
||||
Read an average rewards per episode csv file and return as a dict.
|
||||
|
||||
The dictionary keys are the episode number, and the values are the mean
|
||||
reward that episode.
|
||||
|
||||
:param av_rewards_csv_file: The average rewards per episode csv file path.
|
||||
:return: The average rewards per episode cdv as a dict.
|
||||
"""
|
||||
d = pl.read_csv(av_rewards_csv_file).to_dict()
|
||||
return {v: d["Average Reward"][i] for i, v in enumerate(d["Episode"])}
|
||||
@@ -1,7 +1,6 @@
|
||||
import csv
|
||||
from logging import Logger
|
||||
from typing import List, Final, IO, Union, Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Final, List, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.transactions.transaction import Transaction
|
||||
@@ -13,15 +12,22 @@ _LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
class SessionOutputWriter:
|
||||
"""
|
||||
A session output writer class.
|
||||
|
||||
Is used to write session outputs to csv file.
|
||||
"""
|
||||
|
||||
_AV_REWARD_PER_EPISODE_HEADER: Final[List[str]] = [
|
||||
"Episode", "Average Reward"
|
||||
"Episode",
|
||||
"Average Reward",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: "Primaite",
|
||||
transaction_writer: bool = False,
|
||||
learning_session: bool = True
|
||||
self,
|
||||
env: "Primaite",
|
||||
transaction_writer: bool = False,
|
||||
learning_session: bool = True,
|
||||
):
|
||||
self._env = env
|
||||
self.transaction_writer = transaction_writer
|
||||
@@ -52,14 +58,21 @@ class SessionOutputWriter:
|
||||
self._csv_writer = csv.writer(self._csv_file)
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
"""Close the cvs file."""
|
||||
if self._csv_file:
|
||||
self._csv_file.close()
|
||||
_LOGGER.info(f"Finished writing file: {self._csv_file_path}")
|
||||
_LOGGER.debug(f"Finished writing file: {self._csv_file_path}")
|
||||
|
||||
def write(
|
||||
self,
|
||||
data: Union[Tuple, Transaction]
|
||||
):
|
||||
def write(self, data: Union[Tuple, Transaction]):
|
||||
"""
|
||||
Write a row of session data.
|
||||
|
||||
:param data: The row of data to write. Can be a Tuple or an instance
|
||||
of Transaction.
|
||||
"""
|
||||
if isinstance(data, Transaction):
|
||||
header, data = data.as_csv_data()
|
||||
else:
|
||||
@@ -69,5 +82,4 @@ class SessionOutputWriter:
|
||||
self._init_csv_writer()
|
||||
self._csv_writer.writerow(header)
|
||||
self._first_write = False
|
||||
|
||||
self._csv_writer.writerow(data)
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray[RLlib])
|
||||
# "NONE" (Custom Agent)
|
||||
agent_framework: RLLIB
|
||||
agent_framework: SB3
|
||||
|
||||
# Sets which Red Agent algo/class will be used:
|
||||
# "PPO" (Proximal Policy Optimization)
|
||||
@@ -27,7 +27,7 @@ num_steps: 256
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 10
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAINING
|
||||
session_type: TRAIN
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
|
||||
@@ -1,11 +1,22 @@
|
||||
# Main Config File
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: SB3
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: A2C
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agent_identifier: STABLE_BASELINES3_A2C
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
@@ -28,7 +39,7 @@ observation_space:
|
||||
time_delay: 1
|
||||
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAINING
|
||||
session_type: TRAIN
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
|
||||
@@ -1,11 +1,22 @@
|
||||
# Main Config File
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: CUSTOM
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: RANDOM
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agent_identifier: NONE
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
@@ -24,7 +35,7 @@ observation_space:
|
||||
time_delay: 1
|
||||
# Filename of the scenario / laydown
|
||||
|
||||
session_type: TRAINING
|
||||
session_type: TRAIN
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
|
||||
@@ -1,11 +1,22 @@
|
||||
# Main Config File
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: CUSTOM
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: RANDOM
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agent_identifier: NONE
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
@@ -25,7 +36,7 @@ observation_space:
|
||||
time_delay: 1
|
||||
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAINING
|
||||
session_type: TRAIN
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
|
||||
@@ -1,11 +1,22 @@
|
||||
# Main Config File
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: CUSTOM
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: RANDOM
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agent_identifier: NONE
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
@@ -18,7 +29,7 @@ num_steps: 5
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 1
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAINING
|
||||
session_type: TRAIN
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
|
||||
@@ -1,10 +1,22 @@
|
||||
# Main Config File
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: CUSTOM
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: DUMMY
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
agent_identifier: GENERIC
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
@@ -18,7 +30,7 @@ num_steps: 15
|
||||
time_delay: 1
|
||||
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAINING
|
||||
session_type: EVAL
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
|
||||
@@ -1,11 +1,22 @@
|
||||
# Main Config File
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: CUSTOM
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: RANDOM
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agent_identifier: GENERIC
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
@@ -18,7 +29,7 @@ num_steps: 15
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 1
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAINING
|
||||
session_type: EVAL
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
|
||||
@@ -1,11 +1,22 @@
|
||||
# Main Config File
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: CUSTOM
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: RANDOM
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agent_identifier: GENERIC
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
@@ -18,7 +29,7 @@ num_steps: 5
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 1
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAINING
|
||||
session_type: EVAL
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
|
||||
@@ -1,37 +1,151 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
import datetime
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Dict, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.common.enums import AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.primaite_session import PrimaiteSession
|
||||
from primaite.utils.session_output_reader import av_rewards_dict
|
||||
from tests.mock_and_patch.get_session_path_mock import get_temp_session_path
|
||||
|
||||
ACTION_SPACE_NODE_VALUES = 1
|
||||
ACTION_SPACE_NODE_ACTION_VALUES = 1
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
def _get_temp_session_path(session_timestamp: datetime) -> Path:
|
||||
|
||||
class TempPrimaiteSession(PrimaiteSession):
|
||||
"""
|
||||
A temporary PrimaiteSession class.
|
||||
|
||||
Uses context manager for deletion of files upon exit.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
):
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
self.setup()
|
||||
|
||||
def learn_av_reward_per_episode(self) -> Dict[int, float]:
|
||||
"""Get the learn av reward per episode from file."""
|
||||
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
|
||||
return av_rewards_dict(self.learning_path / csv_file)
|
||||
|
||||
def eval_av_reward_per_episode_csv(self) -> Dict[int, float]:
|
||||
"""Get the eval av reward per episode from file."""
|
||||
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
|
||||
return av_rewards_dict(self.evaluation_path / csv_file)
|
||||
|
||||
@property
|
||||
def env(self) -> Primaite:
|
||||
"""Direct access to the env for ease of testing."""
|
||||
return self._agent_session._env # noqa
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, tb):
|
||||
del self._agent_session._env.episode_av_reward_writer
|
||||
del self._agent_session._env.transaction_writer
|
||||
shutil.rmtree(self.session_path)
|
||||
shutil.rmtree(self.session_path.parent)
|
||||
_LOGGER.debug(f"Deleted temp session directory: {self.session_path}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_primaite_session(request):
|
||||
"""
|
||||
Provides a temporary PrimaiteSession instance.
|
||||
|
||||
It's temporary as it uses a temporary directory as the session path.
|
||||
|
||||
To use this fixture you need to:
|
||||
|
||||
- parametrize your test function with:
|
||||
|
||||
- "temp_primaite_session"
|
||||
- [[path to training config, path to lay down config]]
|
||||
- Include the temp_primaite_session fixture as a param in your test
|
||||
function.
|
||||
- use the temp_primaite_session as a context manager assigning is the
|
||||
name 'session'.
|
||||
|
||||
.. code:: python
|
||||
|
||||
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from primaite.config.training_config import main_training_config_path
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[main_training_config_path(), dos_very_basic_config_path()]
|
||||
],
|
||||
indirect=True
|
||||
)
|
||||
def test_primaite_session(temp_primaite_session):
|
||||
with temp_primaite_session as session:
|
||||
# Learning outputs are saved in session.learning_path
|
||||
session.learn()
|
||||
|
||||
# Evaluation outputs are saved in session.evaluation_path
|
||||
session.evaluate()
|
||||
|
||||
# To ensure that all files are written, you must call .close()
|
||||
session.close()
|
||||
|
||||
# If you need to inspect any session outputs, it must be done
|
||||
# inside the context manager
|
||||
|
||||
# Now that we've exited the context manager, the
|
||||
# session.session_path directory and its contents are deleted
|
||||
"""
|
||||
training_config_path = request.param[0]
|
||||
lay_down_config_path = request.param[1]
|
||||
with patch(
|
||||
"primaite.agents.agent.get_session_path", get_temp_session_path
|
||||
) as mck:
|
||||
mck.session_timestamp = datetime.now()
|
||||
|
||||
return TempPrimaiteSession(training_config_path, lay_down_config_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_session_path() -> Path:
|
||||
"""
|
||||
Get a temp directory session path the test session will output to.
|
||||
|
||||
:param session_timestamp: This is the datetime that the session started.
|
||||
:return: The session directory path.
|
||||
"""
|
||||
session_timestamp = datetime.now()
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
|
||||
session_path = (
|
||||
Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
|
||||
)
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
return session_path
|
||||
|
||||
|
||||
def _get_primaite_env_from_config(
|
||||
training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
temp_session_path,
|
||||
):
|
||||
"""Takes a config path and returns the created instance of Primaite."""
|
||||
session_timestamp: datetime = datetime.now()
|
||||
session_path = _get_temp_session_path(session_timestamp)
|
||||
session_path = temp_session_path(session_timestamp)
|
||||
|
||||
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
env = Primaite(
|
||||
@@ -45,7 +159,7 @@ def _get_primaite_env_from_config(
|
||||
|
||||
# TOOD: This needs t be refactored to happen outside. Should be part of
|
||||
# a main Session class.
|
||||
if env.training_config.agent_identifier == "GENERIC":
|
||||
if env.training_config.agent_identifier is AgentIdentifier.RANDOM:
|
||||
run_generic(env, config_values)
|
||||
|
||||
return env
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
from primaite.config.lay_down_config import data_manipulation_config_path
|
||||
from primaite.config.training_config import main_training_config_path
|
||||
from primaite.main import run
|
||||
|
||||
|
||||
def test_primaite_main_e2e():
|
||||
"""Tests the primaite.main.run function end-to-end."""
|
||||
run(main_training_config_path(), data_manipulation_config_path())
|
||||
24
tests/mock_and_patch/get_session_path_mock.py
Normal file
24
tests/mock_and_patch/get_session_path_mock.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def get_temp_session_path(session_timestamp: datetime) -> Path:
|
||||
"""
|
||||
Get a temp directory session path the test session will output to.
|
||||
|
||||
:param session_timestamp: This is the datetime that the session started.
|
||||
:return: The session directory path.
|
||||
"""
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = (
|
||||
Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
|
||||
)
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
_LOGGER.debug(f"Created temp session directory: {session_path}")
|
||||
return session_path
|
||||
@@ -60,7 +60,9 @@ def test_os_state_change_if_not_compromised(operating_state, expected_state):
|
||||
1,
|
||||
)
|
||||
|
||||
active_node.set_software_state_if_not_compromised(SoftwareState.OVERWHELMED)
|
||||
active_node.set_software_state_if_not_compromised(
|
||||
SoftwareState.OVERWHELMED
|
||||
)
|
||||
|
||||
assert active_node.software_state == expected_state
|
||||
|
||||
@@ -98,7 +100,9 @@ def test_file_system_change(operating_state, expected_state):
|
||||
(HardwareState.ON, FileSystemState.CORRUPT),
|
||||
],
|
||||
)
|
||||
def test_file_system_change_if_not_compromised(operating_state, expected_state):
|
||||
def test_file_system_change_if_not_compromised(
|
||||
operating_state, expected_state
|
||||
):
|
||||
"""
|
||||
Test that a node cannot change its file system state.
|
||||
|
||||
@@ -116,6 +120,8 @@ def test_file_system_change_if_not_compromised(operating_state, expected_state):
|
||||
1,
|
||||
)
|
||||
|
||||
active_node.set_file_system_state_if_not_compromised(FileSystemState.CORRUPT)
|
||||
active_node.set_file_system_state_if_not_compromised(
|
||||
FileSystemState.CORRUPT
|
||||
)
|
||||
|
||||
assert active_node.file_system_state_actual == expected_state
|
||||
|
||||
@@ -7,79 +7,78 @@ from primaite.environment.observations import (
|
||||
NodeStatuses,
|
||||
ObservationsHandler,
|
||||
)
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
from tests.conftest import _get_primaite_env_from_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def env(request):
|
||||
"""Build Primaite environment for integration tests of observation space."""
|
||||
marker = request.node.get_closest_marker("env_config_paths")
|
||||
training_config_path = marker.args[0]["training_config_path"]
|
||||
lay_down_config_path = marker.args[0]["lay_down_config_path"]
|
||||
env = _get_primaite_env_from_config(
|
||||
training_config_path=training_config_path,
|
||||
lay_down_config_path=lay_down_config_path,
|
||||
)
|
||||
yield env
|
||||
|
||||
|
||||
@pytest.mark.env_config_paths(
|
||||
dict(
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "obs_tests/main_config_without_obs.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
|
||||
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_default_obs_space(env: Primaite):
|
||||
def test_default_obs_space(temp_primaite_session):
|
||||
"""Create environment with no obs space defined in config and check that the default obs space was created."""
|
||||
env.update_environent_obs()
|
||||
with temp_primaite_session as session:
|
||||
session.env.update_environent_obs()
|
||||
|
||||
components = env.obs_handler.registered_obs_components
|
||||
components = session.env.obs_handler.registered_obs_components
|
||||
|
||||
assert len(components) == 1
|
||||
assert isinstance(components[0], NodeLinkTable)
|
||||
assert len(components) == 1
|
||||
assert isinstance(components[0], NodeLinkTable)
|
||||
|
||||
|
||||
@pytest.mark.env_config_paths(
|
||||
dict(
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "obs_tests/main_config_without_obs.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
|
||||
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_registering_components(env: Primaite):
|
||||
def test_registering_components(temp_primaite_session):
|
||||
"""Test regitering and deregistering a component."""
|
||||
handler = ObservationsHandler()
|
||||
component = NodeStatuses(env)
|
||||
handler.register(component)
|
||||
assert component in handler.registered_obs_components
|
||||
handler.deregister(component)
|
||||
assert component not in handler.registered_obs_components
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
handler = ObservationsHandler()
|
||||
component = NodeStatuses(env)
|
||||
handler.register(component)
|
||||
assert component in handler.registered_obs_components
|
||||
handler.deregister(component)
|
||||
assert component not in handler.registered_obs_components
|
||||
|
||||
|
||||
@pytest.mark.env_config_paths(
|
||||
dict(
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "obs_tests/main_config_NODE_LINK_TABLE.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_LINK_TABLE.yaml",
|
||||
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
class TestNodeLinkTable:
|
||||
"""Test the NodeLinkTable observation component (in isolation)."""
|
||||
|
||||
def test_obs_shape(self, env: Primaite):
|
||||
def test_obs_shape(self, temp_primaite_session):
|
||||
"""Try creating env with box observation space."""
|
||||
env.update_environent_obs()
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
env.update_environent_obs()
|
||||
|
||||
# we have three nodes and two links, with two service
|
||||
# therefore the box observation space will have:
|
||||
# * 5 rows (3 nodes + 2 links)
|
||||
# * 6 columns (four fixed and two for the services)
|
||||
assert env.env_obs.shape == (5, 6)
|
||||
# we have three nodes and two links, with two service
|
||||
# therefore the box observation space will have:
|
||||
# * 5 rows (3 nodes + 2 links)
|
||||
# * 6 columns (four fixed and two for the services)
|
||||
assert env.env_obs.shape == (5, 6)
|
||||
|
||||
def test_value(self, env: Primaite):
|
||||
def test_value(self, temp_primaite_session):
|
||||
"""Test that the observation is generated correctly.
|
||||
|
||||
The laydown has:
|
||||
@@ -125,36 +124,45 @@ class TestNodeLinkTable:
|
||||
* 999 (999 traffic service1)
|
||||
* 0 (no traffic for service2)
|
||||
"""
|
||||
# act = np.asarray([0,])
|
||||
obs, reward, done, info = env.step(0) # apply the 'do nothing' action
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
# act = np.asarray([0,])
|
||||
obs, reward, done, info = env.step(
|
||||
0
|
||||
) # apply the 'do nothing' action
|
||||
|
||||
assert np.array_equal(
|
||||
obs,
|
||||
[
|
||||
[1, 1, 3, 1, 1, 1],
|
||||
[2, 1, 1, 1, 1, 4],
|
||||
[3, 1, 1, 1, 0, 0],
|
||||
[4, 0, 0, 0, 999, 0],
|
||||
[5, 0, 0, 0, 999, 0],
|
||||
],
|
||||
)
|
||||
assert np.array_equal(
|
||||
obs,
|
||||
[
|
||||
[1, 1, 3, 1, 1, 1],
|
||||
[2, 1, 1, 1, 1, 4],
|
||||
[3, 1, 1, 1, 0, 0],
|
||||
[4, 0, 0, 0, 999, 0],
|
||||
[5, 0, 0, 0, 999, 0],
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.env_config_paths(
|
||||
dict(
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "obs_tests/main_config_NODE_STATUSES.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_STATUSES.yaml",
|
||||
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
class TestNodeStatuses:
|
||||
"""Test the NodeStatuses observation component (in isolation)."""
|
||||
|
||||
def test_obs_shape(self, env: Primaite):
|
||||
def test_obs_shape(self, temp_primaite_session):
|
||||
"""Try creating env with NodeStatuses as the only component."""
|
||||
assert env.env_obs.shape == (15,)
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
assert env.env_obs.shape == (15,)
|
||||
|
||||
def test_values(self, env: Primaite):
|
||||
def test_values(self, temp_primaite_session):
|
||||
"""Test that the hardware and software states are encoded correctly.
|
||||
|
||||
The laydown has:
|
||||
@@ -181,28 +189,38 @@ class TestNodeStatuses:
|
||||
* service 1 = n/a (0)
|
||||
* service 2 = n/a (0)
|
||||
"""
|
||||
obs, _, _, _ = env.step(0) # apply the 'do nothing' action
|
||||
assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0])
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
obs, _, _, _ = env.step(0) # apply the 'do nothing' action
|
||||
assert np.array_equal(
|
||||
obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.env_config_paths(
|
||||
dict(
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT
|
||||
/ "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml",
|
||||
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
class TestLinkTrafficLevels:
|
||||
"""Test the LinkTrafficLevels observation component (in isolation)."""
|
||||
|
||||
def test_obs_shape(self, env: Primaite):
|
||||
def test_obs_shape(self, temp_primaite_session):
|
||||
"""Try creating env with MultiDiscrete observation space."""
|
||||
env.update_environent_obs()
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
env.update_environent_obs()
|
||||
|
||||
# we have two links and two services, so the shape should be 2 * 2
|
||||
assert env.env_obs.shape == (2 * 2,)
|
||||
# we have two links and two services, so the shape should be 2 * 2
|
||||
assert env.env_obs.shape == (2 * 2,)
|
||||
|
||||
def test_values(self, env: Primaite):
|
||||
def test_values(self, temp_primaite_session):
|
||||
"""Test that traffic values are encoded correctly.
|
||||
|
||||
The laydown has:
|
||||
@@ -212,12 +230,14 @@ class TestLinkTrafficLevels:
|
||||
* an IER trying to send 999 bits of data over both links the whole time (via the first service)
|
||||
* link bandwidth of 1000, therefore the utilisation is 99.9%
|
||||
"""
|
||||
obs, reward, done, info = env.step(0)
|
||||
obs, reward, done, info = env.step(0)
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
obs, reward, done, info = env.step(0)
|
||||
obs, reward, done, info = env.step(0)
|
||||
|
||||
# the observation space has combine_service_traffic set to False, so the space has this format:
|
||||
# [link1_service1, link1_service2, link2_service1, link2_service2]
|
||||
# we send 999 bits of data via link1 and link2 on service 1.
|
||||
# therefore the first and third elements should be 6 and all others 0
|
||||
# (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%)
|
||||
assert np.array_equal(obs, [6, 0, 6, 0])
|
||||
# the observation space has combine_service_traffic set to False, so the space has this format:
|
||||
# [link1_service1, link1_service2, link2_service1, link2_service2]
|
||||
# we send 999 bits of data via link1 and link2 on service 1.
|
||||
# therefore the first and third elements should be 6 and all others 0
|
||||
# (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%)
|
||||
assert np.array_equal(obs, [6, 0, 6, 0])
|
||||
|
||||
61
tests/test_primaite_session.py
Normal file
61
tests/test_primaite_session.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from primaite.config.training_config import main_training_config_path
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[[main_training_config_path(), dos_very_basic_config_path()]],
|
||||
indirect=True,
|
||||
)
|
||||
def test_primaite_session(temp_primaite_session):
|
||||
"""Tests the PrimaiteSession class and its outputs."""
|
||||
with temp_primaite_session as session:
|
||||
session_path = session.session_path
|
||||
assert session_path.exists()
|
||||
session.learn()
|
||||
# Learning outputs are saved in session.learning_path
|
||||
session.evaluate()
|
||||
# Evaluation outputs are saved in session.evaluation_path
|
||||
|
||||
# If you need to inspect any session outputs, it must be done inside
|
||||
# the context manager
|
||||
|
||||
# Check that the metadata json file exists
|
||||
assert (session_path / "session_metadata.json").exists()
|
||||
|
||||
# Check that the network png file exists
|
||||
assert (session_path / f"network_{session.timestamp_str}.png").exists()
|
||||
|
||||
# Check that both the transactions and av reward csv files exist
|
||||
for file in session.learning_path.iterdir():
|
||||
if file.suffix == ".csv":
|
||||
assert (
|
||||
"all_transactions" in file.name
|
||||
or "average_reward_per_episode" in file.name
|
||||
)
|
||||
|
||||
# Check that both the transactions and av reward csv files exist
|
||||
for file in session.evaluation_path.iterdir():
|
||||
if file.suffix == ".csv":
|
||||
assert (
|
||||
"all_transactions" in file.name
|
||||
or "average_reward_per_episode" in file.name
|
||||
)
|
||||
|
||||
_LOGGER.debug("Inspecting files in temp session path...")
|
||||
for dir_path, dir_names, file_names in os.walk(session_path):
|
||||
for file in file_names:
|
||||
path = os.path.join(dir_path, file)
|
||||
file_str = path.split(str(session_path))[-1]
|
||||
_LOGGER.debug(f" {file_str}")
|
||||
|
||||
# Now that we've exited the context manager, the session.session_path
|
||||
# directory and its contents are deleted
|
||||
assert not session_path.exists()
|
||||
@@ -18,7 +18,9 @@ from primaite.nodes.service_node import ServiceNode
|
||||
"starting_operating_state, expected_operating_state",
|
||||
[(HardwareState.RESETTING, HardwareState.ON)],
|
||||
)
|
||||
def test_node_resets_correctly(starting_operating_state, expected_operating_state):
|
||||
def test_node_resets_correctly(
|
||||
starting_operating_state, expected_operating_state
|
||||
):
|
||||
"""Tests that a node resets correctly."""
|
||||
active_node = ActiveNode(
|
||||
node_id="0",
|
||||
|
||||
@@ -1,26 +1,33 @@
|
||||
import pytest
|
||||
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
from tests.conftest import _get_primaite_env_from_config
|
||||
|
||||
|
||||
def test_rewards_are_being_penalised_at_each_step_function():
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml",
|
||||
TEST_CONFIG_ROOT / "one_node_states_on_off_lay_down_config.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_rewards_are_being_penalised_at_each_step_function(
|
||||
temp_primaite_session,
|
||||
):
|
||||
"""
|
||||
Test that hardware state is penalised at each step.
|
||||
|
||||
When the initial state is OFF compared to reference state which is ON.
|
||||
"""
|
||||
env = _get_primaite_env_from_config(
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "one_node_states_on_off_main_config.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT
|
||||
/ "one_node_states_on_off_lay_down_config.yaml",
|
||||
)
|
||||
|
||||
"""
|
||||
On different steps (of the 13 in total) these are the following rewards for config_6 which are activated:
|
||||
On different steps (of the 13 in total) these are the following rewards
|
||||
for config_6 which are activated:
|
||||
File System State: goodShouldBeCorrupt = 5 (between Steps 1 & 3)
|
||||
Hardware State: onShouldBeOff = -2 (between Steps 4 & 6)
|
||||
Service State: goodShouldBeCompromised = 5 (between Steps 7 & 9)
|
||||
Software State (Software State): goodShouldBeCompromised = 5 (between Steps 10 & 12)
|
||||
Software State (Software State): goodShouldBeCompromised = 5 (between
|
||||
Steps 10 & 12)
|
||||
|
||||
Total Reward: -2 - 2 + 5 + 5 + 5 + 5 + 5 + 5 = 26
|
||||
Step Count: 13
|
||||
@@ -28,5 +35,8 @@ def test_rewards_are_being_penalised_at_each_step_function():
|
||||
For the 4 steps where this occurs the average reward is:
|
||||
Average Reward: 2 (26 / 13)
|
||||
"""
|
||||
print("average reward", env.average_reward)
|
||||
assert env.average_reward == -8.0
|
||||
with temp_primaite_session as session:
|
||||
session.evaluate()
|
||||
session.close()
|
||||
ev_rewards = session.eval_av_reward_per_episode_csv()
|
||||
assert ev_rewards[1] == -8.0
|
||||
|
||||
@@ -45,7 +45,9 @@ def test_service_state_change(operating_state, expected_state):
|
||||
(HardwareState.ON, SoftwareState.OVERWHELMED),
|
||||
],
|
||||
)
|
||||
def test_service_state_change_if_not_comprised(operating_state, expected_state):
|
||||
def test_service_state_change_if_not_comprised(
|
||||
operating_state, expected_state
|
||||
):
|
||||
"""
|
||||
Test that a node cannot change the state of a running service.
|
||||
|
||||
@@ -65,6 +67,8 @@ def test_service_state_change_if_not_comprised(operating_state, expected_state):
|
||||
service = Service("TCP", 80, SoftwareState.GOOD)
|
||||
service_node.add_service(service)
|
||||
|
||||
service_node.set_service_state_if_not_compromised("TCP", SoftwareState.OVERWHELMED)
|
||||
service_node.set_service_state_if_not_compromised(
|
||||
"TCP", SoftwareState.OVERWHELMED
|
||||
)
|
||||
|
||||
assert service_node.get_service_state("TCP") == expected_state
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.common.enums import HardwareState
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
from tests.conftest import _get_primaite_env_from_config
|
||||
|
||||
|
||||
def run_generic_set_actions(env: Primaite):
|
||||
@@ -44,59 +45,72 @@ def run_generic_set_actions(env: Primaite):
|
||||
# env.close()
|
||||
|
||||
|
||||
def test_single_action_space_is_valid():
|
||||
"""Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations."""
|
||||
env = _get_primaite_env_from_config(
|
||||
training_config_path=TEST_CONFIG_ROOT / "single_action_space_main_config.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT
|
||||
/ "single_action_space_lay_down_config.yaml",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "single_action_space_main_config.yaml",
|
||||
TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_single_action_space_is_valid(temp_primaite_session):
|
||||
"""Test single action space is valid."""
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
|
||||
run_generic_set_actions(env)
|
||||
|
||||
# Retrieve the action space dictionary values from environment
|
||||
env_action_space_dict = env.action_dict.values()
|
||||
# Flags to check the conditions of the action space
|
||||
contains_acl_actions = False
|
||||
contains_node_actions = False
|
||||
both_action_spaces = False
|
||||
# Loop through each element of the list (which is every value from the dictionary)
|
||||
for dict_item in env_action_space_dict:
|
||||
# Node action detected
|
||||
if len(dict_item) == 4:
|
||||
contains_node_actions = True
|
||||
# Link action detected
|
||||
elif len(dict_item) == 6:
|
||||
contains_acl_actions = True
|
||||
# If both are there then the ANY action type is working
|
||||
if contains_node_actions and contains_acl_actions:
|
||||
both_action_spaces = True
|
||||
# Check condition should be True
|
||||
assert both_action_spaces
|
||||
run_generic_set_actions(env)
|
||||
# Retrieve the action space dictionary values from environment
|
||||
env_action_space_dict = env.action_dict.values()
|
||||
# Flags to check the conditions of the action space
|
||||
contains_acl_actions = False
|
||||
contains_node_actions = False
|
||||
both_action_spaces = False
|
||||
# Loop through each element of the list (which is every value from the dictionary)
|
||||
for dict_item in env_action_space_dict:
|
||||
# Node action detected
|
||||
if len(dict_item) == 4:
|
||||
contains_node_actions = True
|
||||
# Link action detected
|
||||
elif len(dict_item) == 6:
|
||||
contains_acl_actions = True
|
||||
# If both are there then the ANY action type is working
|
||||
if contains_node_actions and contains_acl_actions:
|
||||
both_action_spaces = True
|
||||
# Check condition should be True
|
||||
assert both_action_spaces
|
||||
|
||||
|
||||
def test_agent_is_executing_actions_from_both_spaces():
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT
|
||||
/ "single_action_space_fixed_blue_actions_main_config.yaml",
|
||||
TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_agent_is_executing_actions_from_both_spaces(temp_primaite_session):
|
||||
"""Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL)."""
|
||||
env = _get_primaite_env_from_config(
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "single_action_space_fixed_blue_actions_main_config.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT
|
||||
/ "single_action_space_lay_down_config.yaml",
|
||||
)
|
||||
# Run environment with specified fixed blue agent actions only
|
||||
run_generic_set_actions(env)
|
||||
# Retrieve hardware state of computer_1 node in laydown config
|
||||
# Agent turned this off in Step 5
|
||||
computer_node_hardware_state = env.nodes["1"].hardware_state
|
||||
# Retrieve the Access Control List object stored by the environment at the end of the episode
|
||||
access_control_list = env.acl
|
||||
# Use the Access Control List object acl object attribute to get dictionary
|
||||
# Use dictionary.values() to get total list of all items in the dictionary
|
||||
acl_rules_list = access_control_list.acl.values()
|
||||
# Length of this list tells you how many items are in the dictionary
|
||||
# This number is the frequency of Access Control Rules in the environment
|
||||
# In the scenario, we specified that the agent should create only 1 acl rule
|
||||
num_of_rules = len(acl_rules_list)
|
||||
# Therefore these statements below MUST be true
|
||||
assert computer_node_hardware_state == HardwareState.OFF
|
||||
assert num_of_rules == 1
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
# Run environment with specified fixed blue agent actions only
|
||||
run_generic_set_actions(env)
|
||||
# Retrieve hardware state of computer_1 node in laydown config
|
||||
# Agent turned this off in Step 5
|
||||
computer_node_hardware_state = env.nodes["1"].hardware_state
|
||||
# Retrieve the Access Control List object stored by the environment at the end of the episode
|
||||
access_control_list = env.acl
|
||||
# Use the Access Control List object acl object attribute to get dictionary
|
||||
# Use dictionary.values() to get total list of all items in the dictionary
|
||||
acl_rules_list = access_control_list.acl.values()
|
||||
# Length of this list tells you how many items are in the dictionary
|
||||
# This number is the frequency of Access Control Rules in the environment
|
||||
# In the scenario, we specified that the agent should create only 1 acl rule
|
||||
num_of_rules = len(acl_rules_list)
|
||||
# Therefore these statements below MUST be true
|
||||
assert computer_node_hardware_state == HardwareState.OFF
|
||||
assert num_of_rules == 1
|
||||
|
||||
@@ -16,7 +16,9 @@ def test_legacy_lay_down_config_yaml_conversion():
|
||||
with open(new_path, "r") as file:
|
||||
new_dict = yaml.safe_load(file)
|
||||
|
||||
converted_dict = training_config.convert_legacy_training_config_dict(legacy_dict)
|
||||
converted_dict = training_config.convert_legacy_training_config_dict(
|
||||
legacy_dict
|
||||
)
|
||||
|
||||
for key, value in new_dict.items():
|
||||
assert converted_dict[key] == value
|
||||
|
||||
Reference in New Issue
Block a user