Merge remote-tracking branch 'origin/dev' into feature/901-change-functionality-acl-rules
This commit is contained in:
4
docs/_templates/custom-class-template.rst
vendored
4
docs/_templates/custom-class-template.rst
vendored
@@ -1,3 +1,7 @@
|
||||
.. only:: comment
|
||||
|
||||
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
|
||||
..
|
||||
Credit to https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion for the custom templates.
|
||||
..
|
||||
|
||||
4
docs/_templates/custom-module-template.rst
vendored
4
docs/_templates/custom-module-template.rst
vendored
@@ -1,3 +1,7 @@
|
||||
.. only:: comment
|
||||
|
||||
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
|
||||
..
|
||||
Credit to https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion for the custom templates.
|
||||
..
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
.. only:: comment
|
||||
|
||||
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
|
||||
..
|
||||
DO NOT DELETE THIS FILE! It contains the all-important `.. autosummary::` directive with `:recursive:` option, without
|
||||
which API documentation wouldn't get extracted from docstrings by the `sphinx.ext.autosummary` engine. It is hidden
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# For the full list of built-in configuration values, see the documentation:
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
.. only:: comment
|
||||
|
||||
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
|
||||
Welcome to PrimAITE's documentation
|
||||
====================================
|
||||
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
.. _about:
|
||||
.. only:: comment
|
||||
|
||||
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
|
||||
.. _about:
|
||||
|
||||
About PrimAITE
|
||||
==============
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
.. only:: comment
|
||||
|
||||
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
|
||||
.. _config:
|
||||
|
||||
The Config Files Explained
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
Custom Agents
|
||||
.. only:: comment
|
||||
|
||||
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
|
||||
Custom Agents
|
||||
=============
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
.. only:: comment
|
||||
|
||||
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
|
||||
.. role:: raw-html(raw)
|
||||
:format: html
|
||||
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
.. only:: comment
|
||||
|
||||
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
|
||||
.. _getting-started:
|
||||
|
||||
Getting Started
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
.. only:: comment
|
||||
|
||||
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
|
||||
Glossary
|
||||
=============
|
||||
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
.. only:: comment
|
||||
|
||||
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
|
||||
v1.2 to v2.0 Migration guide
|
||||
============================
|
||||
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
.. only:: comment
|
||||
|
||||
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
|
||||
.. _run a primaite session:
|
||||
|
||||
Run a PrimAITE Session
|
||||
@@ -44,7 +48,8 @@ For example, when running a session at 17:30:00 on 31st January 2023, the sessio
|
||||
``~/primaite/sessions/2023-01-31/2023-01-31_17-30-00/``.
|
||||
|
||||
Loading a session
|
||||
-------
|
||||
-----------------
|
||||
|
||||
A previous session can be loaded by providing the **directory** of the previous session to either the ``primaite session`` command from the cli
|
||||
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` with session_path.
|
||||
|
||||
|
||||
2
setup.py
2
setup.py
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from setuptools import setup
|
||||
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel # noqa
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import logging
|
||||
import logging.config
|
||||
import sys
|
||||
@@ -6,7 +6,7 @@ from bisect import bisect
|
||||
from logging import Formatter, Logger, LogRecord, StreamHandler
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from typing import Dict, Final
|
||||
from typing import Any, Dict, Final
|
||||
|
||||
import pkg_resources
|
||||
import yaml
|
||||
@@ -16,7 +16,7 @@ _PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite")
|
||||
"""An instance of `PlatformDirs` set with appname='primaite'."""
|
||||
|
||||
|
||||
def _get_primaite_config():
|
||||
def _get_primaite_config() -> Dict:
|
||||
config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml"
|
||||
if not config_path.exists():
|
||||
config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
|
||||
@@ -72,7 +72,7 @@ class _LevelFormatter(Formatter):
|
||||
Credit to: https://stackoverflow.com/a/68154386
|
||||
"""
|
||||
|
||||
def __init__(self, formats: Dict[int, str], **kwargs):
|
||||
def __init__(self, formats: Dict[int, str], **kwargs: Any) -> None:
|
||||
super().__init__()
|
||||
|
||||
if "fmt" in kwargs:
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Access Control List. Models firewall functionality."""
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""A class that implements the access control list implementation for the network."""
|
||||
import logging
|
||||
from typing import Final, List, Union
|
||||
from typing import Dict, Final, List, Union
|
||||
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.common.enums import RulePermissionType
|
||||
@@ -12,7 +12,7 @@ _LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
||||
class AccessControlList:
|
||||
"""Access Control List class."""
|
||||
|
||||
def __init__(self, implicit_permission, max_acl_rules):
|
||||
def __init__(self, implicit_permission: RulePermissionType, max_acl_rules: int) -> None:
|
||||
"""Init."""
|
||||
# Implicit ALLOW or DENY firewall spec
|
||||
self.acl_implicit_permission = implicit_permission
|
||||
@@ -30,7 +30,7 @@ class AccessControlList:
|
||||
self._acl: List[Union[ACLRule, None]] = [None] * (self.max_acl_rules - 1)
|
||||
|
||||
@property
|
||||
def acl(self):
|
||||
def acl(self) -> List[Union[ACLRule, None]]:
|
||||
"""Public access method for private _acl."""
|
||||
return self._acl + [self.acl_implicit_rule]
|
||||
|
||||
@@ -84,7 +84,9 @@ class AccessControlList:
|
||||
# If there has been no rule to allow the IER through, it will return a blocked signal by default
|
||||
return True
|
||||
|
||||
def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port, _position):
|
||||
def add_rule(
|
||||
self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str, _position: int
|
||||
) -> None:
|
||||
"""
|
||||
Adds a new rule.
|
||||
|
||||
@@ -141,12 +143,12 @@ class AccessControlList:
|
||||
if isinstance(self._acl[index], ACLRule) and hash(self._acl[index]) == delete_rule_hash:
|
||||
self._acl[index] = None
|
||||
|
||||
def remove_all_rules(self):
|
||||
def remove_all_rules(self) -> None:
|
||||
"""Removes all rules."""
|
||||
for i in range(len(self._acl)):
|
||||
self._acl[i] = None
|
||||
|
||||
def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
def get_dictionary_hash(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> int:
|
||||
"""
|
||||
Produces a hash value for a rule.
|
||||
|
||||
@@ -164,7 +166,9 @@ class AccessControlList:
|
||||
hash_value = hash(rule)
|
||||
return hash_value
|
||||
|
||||
def get_relevant_rules(self, _source_ip_address, _dest_ip_address, _protocol, _port):
|
||||
def get_relevant_rules(
|
||||
self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Get all ACL rules that relate to the given arguments.
|
||||
|
||||
:param _source_ip_address: the source IP address to check
|
||||
@@ -172,7 +176,7 @@ class AccessControlList:
|
||||
:param _protocol: the protocol to check
|
||||
:param _port: the port to check
|
||||
:return: Dictionary of all ACL rules that relate to the given arguments
|
||||
:rtype: Dict[str, ACLRule]
|
||||
:rtype: Dict[int, ACLRule]
|
||||
"""
|
||||
relevant_rules = {}
|
||||
for rule in self.acl:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""A class that implements an access control list rule."""
|
||||
from primaite.common.enums import RulePermissionType
|
||||
|
||||
@@ -6,7 +6,9 @@ from primaite.common.enums import RulePermissionType
|
||||
class ACLRule:
|
||||
"""Access Control List Rule class."""
|
||||
|
||||
def __init__(self, _permission: RulePermissionType, _source_ip, _dest_ip, _protocol, _port):
|
||||
def __init__(
|
||||
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an ACL Rule.
|
||||
|
||||
@@ -16,13 +18,13 @@ class ACLRule:
|
||||
:param _protocol: The rule protocol
|
||||
:param _port: The rule port
|
||||
"""
|
||||
self.permission = _permission
|
||||
self.source_ip = _source_ip
|
||||
self.dest_ip = _dest_ip
|
||||
self.protocol = _protocol
|
||||
self.port = _port
|
||||
self.permission: RulePermissionType = _permission
|
||||
self.source_ip: str = _source_ip
|
||||
self.dest_ip: str = _dest_ip
|
||||
self.protocol: str = _protocol
|
||||
self.port: str = _port
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
"""
|
||||
Override the hash function.
|
||||
|
||||
@@ -39,7 +41,7 @@ class ACLRule:
|
||||
)
|
||||
)
|
||||
|
||||
def get_permission(self):
|
||||
def get_permission(self) -> str:
|
||||
"""
|
||||
Gets the permission attribute.
|
||||
|
||||
@@ -48,7 +50,7 @@ class ACLRule:
|
||||
"""
|
||||
return self.permission
|
||||
|
||||
def get_source_ip(self):
|
||||
def get_source_ip(self) -> str:
|
||||
"""
|
||||
Gets the source IP address attribute.
|
||||
|
||||
@@ -57,7 +59,7 @@ class ACLRule:
|
||||
"""
|
||||
return self.source_ip
|
||||
|
||||
def get_dest_ip(self):
|
||||
def get_dest_ip(self) -> str:
|
||||
"""
|
||||
Gets the desintation IP address attribute.
|
||||
|
||||
@@ -66,7 +68,7 @@ class ACLRule:
|
||||
"""
|
||||
return self.dest_ip
|
||||
|
||||
def get_protocol(self):
|
||||
def get_protocol(self) -> str:
|
||||
"""
|
||||
Gets the protocol attribute.
|
||||
|
||||
@@ -75,7 +77,7 @@ class ACLRule:
|
||||
"""
|
||||
return self.protocol
|
||||
|
||||
def get_port(self):
|
||||
def get_port(self) -> str:
|
||||
"""
|
||||
Gets the port attribute.
|
||||
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Common interface between RL agents from different libraries and PrimAITE."""
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import primaite
|
||||
@@ -15,7 +17,7 @@ from primaite.data_viz.session_plots import plot_av_reward_per_episode
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.utils.session_metadata_parser import parse_session_metadata
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def get_session_path(session_timestamp: datetime) -> Path:
|
||||
@@ -50,7 +52,7 @@ class AgentSessionABC(ABC):
|
||||
training_config_path: Optional[Union[str, Path]] = None,
|
||||
lay_down_config_path: Optional[Union[str, Path]] = None,
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an agent session from config files, or load a previous session.
|
||||
|
||||
@@ -130,11 +132,11 @@ class AgentSessionABC(ABC):
|
||||
return path
|
||||
|
||||
@property
|
||||
def uuid(self):
|
||||
def uuid(self) -> str:
|
||||
"""The Agent Session UUID."""
|
||||
return self._uuid
|
||||
|
||||
def _write_session_metadata_file(self):
|
||||
def _write_session_metadata_file(self) -> None:
|
||||
"""
|
||||
Write the ``session_metadata.json`` file.
|
||||
|
||||
@@ -170,7 +172,7 @@ class AgentSessionABC(ABC):
|
||||
json.dump(metadata_dict, file)
|
||||
_LOGGER.debug("Finished writing session metadata file")
|
||||
|
||||
def _update_session_metadata_file(self):
|
||||
def _update_session_metadata_file(self) -> None:
|
||||
"""
|
||||
Update the ``session_metadata.json`` file.
|
||||
|
||||
@@ -199,7 +201,7 @@ class AgentSessionABC(ABC):
|
||||
_LOGGER.debug("Finished updating session metadata file")
|
||||
|
||||
@abstractmethod
|
||||
def _setup(self):
|
||||
def _setup(self) -> None:
|
||||
_LOGGER.info(
|
||||
"Welcome to the Primary-level AI Training Environment " f"(PrimAITE) (version: {primaite.__version__})"
|
||||
)
|
||||
@@ -209,14 +211,14 @@ class AgentSessionABC(ABC):
|
||||
self._can_evaluate = False
|
||||
|
||||
@abstractmethod
|
||||
def _save_checkpoint(self):
|
||||
def _save_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
@@ -233,8 +235,8 @@ class AgentSessionABC(ABC):
|
||||
@abstractmethod
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
@@ -247,10 +249,10 @@ class AgentSessionABC(ABC):
|
||||
_LOGGER.info("Finished evaluation")
|
||||
|
||||
@abstractmethod
|
||||
def _get_latest_checkpoint(self):
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def load(self, path: Union[str, Path]):
|
||||
def load(self, path: Union[str, Path]) -> None:
|
||||
"""Load an agent from file."""
|
||||
md_dict, training_config_path, laydown_config_path = parse_session_metadata(path)
|
||||
|
||||
@@ -274,21 +276,21 @@ class AgentSessionABC(ABC):
|
||||
return self.learning_path / file_name
|
||||
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def export(self):
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Closes the agent."""
|
||||
self._env.episode_av_reward_writer.close() # noqa
|
||||
self._env.transaction_writer.close() # noqa
|
||||
|
||||
def _plot_av_reward_per_episode(self, learning_session: bool = True):
|
||||
def _plot_av_reward_per_episode(self, learning_session: bool = True) -> None:
|
||||
# self.close()
|
||||
title = f"PrimAITE Session {self.timestamp_str} "
|
||||
subtitle = str(self._training_config)
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
@@ -23,7 +26,7 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a hardcoded agent session.
|
||||
|
||||
@@ -36,7 +39,7 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
super().__init__(training_config_path, lay_down_config_path, session_path)
|
||||
self._setup()
|
||||
|
||||
def _setup(self):
|
||||
def _setup(self) -> None:
|
||||
self._env: Primaite = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
@@ -47,16 +50,16 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
self._can_learn = False
|
||||
self._can_evaluate = True
|
||||
|
||||
def _save_checkpoint(self):
|
||||
def _save_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
@@ -65,13 +68,13 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
_LOGGER.warning("Deterministic agents cannot learn")
|
||||
|
||||
@abstractmethod
|
||||
def _calculate_action(self, obs):
|
||||
def _calculate_action(self, obs: np.ndarray) -> None:
|
||||
pass
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
@@ -102,14 +105,14 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
self._env.close()
|
||||
|
||||
@classmethod
|
||||
def load(cls, path=None):
|
||||
def load(cls, path: Union[str, Path] = None) -> None:
|
||||
"""Load an agent from file."""
|
||||
_LOGGER.warning("Deterministic agents cannot be loaded")
|
||||
|
||||
def save(self):
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
_LOGGER.warning("Deterministic agents cannot be saved")
|
||||
|
||||
def export(self):
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
_LOGGER.warning("Deterministic agents cannot be exported")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any, Dict, List, Union
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -32,7 +33,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
def get_blocked_green_iers(
|
||||
self, green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[Any, Any]:
|
||||
) -> Dict[str, IER]:
|
||||
"""Get blocked green IERs.
|
||||
|
||||
:param green_iers: Green IERs to check for being
|
||||
@@ -60,7 +61,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
return blocked_green_iers
|
||||
|
||||
def get_matching_acl_rules_for_ier(self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]):
|
||||
def get_matching_acl_rules_for_ier(
|
||||
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Get list of ACL rules which are relevant to an IER.
|
||||
|
||||
:param ier: Information Exchange Request to query against the ACL list
|
||||
@@ -83,7 +86,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
def get_blocking_acl_rules_for_ier(
|
||||
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[str, Any]:
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""
|
||||
Get blocking ACL rules for an IER.
|
||||
|
||||
@@ -111,7 +114,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
def get_allow_acl_rules_for_ier(
|
||||
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[str, Any]:
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Get all allowing ACL rules for an IER.
|
||||
|
||||
:param ier: Information Exchange Request to query against the ACL list
|
||||
@@ -141,7 +144,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
acl: AccessControlList,
|
||||
nodes: Dict[str, Union[ServiceNode, ActiveNode]],
|
||||
services_list: List[str],
|
||||
) -> Dict[str, ACLRule]:
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Filter ACL rules to only those which are relevant to the specified nodes.
|
||||
|
||||
:param source_node_id: Source node
|
||||
@@ -173,6 +176,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
if protocol != "ANY":
|
||||
protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services
|
||||
# TODO: This should throw an error because protocol is a string
|
||||
|
||||
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
|
||||
return matching_rules
|
||||
@@ -186,7 +190,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
acl: AccessControlList,
|
||||
nodes: Dict[str, NodeUnion],
|
||||
services_list: List[str],
|
||||
) -> Dict[str, ACLRule]:
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""List ALLOW rules relating to specified nodes.
|
||||
|
||||
:param source_node_id: Source node id
|
||||
@@ -233,7 +237,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
acl: AccessControlList,
|
||||
nodes: Dict[str, NodeUnion],
|
||||
services_list: List[str],
|
||||
) -> Dict[str, ACLRule]:
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""List DENY rules relating to specified nodes.
|
||||
|
||||
:param source_node_id: Source node id
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import numpy as np
|
||||
|
||||
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
|
||||
@@ -101,6 +102,7 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
# TODO: transform_action_node_enum takes only one argument, not sure why two are given here.
|
||||
action = transform_action_node_enum(action, action_dict)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from ray.rllib.algorithms import Algorithm
|
||||
@@ -18,10 +20,11 @@ from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def _env_creator(env_config):
|
||||
# TODO: verify type of env_config
|
||||
def _env_creator(env_config: Dict[str, Any]) -> Primaite:
|
||||
return Primaite(
|
||||
training_config_path=env_config["training_config_path"],
|
||||
lay_down_config_path=env_config["lay_down_config_path"],
|
||||
@@ -30,11 +33,12 @@ def _env_creator(env_config):
|
||||
)
|
||||
|
||||
|
||||
def _custom_log_creator(session_path: Path):
|
||||
# TODO: verify type hint return type
|
||||
def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]:
|
||||
logdir = session_path / "ray_results"
|
||||
logdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def logger_creator(config):
|
||||
def logger_creator(config: Dict) -> UnifiedLogger:
|
||||
return UnifiedLogger(config, logdir, loggers=None)
|
||||
|
||||
return logger_creator
|
||||
@@ -48,7 +52,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the RLLib Agent training session.
|
||||
|
||||
@@ -73,6 +77,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
self._agent_config_class: Union[PPOConfig, A2CConfig]
|
||||
if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
||||
self._agent_config_class = PPOConfig
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
||||
@@ -94,7 +99,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
f"{self._training_config.deep_learning_framework}"
|
||||
)
|
||||
|
||||
def _update_session_metadata_file(self):
|
||||
def _update_session_metadata_file(self) -> None:
|
||||
"""
|
||||
Update the ``session_metadata.json`` file.
|
||||
|
||||
@@ -122,7 +127,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
json.dump(metadata_dict, file)
|
||||
_LOGGER.debug("Finished updating session metadata file")
|
||||
|
||||
def _setup(self):
|
||||
def _setup(self) -> None:
|
||||
super()._setup()
|
||||
register_env("primaite", _env_creator)
|
||||
self._agent_config = self._agent_config_class()
|
||||
@@ -148,7 +153,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
)
|
||||
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
|
||||
|
||||
def _save_checkpoint(self):
|
||||
def _save_checkpoint(self) -> None:
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._current_result["episodes_total"]
|
||||
save_checkpoint = False
|
||||
@@ -159,8 +164,8 @@ class RLlibAgent(AgentSessionABC):
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
@@ -180,8 +185,8 @@ class RLlibAgent(AgentSessionABC):
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: None,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
@@ -189,7 +194,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@@ -197,7 +202,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
"""Load an agent from file."""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self, overwrite_existing: bool = True):
|
||||
def save(self, overwrite_existing: bool = True) -> None:
|
||||
"""Save the agent."""
|
||||
# Make temp dir to save in isolation
|
||||
temp_dir = self.learning_path / str(uuid4())
|
||||
@@ -217,6 +222,6 @@ class RLlibAgent(AgentSessionABC):
|
||||
# Drop the temp directory
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
def export(self):
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import A2C, PPO
|
||||
@@ -13,7 +15,7 @@ from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
class SB3Agent(AgentSessionABC):
|
||||
@@ -24,7 +26,7 @@ class SB3Agent(AgentSessionABC):
|
||||
training_config_path: Optional[Union[str, Path]] = None,
|
||||
lay_down_config_path: Optional[Union[str, Path]] = None,
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the SB3 Agent training session.
|
||||
|
||||
@@ -42,6 +44,7 @@ class SB3Agent(AgentSessionABC):
|
||||
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
self._agent_class: Union[PPO, A2C]
|
||||
if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
||||
self._agent_class = PPO
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
||||
@@ -65,7 +68,7 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
self._setup()
|
||||
|
||||
def _setup(self):
|
||||
def _setup(self) -> None:
|
||||
"""Set up the SB3 Agent."""
|
||||
self._env = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
@@ -112,7 +115,7 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
super()._setup()
|
||||
|
||||
def _save_checkpoint(self):
|
||||
def _save_checkpoint(self) -> None:
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._env.episode_count
|
||||
save_checkpoint = False
|
||||
@@ -123,13 +126,13 @@ class SB3Agent(AgentSessionABC):
|
||||
self._agent.save(checkpoint_path)
|
||||
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
@@ -152,8 +155,8 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
@@ -182,10 +185,10 @@ class SB3Agent(AgentSessionABC):
|
||||
self._env.close()
|
||||
super().evaluate()
|
||||
|
||||
def save(self):
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
self._agent.save(self._saved_agent_path)
|
||||
|
||||
def export(self):
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
|
||||
|
||||
@@ -9,7 +13,7 @@ class RandomAgent(HardCodedAgentSessionABC):
|
||||
Get a completely random action from the action space.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
return self._env.action_space.sample()
|
||||
|
||||
|
||||
@@ -20,7 +24,7 @@ class DummyAgent(HardCodedAgentSessionABC):
|
||||
All action spaces setup so dummy action is always 0 regardless of action type used.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
@@ -31,7 +35,7 @@ class DoNothingACLAgent(HardCodedAgentSessionABC):
|
||||
A valid ACL action that has no effect; does nothing.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
|
||||
nothing_action = transform_action_acl_enum(nothing_action)
|
||||
nothing_action = get_new_action(nothing_action, self._env.action_dict)
|
||||
@@ -46,7 +50,7 @@ class DoNothingNodeAgent(HardCodedAgentSessionABC):
|
||||
A valid Node action that has no effect; does nothing.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
nothing_action = [1, "NONE", "ON", 0]
|
||||
nothing_action = transform_action_node_enum(nothing_action)
|
||||
nothing_action = get_new_action(nothing_action, self._env.action_dict)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -34,11 +35,11 @@ def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]:
|
||||
else:
|
||||
property_action = "NONE"
|
||||
|
||||
new_action = [action[0], action_node_property, property_action, action[3]]
|
||||
new_action: list[Union[int, str]] = [action[0], action_node_property, property_action, action[3]]
|
||||
return new_action
|
||||
|
||||
|
||||
def transform_action_acl_readable(action: List[str]) -> List[Union[str, int]]:
|
||||
def transform_action_acl_readable(action: List[int]) -> List[Union[str, int]]:
|
||||
"""
|
||||
Transform an ACL action to a more readable format.
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Provides a CLI using Typer as an entry point."""
|
||||
import logging
|
||||
import os
|
||||
@@ -19,7 +19,7 @@ app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def build_dirs():
|
||||
def build_dirs() -> None:
|
||||
"""Build the PrimAITE app directories."""
|
||||
from primaite.setup import setup_app_dirs
|
||||
|
||||
@@ -27,7 +27,7 @@ def build_dirs():
|
||||
|
||||
|
||||
@app.command()
|
||||
def reset_notebooks(overwrite: bool = True):
|
||||
def reset_notebooks(overwrite: bool = True) -> None:
|
||||
"""
|
||||
Force a reset of the demo notebooks in the users notebooks directory.
|
||||
|
||||
@@ -39,7 +39,7 @@ def reset_notebooks(overwrite: bool = True):
|
||||
|
||||
|
||||
@app.command()
|
||||
def logs(last_n: Annotated[int, typer.Option("-n")]):
|
||||
def logs(last_n: Annotated[int, typer.Option("-n")]) -> None:
|
||||
"""
|
||||
Print the PrimAITE log file.
|
||||
|
||||
@@ -60,7 +60,7 @@ _LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # n
|
||||
|
||||
|
||||
@app.command()
|
||||
def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None):
|
||||
def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None) -> None:
|
||||
"""
|
||||
View or set the PrimAITE Log Level.
|
||||
|
||||
@@ -88,7 +88,7 @@ def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None):
|
||||
|
||||
|
||||
@app.command()
|
||||
def notebooks():
|
||||
def notebooks() -> None:
|
||||
"""Start Jupyter Lab in the users PrimAITE notebooks directory."""
|
||||
from primaite.notebooks import start_jupyter_session
|
||||
|
||||
@@ -96,7 +96,7 @@ def notebooks():
|
||||
|
||||
|
||||
@app.command()
|
||||
def version():
|
||||
def version() -> None:
|
||||
"""Get the installed PrimAITE version number."""
|
||||
import primaite
|
||||
|
||||
@@ -104,7 +104,7 @@ def version():
|
||||
|
||||
|
||||
@app.command()
|
||||
def clean_up():
|
||||
def clean_up() -> None:
|
||||
"""Cleans up left over files from previous version installations."""
|
||||
from primaite.setup import old_installation_clean_up
|
||||
|
||||
@@ -112,7 +112,7 @@ def clean_up():
|
||||
|
||||
|
||||
@app.command()
|
||||
def setup(overwrite_existing: bool = True):
|
||||
def setup(overwrite_existing: bool = True) -> None:
|
||||
"""
|
||||
Perform the PrimAITE first-time setup.
|
||||
|
||||
@@ -151,7 +151,7 @@ def setup(overwrite_existing: bool = True):
|
||||
|
||||
|
||||
@app.command()
|
||||
def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[str] = None):
|
||||
def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[str] = None) -> None:
|
||||
"""
|
||||
Run a PrimAITE session.
|
||||
|
||||
@@ -185,7 +185,7 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[
|
||||
|
||||
|
||||
@app.command()
|
||||
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None):
|
||||
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None) -> None:
|
||||
"""
|
||||
View or set the plotly template for Session plots.
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Objects which are shared between many PrimAITE modules."""
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Type, Union
|
||||
from typing import Union
|
||||
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.passive_node import PassiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
NodeUnion: Type = Union[ActiveNode, PassiveNode, ServiceNode]
|
||||
NodeUnion = Union[ActiveNode, PassiveNode, ServiceNode]
|
||||
"""A Union of ActiveNode, PassiveNode, and ServiceNode."""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Enumerations for APE."""
|
||||
|
||||
from enum import Enum, IntEnum
|
||||
@@ -148,6 +148,7 @@ class ActionType(Enum):
|
||||
ANY = 2
|
||||
|
||||
|
||||
# TODO: this is not used anymore, write a ticket to delete it.
|
||||
class ObservationType(Enum):
|
||||
"""Observation type enumeration."""
|
||||
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""The protocol class."""
|
||||
|
||||
|
||||
class Protocol(object):
|
||||
"""Protocol class."""
|
||||
|
||||
def __init__(self, _name):
|
||||
def __init__(self, _name: str) -> None:
|
||||
"""
|
||||
Initialise a protocol.
|
||||
|
||||
:param _name: The name of the protocol
|
||||
:type _name: str
|
||||
"""
|
||||
self.name = _name
|
||||
self.load = 0 # bps
|
||||
self.name: str = _name
|
||||
self.load: int = 0 # bps
|
||||
|
||||
def get_name(self):
|
||||
def get_name(self) -> str:
|
||||
"""
|
||||
Gets the protocol name.
|
||||
|
||||
@@ -24,7 +24,7 @@ class Protocol(object):
|
||||
"""
|
||||
return self.name
|
||||
|
||||
def get_load(self):
|
||||
def get_load(self) -> int:
|
||||
"""
|
||||
Gets the protocol load.
|
||||
|
||||
@@ -33,7 +33,7 @@ class Protocol(object):
|
||||
"""
|
||||
return self.load
|
||||
|
||||
def add_load(self, _load):
|
||||
def add_load(self, _load: int) -> None:
|
||||
"""
|
||||
Adds load to the protocol.
|
||||
|
||||
@@ -42,6 +42,6 @@ class Protocol(object):
|
||||
"""
|
||||
self.load += _load
|
||||
|
||||
def clear_load(self):
|
||||
def clear_load(self) -> None:
|
||||
"""Clears the load on this protocol."""
|
||||
self.load = 0
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""The Service class."""
|
||||
|
||||
from primaite.common.enums import SoftwareState
|
||||
@@ -7,7 +7,7 @@ from primaite.common.enums import SoftwareState
|
||||
class Service(object):
|
||||
"""Service class."""
|
||||
|
||||
def __init__(self, name: str, port: str, software_state: SoftwareState):
|
||||
def __init__(self, name: str, port: str, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Initialise a service.
|
||||
|
||||
@@ -15,12 +15,12 @@ class Service(object):
|
||||
:param port: The service port.
|
||||
:param software_state: The service SoftwareState.
|
||||
"""
|
||||
self.name = name
|
||||
self.port = port
|
||||
self.software_state = software_state
|
||||
self.patching_count = 0
|
||||
self.name: str = name
|
||||
self.port: str = port
|
||||
self.software_state: SoftwareState = software_state
|
||||
self.patching_count: int = 0
|
||||
|
||||
def reduce_patching_count(self):
|
||||
def reduce_patching_count(self) -> None:
|
||||
"""Reduces the patching count for the service."""
|
||||
self.patching_count -= 1
|
||||
if self.patching_count <= 0:
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Configuration parameters for running experiments."""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Union
|
||||
|
||||
@@ -6,7 +7,7 @@ import yaml
|
||||
|
||||
from primaite import getLogger, USERS_CONFIG_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Optional, Union
|
||||
|
||||
@@ -19,7 +20,7 @@ from primaite.common.enums import (
|
||||
SessionType,
|
||||
)
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training"
|
||||
|
||||
@@ -86,7 +87,7 @@ class TrainingConfig:
|
||||
session_type: SessionType = SessionType.TRAIN
|
||||
"The type of PrimAITE session to run"
|
||||
|
||||
load_agent: str = False
|
||||
load_agent: bool = False
|
||||
"Determine whether to load an agent from file"
|
||||
|
||||
agent_load_file: Optional[str] = None
|
||||
@@ -198,7 +199,7 @@ class TrainingConfig:
|
||||
"The random number generator seed to be used while training the agent"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig:
|
||||
def from_dict(cls, config_dict: Dict[str, Any]) -> TrainingConfig:
|
||||
"""
|
||||
Create an instance of TrainingConfig from a dict.
|
||||
|
||||
@@ -216,12 +217,14 @@ class TrainingConfig:
|
||||
"implicit_acl_rule": RulePermissionType,
|
||||
}
|
||||
|
||||
# convert the string representation of enums into the actual enum values themselves?
|
||||
for key, value in field_enum_map.items():
|
||||
if key in config_dict:
|
||||
config_dict[key] = value[config_dict[key]]
|
||||
|
||||
return TrainingConfig(**config_dict)
|
||||
|
||||
def to_dict(self, json_serializable: bool = True):
|
||||
def to_dict(self, json_serializable: bool = True) -> Dict:
|
||||
"""
|
||||
Serialise the ``TrainingConfig`` as dict.
|
||||
|
||||
@@ -341,7 +344,7 @@ def convert_legacy_training_config_dict(
|
||||
return config_dict
|
||||
|
||||
|
||||
def _get_new_key_from_legacy(legacy_key: str) -> str:
|
||||
def _get_new_key_from_legacy(legacy_key: str) -> Optional[str]:
|
||||
"""
|
||||
Maps legacy training config keys to the new format keys.
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Utility to generate plots of sessions metrics after PrimAITE."""
|
||||
from enum import Enum
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Gym/Gymnasium environment for RL agents consisting of a simulated computer network."""
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Module for handling configurable observation spaces in PrimAITE."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -17,14 +19,15 @@ from primaite.nodes.service_node import ServiceNode
|
||||
if TYPE_CHECKING:
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_LOGGER: Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbstractObservationComponent(ABC):
|
||||
"""Represents a part of the PrimAITE observation space."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, env: "Primaite"):
|
||||
def __init__(self, env: "Primaite") -> None:
|
||||
"""
|
||||
Initialise observation component.
|
||||
|
||||
@@ -39,7 +42,7 @@ class AbstractObservationComponent(ABC):
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def update(self):
|
||||
def update(self) -> None:
|
||||
"""Update the observation based on the current state of the environment."""
|
||||
self.current_observation = NotImplemented
|
||||
|
||||
@@ -74,7 +77,7 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
_MAX_VAL: int = 1_000_000_000
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite"):
|
||||
def __init__(self, env: "Primaite") -> None:
|
||||
"""
|
||||
Initialise a NodeLinkTable observation space component.
|
||||
|
||||
@@ -101,7 +104,7 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self):
|
||||
def update(self) -> None:
|
||||
"""
|
||||
Update the observation based on current environment state.
|
||||
|
||||
@@ -148,7 +151,7 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
protocol_index += 1
|
||||
item_index += 1
|
||||
|
||||
def generate_structure(self):
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
nodes = self.env.nodes.values()
|
||||
links = self.env.links.values()
|
||||
@@ -211,7 +214,7 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite"):
|
||||
def __init__(self, env: "Primaite") -> None:
|
||||
"""
|
||||
Initialise a NodeStatuses observation component.
|
||||
|
||||
@@ -237,7 +240,7 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self):
|
||||
def update(self) -> None:
|
||||
"""
|
||||
Update the observation based on current environment state.
|
||||
|
||||
@@ -268,7 +271,7 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
)
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self):
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
services = self.env.services_list
|
||||
|
||||
@@ -318,7 +321,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
env: "Primaite",
|
||||
combine_service_traffic: bool = False,
|
||||
quantisation_levels: int = 5,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a LinkTrafficLevels observation component.
|
||||
|
||||
@@ -360,7 +363,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self):
|
||||
def update(self) -> None:
|
||||
"""
|
||||
Update the observation based on current environment state.
|
||||
|
||||
@@ -386,7 +389,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self):
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
structure = []
|
||||
for _, link in self.env.links.items():
|
||||
@@ -470,7 +473,7 @@ class AccessControlList(AbstractObservationComponent):
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self):
|
||||
def update(self) -> None:
|
||||
"""Update the observation based on current environment state.
|
||||
|
||||
The structure of the observation space is described in :class:`.AccessControlList`
|
||||
@@ -550,7 +553,7 @@ class AccessControlList(AbstractObservationComponent):
|
||||
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self):
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
structure = []
|
||||
for acl_rule in self.env.acl.acl:
|
||||
@@ -593,7 +596,7 @@ class ObservationsHandler:
|
||||
"ACCESS_CONTROL_LIST": AccessControlList,
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""Initialise the observation handler."""
|
||||
self.registered_obs_components: List[AbstractObservationComponent] = []
|
||||
|
||||
@@ -606,7 +609,7 @@ class ObservationsHandler:
|
||||
# used for transactions and when flatten=true
|
||||
self._flat_observation: np.ndarray
|
||||
|
||||
def update_obs(self):
|
||||
def update_obs(self) -> None:
|
||||
"""Fetch fresh information about the environment."""
|
||||
current_obs = []
|
||||
for obs in self.registered_obs_components:
|
||||
@@ -619,7 +622,7 @@ class ObservationsHandler:
|
||||
self._observation = tuple(current_obs)
|
||||
self._flat_observation = spaces.flatten(self._space, self._observation)
|
||||
|
||||
def register(self, obs_component: AbstractObservationComponent):
|
||||
def register(self, obs_component: AbstractObservationComponent) -> None:
|
||||
"""
|
||||
Add a component for this handler to track.
|
||||
|
||||
@@ -629,7 +632,7 @@ class ObservationsHandler:
|
||||
self.registered_obs_components.append(obs_component)
|
||||
self.update_space()
|
||||
|
||||
def deregister(self, obs_component: AbstractObservationComponent):
|
||||
def deregister(self, obs_component: AbstractObservationComponent) -> None:
|
||||
"""
|
||||
Remove a component from this handler.
|
||||
|
||||
@@ -640,7 +643,7 @@ class ObservationsHandler:
|
||||
self.registered_obs_components.remove(obs_component)
|
||||
self.update_space()
|
||||
|
||||
def update_space(self):
|
||||
def update_space(self) -> None:
|
||||
"""Rebuild the handler's composite observation space from its components."""
|
||||
component_spaces = []
|
||||
for obs_comp in self.registered_obs_components:
|
||||
@@ -657,7 +660,7 @@ class ObservationsHandler:
|
||||
self._flat_space = spaces.Box(0, 1, (0,))
|
||||
|
||||
@property
|
||||
def space(self):
|
||||
def space(self) -> spaces.Space:
|
||||
"""Observation space, return the flattened version if flatten is True."""
|
||||
if len(self.registered_obs_components) > 1:
|
||||
return self._flat_space
|
||||
@@ -665,7 +668,7 @@ class ObservationsHandler:
|
||||
return self._space
|
||||
|
||||
@property
|
||||
def current_observation(self):
|
||||
def current_observation(self) -> Union[np.ndarray, Tuple[np.ndarray]]:
|
||||
"""Current observation, return the flattened version if flatten is True."""
|
||||
if len(self.registered_obs_components) > 1:
|
||||
return self._flat_observation
|
||||
@@ -673,7 +676,7 @@ class ObservationsHandler:
|
||||
return self._observation
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, env: "Primaite", obs_space_config: dict):
|
||||
def from_config(cls, env: "Primaite", obs_space_config: dict) -> "ObservationsHandler":
|
||||
"""
|
||||
Parse a config dictinary, return a new observation handler populated with new observation component objects.
|
||||
|
||||
@@ -716,7 +719,7 @@ class ObservationsHandler:
|
||||
handler.update_obs()
|
||||
return handler
|
||||
|
||||
def describe_structure(self):
|
||||
def describe_structure(self) -> List[str]:
|
||||
"""
|
||||
Create a list of names for the features of the obs space.
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Main environment module containing the PRIMmary AI Training Evironment (Primaite) class."""
|
||||
import copy
|
||||
import logging
|
||||
import uuid as uuid
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from random import choice, randint, sample, uniform
|
||||
from typing import Dict, Final, Tuple, Union
|
||||
from typing import Any, Dict, Final, List, Tuple, Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
@@ -20,6 +21,7 @@ from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import (
|
||||
ActionType,
|
||||
AgentFramework,
|
||||
AgentIdentifier,
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
NodePOLInitiator,
|
||||
@@ -48,7 +50,7 @@ from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_nod
|
||||
from primaite.transactions.transaction import Transaction
|
||||
from primaite.utils.session_output_writer import SessionOutputWriter
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
class Primaite(Env):
|
||||
@@ -66,7 +68,7 @@ class Primaite(Env):
|
||||
lay_down_config_path: Union[str, Path],
|
||||
session_path: Path,
|
||||
timestamp_str: str,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
The Primaite constructor.
|
||||
|
||||
@@ -77,13 +79,14 @@ class Primaite(Env):
|
||||
"""
|
||||
self.session_path: Final[Path] = session_path
|
||||
self.timestamp_str: Final[str] = timestamp_str
|
||||
self._training_config_path = training_config_path
|
||||
self._lay_down_config_path = lay_down_config_path
|
||||
self._training_config_path: Union[str, Path] = training_config_path
|
||||
self._lay_down_config_path: Union[str, Path] = lay_down_config_path
|
||||
|
||||
self.training_config: TrainingConfig = training_config.load(training_config_path)
|
||||
_LOGGER.info(f"Using: {str(self.training_config)}")
|
||||
|
||||
# Number of steps in an episode
|
||||
self.episode_steps: int
|
||||
if self.training_config.session_type == SessionType.TRAIN:
|
||||
self.episode_steps = self.training_config.num_train_steps
|
||||
elif self.training_config.session_type == SessionType.EVAL:
|
||||
@@ -94,7 +97,7 @@ class Primaite(Env):
|
||||
super(Primaite, self).__init__()
|
||||
|
||||
# The agent in use
|
||||
self.agent_identifier = self.training_config.agent_identifier
|
||||
self.agent_identifier: AgentIdentifier = self.training_config.agent_identifier
|
||||
|
||||
# Create a dictionary to hold all the nodes
|
||||
self.nodes: Dict[str, NodeUnion] = {}
|
||||
@@ -113,42 +116,42 @@ class Primaite(Env):
|
||||
self.green_iers_reference: Dict[str, IER] = {}
|
||||
|
||||
# Create a dictionary to hold all the node PoLs (this will come from an external source)
|
||||
self.node_pol = {}
|
||||
self.node_pol: Dict[str, NodeStateInstructionGreen] = {}
|
||||
|
||||
# Create a dictionary to hold all the red agent IERs (this will come from an external source)
|
||||
self.red_iers = {}
|
||||
self.red_iers: Dict[str, IER] = {}
|
||||
|
||||
# Create a dictionary to hold all the red agent node PoLs (this will come from an external source)
|
||||
self.red_node_pol = {}
|
||||
self.red_node_pol: Dict[str, NodeStateInstructionRed] = {}
|
||||
|
||||
# Create the Access Control List
|
||||
self.acl = AccessControlList(
|
||||
self.acl: AccessControlList = AccessControlList(
|
||||
self.training_config.implicit_acl_rule,
|
||||
self.training_config.max_number_acl_rules,
|
||||
)
|
||||
# Sets limit for number of ACL rules in environment
|
||||
self.max_number_acl_rules = self.training_config.max_number_acl_rules
|
||||
self.max_number_acl_rules: int = self.training_config.max_number_acl_rules
|
||||
|
||||
# Create a list of services (enums)
|
||||
self.services_list = []
|
||||
self.services_list: List[str] = []
|
||||
|
||||
# Create a list of ports
|
||||
self.ports_list = []
|
||||
self.ports_list: List[str] = []
|
||||
|
||||
# Create graph (network)
|
||||
self.network = nx.MultiGraph()
|
||||
self.network: nx.Graph = nx.MultiGraph()
|
||||
|
||||
# Create a graph (network) reference
|
||||
self.network_reference = nx.MultiGraph()
|
||||
self.network_reference: nx.Graph = nx.MultiGraph()
|
||||
|
||||
# Create step count
|
||||
self.step_count = 0
|
||||
self.step_count: int = 0
|
||||
|
||||
self.total_step_count: int = 0
|
||||
"""The total number of time steps completed."""
|
||||
|
||||
# Create step info dictionary
|
||||
self.step_info = {}
|
||||
self.step_info: Dict[Any] = {}
|
||||
|
||||
# Total reward
|
||||
self.total_reward: float = 0
|
||||
@@ -157,22 +160,23 @@ class Primaite(Env):
|
||||
self.average_reward: float = 0
|
||||
|
||||
# Episode count
|
||||
self.episode_count = 0
|
||||
self.episode_count: int = 0
|
||||
|
||||
# Number of nodes - gets a value by examining the nodes dictionary after it's been populated
|
||||
self.num_nodes = 0
|
||||
self.num_nodes: int = 0
|
||||
|
||||
# Number of links - gets a value by examining the links dictionary after it's been populated
|
||||
self.num_links = 0
|
||||
self.num_links: int = 0
|
||||
|
||||
# Number of services - gets a value when config is loaded
|
||||
self.num_services = 0
|
||||
self.num_services: int = 0
|
||||
|
||||
# Number of ports - gets a value when config is loaded
|
||||
self.num_ports = 0
|
||||
self.num_ports: int = 0
|
||||
|
||||
# The action type
|
||||
self.action_type = 0
|
||||
# TODO: confirm type
|
||||
self.action_type: int = 0
|
||||
|
||||
# TODO fix up with TrainingConfig
|
||||
# stores the observation config from the yaml, default is NODE_LINK_TABLE
|
||||
@@ -184,7 +188,7 @@ class Primaite(Env):
|
||||
# It will be initialised later.
|
||||
self.obs_handler: ObservationsHandler
|
||||
|
||||
self._obs_space_description = None
|
||||
self._obs_space_description: List[str] = None
|
||||
"The env observation space description for transactions writing"
|
||||
|
||||
# Open the config file and build the environment laydown
|
||||
@@ -216,9 +220,13 @@ class Primaite(Env):
|
||||
_LOGGER.error("Could not save network diagram", exc_info=True)
|
||||
|
||||
# Initiate observation space
|
||||
self.observation_space: spaces.Space
|
||||
self.env_obs: np.ndarray
|
||||
self.observation_space, self.env_obs = self.init_observations()
|
||||
|
||||
# Define Action Space - depends on action space type (Node or ACL)
|
||||
self.action_dict: Dict[int, List[int]]
|
||||
self.action_space: spaces.Space
|
||||
if self.training_config.action_type == ActionType.NODE:
|
||||
_LOGGER.debug("Action space type NODE selected")
|
||||
# Terms (for node action space):
|
||||
@@ -246,8 +254,12 @@ class Primaite(Env):
|
||||
else:
|
||||
_LOGGER.error(f"Invalid action type selected: {self.training_config.action_type}")
|
||||
|
||||
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=True)
|
||||
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=True)
|
||||
self.episode_av_reward_writer: SessionOutputWriter = SessionOutputWriter(
|
||||
self, transaction_writer=False, learning_session=True
|
||||
)
|
||||
self.transaction_writer: SessionOutputWriter = SessionOutputWriter(
|
||||
self, transaction_writer=True, learning_session=True
|
||||
)
|
||||
|
||||
@property
|
||||
def actual_episode_count(self) -> int:
|
||||
@@ -256,7 +268,7 @@ class Primaite(Env):
|
||||
return self.episode_count - 1
|
||||
return self.episode_count
|
||||
|
||||
def set_as_eval(self):
|
||||
def set_as_eval(self) -> None:
|
||||
"""Set the writers to write to eval directories."""
|
||||
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=False)
|
||||
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=False)
|
||||
@@ -265,12 +277,12 @@ class Primaite(Env):
|
||||
self.total_step_count = 0
|
||||
self.episode_steps = self.training_config.num_eval_steps
|
||||
|
||||
def _write_av_reward_per_episode(self):
|
||||
def _write_av_reward_per_episode(self) -> None:
|
||||
if self.actual_episode_count > 0:
|
||||
csv_data = self.actual_episode_count, self.average_reward
|
||||
self.episode_av_reward_writer.write(csv_data)
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> np.ndarray:
|
||||
"""
|
||||
AI Gym Reset function.
|
||||
|
||||
@@ -304,7 +316,7 @@ class Primaite(Env):
|
||||
|
||||
return self.env_obs
|
||||
|
||||
def step(self, action):
|
||||
def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict]:
|
||||
"""
|
||||
AI Gym Step function.
|
||||
|
||||
@@ -423,7 +435,7 @@ class Primaite(Env):
|
||||
# Return
|
||||
return self.env_obs, reward, done, self.step_info
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Override parent close and close writers."""
|
||||
# Close files if last episode/step
|
||||
# if self.can_finish:
|
||||
@@ -432,18 +444,18 @@ class Primaite(Env):
|
||||
self.transaction_writer.close()
|
||||
self.episode_av_reward_writer.close()
|
||||
|
||||
def init_acl(self):
|
||||
def init_acl(self) -> None:
|
||||
"""Initialise the Access Control List."""
|
||||
self.acl.remove_all_rules()
|
||||
|
||||
def output_link_status(self):
|
||||
def output_link_status(self) -> None:
|
||||
"""Output the link status of all links to the console."""
|
||||
for link_key, link_value in self.links.items():
|
||||
_LOGGER.debug("Link ID: " + link_value.get_id())
|
||||
for protocol in link_value.protocol_list:
|
||||
print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load()))
|
||||
|
||||
def interpret_action_and_apply(self, _action):
|
||||
def interpret_action_and_apply(self, _action: int) -> None:
|
||||
"""
|
||||
Applies agent actions to the nodes and Access Control List.
|
||||
|
||||
@@ -462,7 +474,7 @@ class Primaite(Env):
|
||||
else:
|
||||
logging.error("Invalid action type found")
|
||||
|
||||
def apply_actions_to_nodes(self, _action):
|
||||
def apply_actions_to_nodes(self, _action: int) -> None:
|
||||
"""
|
||||
Applies agent actions to the nodes.
|
||||
|
||||
@@ -550,7 +562,7 @@ class Primaite(Env):
|
||||
else:
|
||||
return
|
||||
|
||||
def apply_actions_to_acl(self, _action):
|
||||
def apply_actions_to_acl(self, _action: int) -> None:
|
||||
"""
|
||||
Applies agent actions to the Access Control List [TO DO].
|
||||
|
||||
@@ -630,7 +642,7 @@ class Primaite(Env):
|
||||
else:
|
||||
return
|
||||
|
||||
def apply_time_based_updates(self):
|
||||
def apply_time_based_updates(self) -> None:
|
||||
"""
|
||||
Updates anything that needs to count down and then change state.
|
||||
|
||||
@@ -686,12 +698,12 @@ class Primaite(Env):
|
||||
|
||||
return self.obs_handler.space, self.obs_handler.current_observation
|
||||
|
||||
def update_environent_obs(self):
|
||||
def update_environent_obs(self) -> None:
|
||||
"""Updates the observation space based on the node and link status."""
|
||||
self.obs_handler.update_obs()
|
||||
self.env_obs = self.obs_handler.current_observation
|
||||
|
||||
def load_lay_down_config(self):
|
||||
def load_lay_down_config(self) -> None:
|
||||
"""Loads config data in order to build the environment configuration."""
|
||||
for item in self.lay_down_config:
|
||||
if item["item_type"] == "NODE":
|
||||
@@ -729,7 +741,7 @@ class Primaite(Env):
|
||||
_LOGGER.info("Environment configuration loaded")
|
||||
print("Environment configuration loaded")
|
||||
|
||||
def create_node(self, item):
|
||||
def create_node(self, item: Dict) -> None:
|
||||
"""
|
||||
Creates a node from config data.
|
||||
|
||||
@@ -810,7 +822,7 @@ class Primaite(Env):
|
||||
# Add node to network (reference)
|
||||
self.network_reference.add_nodes_from([node_ref])
|
||||
|
||||
def create_link(self, item: Dict):
|
||||
def create_link(self, item: Dict) -> None:
|
||||
"""
|
||||
Creates a link from config data.
|
||||
|
||||
@@ -854,7 +866,7 @@ class Primaite(Env):
|
||||
self.services_list,
|
||||
)
|
||||
|
||||
def create_green_ier(self, item):
|
||||
def create_green_ier(self, item: Dict) -> None:
|
||||
"""
|
||||
Creates a green IER from config data.
|
||||
|
||||
@@ -895,7 +907,7 @@ class Primaite(Env):
|
||||
ier_mission_criticality,
|
||||
)
|
||||
|
||||
def create_red_ier(self, item):
|
||||
def create_red_ier(self, item: Dict) -> None:
|
||||
"""
|
||||
Creates a red IER from config data.
|
||||
|
||||
@@ -925,7 +937,7 @@ class Primaite(Env):
|
||||
ier_mission_criticality,
|
||||
)
|
||||
|
||||
def create_green_pol(self, item):
|
||||
def create_green_pol(self, item: Dict) -> None:
|
||||
"""
|
||||
Creates a green PoL object from config data.
|
||||
|
||||
@@ -959,7 +971,7 @@ class Primaite(Env):
|
||||
pol_state,
|
||||
)
|
||||
|
||||
def create_red_pol(self, item):
|
||||
def create_red_pol(self, item: Dict) -> None:
|
||||
"""
|
||||
Creates a red PoL object from config data.
|
||||
|
||||
@@ -1000,7 +1012,7 @@ class Primaite(Env):
|
||||
pol_source_node_service_state,
|
||||
)
|
||||
|
||||
def create_acl_rule(self, item):
|
||||
def create_acl_rule(self, item: Dict) -> None:
|
||||
"""
|
||||
Creates an ACL rule from config data.
|
||||
|
||||
@@ -1023,7 +1035,8 @@ class Primaite(Env):
|
||||
acl_rule_position,
|
||||
)
|
||||
|
||||
def create_services_list(self, services):
|
||||
# TODO: confirm typehint using runtime
|
||||
def create_services_list(self, services: Dict) -> None:
|
||||
"""
|
||||
Creates a list of services (enum) from config data.
|
||||
|
||||
@@ -1039,7 +1052,7 @@ class Primaite(Env):
|
||||
# Set the number of services
|
||||
self.num_services = len(self.services_list)
|
||||
|
||||
def create_ports_list(self, ports):
|
||||
def create_ports_list(self, ports: Dict) -> None:
|
||||
"""
|
||||
Creates a list of ports from config data.
|
||||
|
||||
@@ -1055,7 +1068,8 @@ class Primaite(Env):
|
||||
# Set the number of ports
|
||||
self.num_ports = len(self.ports_list)
|
||||
|
||||
def get_observation_info(self, observation_info):
|
||||
# TODO: this is not used anymore, write a ticket to delete it
|
||||
def get_observation_info(self, observation_info: Dict) -> None:
|
||||
"""
|
||||
Extracts observation_info.
|
||||
|
||||
@@ -1064,7 +1078,8 @@ class Primaite(Env):
|
||||
"""
|
||||
self.observation_type = ObservationType[observation_info["type"]]
|
||||
|
||||
def get_action_info(self, action_info):
|
||||
# TODO: this is not used anymore, write a ticket to delete it.
|
||||
def get_action_info(self, action_info: Dict) -> None:
|
||||
"""
|
||||
Extracts action_info.
|
||||
|
||||
@@ -1073,7 +1088,7 @@ class Primaite(Env):
|
||||
"""
|
||||
self.action_type = ActionType[action_info["type"]]
|
||||
|
||||
def save_obs_config(self, obs_config: dict):
|
||||
def save_obs_config(self, obs_config: dict) -> None:
|
||||
"""
|
||||
Cache the config for the observation space.
|
||||
|
||||
@@ -1086,7 +1101,7 @@ class Primaite(Env):
|
||||
"""
|
||||
self.obs_config = obs_config
|
||||
|
||||
def reset_environment(self):
|
||||
def reset_environment(self) -> None:
|
||||
"""
|
||||
Resets environment.
|
||||
|
||||
@@ -1111,7 +1126,7 @@ class Primaite(Env):
|
||||
for ier_key, ier_value in self.red_iers.items():
|
||||
ier_value.set_is_running(False)
|
||||
|
||||
def reset_node(self, item):
|
||||
def reset_node(self, item: Dict) -> None:
|
||||
"""
|
||||
Resets the statuses of a node.
|
||||
|
||||
@@ -1159,7 +1174,7 @@ class Primaite(Env):
|
||||
# Bad formatting
|
||||
pass
|
||||
|
||||
def create_node_action_dict(self):
|
||||
def create_node_action_dict(self) -> Dict[int, List[int]]:
|
||||
"""
|
||||
Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.
|
||||
|
||||
@@ -1199,7 +1214,7 @@ class Primaite(Env):
|
||||
|
||||
return actions
|
||||
|
||||
def create_acl_action_dict(self):
|
||||
def create_acl_action_dict(self) -> Dict[int, List[int]]:
|
||||
"""Creates a dictionary mapping each possible discrete action to more readable multidiscrete action."""
|
||||
# Terms (for ACL action space):
|
||||
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
|
||||
@@ -1240,7 +1255,7 @@ class Primaite(Env):
|
||||
|
||||
return actions
|
||||
|
||||
def create_node_and_acl_action_dict(self):
|
||||
def create_node_and_acl_action_dict(self) -> Dict[int, List[int]]:
|
||||
"""
|
||||
Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action.
|
||||
|
||||
@@ -1257,7 +1272,7 @@ class Primaite(Env):
|
||||
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
|
||||
return combined_action_dict
|
||||
|
||||
def _create_random_red_agent(self):
|
||||
def _create_random_red_agent(self) -> None:
|
||||
"""Decide on random red agent for the episode to be called in env.reset()."""
|
||||
# Reset the current red iers and red node pol
|
||||
self.red_iers = {}
|
||||
|
||||
@@ -1,25 +1,31 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Implements reward function."""
|
||||
from typing import Dict
|
||||
from logging import Logger
|
||||
from typing import Dict, TYPE_CHECKING, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
|
||||
from primaite.common.service import Service
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
if TYPE_CHECKING:
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_reward_function(
|
||||
initial_nodes,
|
||||
final_nodes,
|
||||
reference_nodes,
|
||||
green_iers,
|
||||
green_iers_reference,
|
||||
red_iers,
|
||||
step_count,
|
||||
config_values,
|
||||
initial_nodes: Dict[str, NodeUnion],
|
||||
final_nodes: Dict[str, NodeUnion],
|
||||
reference_nodes: Dict[str, NodeUnion],
|
||||
green_iers: Dict[str, "IER"],
|
||||
green_iers_reference: Dict[str, "IER"],
|
||||
red_iers: Dict[str, "IER"],
|
||||
step_count: int,
|
||||
config_values: "TrainingConfig",
|
||||
) -> float:
|
||||
"""
|
||||
Compares the states of the initial and final nodes/links to get a reward.
|
||||
@@ -93,7 +99,9 @@ def calculate_reward_function(
|
||||
return reward_value
|
||||
|
||||
|
||||
def score_node_operating_state(final_node, initial_node, reference_node, config_values) -> float:
|
||||
def score_node_operating_state(
|
||||
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig"
|
||||
) -> float:
|
||||
"""
|
||||
Calculates score relating to the hardware state of a node.
|
||||
|
||||
@@ -142,7 +150,12 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
|
||||
return score
|
||||
|
||||
|
||||
def score_node_os_state(final_node, initial_node, reference_node, config_values) -> float:
|
||||
def score_node_os_state(
|
||||
final_node: Union[ActiveNode, ServiceNode],
|
||||
initial_node: Union[ActiveNode, ServiceNode],
|
||||
reference_node: Union[ActiveNode, ServiceNode],
|
||||
config_values: "TrainingConfig",
|
||||
) -> float:
|
||||
"""
|
||||
Calculates score relating to the Software State of a node.
|
||||
|
||||
@@ -193,7 +206,9 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
|
||||
return score
|
||||
|
||||
|
||||
def score_node_service_state(final_node, initial_node, reference_node, config_values) -> float:
|
||||
def score_node_service_state(
|
||||
final_node: ServiceNode, initial_node: ServiceNode, reference_node: ServiceNode, config_values: "TrainingConfig"
|
||||
) -> float:
|
||||
"""
|
||||
Calculates score relating to the service state(s) of a node.
|
||||
|
||||
@@ -265,7 +280,12 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
|
||||
return score
|
||||
|
||||
|
||||
def score_node_file_system(final_node, initial_node, reference_node, config_values) -> float:
|
||||
def score_node_file_system(
|
||||
final_node: Union[ActiveNode, ServiceNode],
|
||||
initial_node: Union[ActiveNode, ServiceNode],
|
||||
reference_node: Union[ActiveNode, ServiceNode],
|
||||
config_values: "TrainingConfig",
|
||||
) -> float:
|
||||
"""
|
||||
Calculates score relating to the file system state of a node.
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Network connections between nodes in the simulation."""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""The link class."""
|
||||
from typing import List
|
||||
|
||||
@@ -8,7 +8,7 @@ from primaite.common.protocol import Protocol
|
||||
class Link(object):
|
||||
"""Link class."""
|
||||
|
||||
def __init__(self, _id, _bandwidth, _source_node_name, _dest_node_name, _services):
|
||||
def __init__(self, _id: str, _bandwidth: int, _source_node_name: str, _dest_node_name: str, _services: str) -> None:
|
||||
"""
|
||||
Initialise a Link within the simulated network.
|
||||
|
||||
@@ -18,17 +18,17 @@ class Link(object):
|
||||
:param _dest_node_name: The name of the destination node
|
||||
:param _protocols: The protocols to add to the link
|
||||
"""
|
||||
self.id = _id
|
||||
self.bandwidth = _bandwidth
|
||||
self.source_node_name = _source_node_name
|
||||
self.dest_node_name = _dest_node_name
|
||||
self.id: str = _id
|
||||
self.bandwidth: int = _bandwidth
|
||||
self.source_node_name: str = _source_node_name
|
||||
self.dest_node_name: str = _dest_node_name
|
||||
self.protocol_list: List[Protocol] = []
|
||||
|
||||
# Add the default protocols
|
||||
for protocol_name in _services:
|
||||
self.add_protocol(protocol_name)
|
||||
|
||||
def add_protocol(self, _protocol):
|
||||
def add_protocol(self, _protocol: str) -> None:
|
||||
"""
|
||||
Adds a new protocol to the list of protocols on this link.
|
||||
|
||||
@@ -37,7 +37,7 @@ class Link(object):
|
||||
"""
|
||||
self.protocol_list.append(Protocol(_protocol))
|
||||
|
||||
def get_id(self):
|
||||
def get_id(self) -> str:
|
||||
"""
|
||||
Gets link ID.
|
||||
|
||||
@@ -46,7 +46,7 @@ class Link(object):
|
||||
"""
|
||||
return self.id
|
||||
|
||||
def get_source_node_name(self):
|
||||
def get_source_node_name(self) -> str:
|
||||
"""
|
||||
Gets source node name.
|
||||
|
||||
@@ -55,7 +55,7 @@ class Link(object):
|
||||
"""
|
||||
return self.source_node_name
|
||||
|
||||
def get_dest_node_name(self):
|
||||
def get_dest_node_name(self) -> str:
|
||||
"""
|
||||
Gets destination node name.
|
||||
|
||||
@@ -64,7 +64,7 @@ class Link(object):
|
||||
"""
|
||||
return self.dest_node_name
|
||||
|
||||
def get_bandwidth(self):
|
||||
def get_bandwidth(self) -> int:
|
||||
"""
|
||||
Gets bandwidth of link.
|
||||
|
||||
@@ -73,7 +73,7 @@ class Link(object):
|
||||
"""
|
||||
return self.bandwidth
|
||||
|
||||
def get_protocol_list(self):
|
||||
def get_protocol_list(self) -> List[Protocol]:
|
||||
"""
|
||||
Gets list of protocols on this link.
|
||||
|
||||
@@ -82,7 +82,7 @@ class Link(object):
|
||||
"""
|
||||
return self.protocol_list
|
||||
|
||||
def get_current_load(self):
|
||||
def get_current_load(self) -> int:
|
||||
"""
|
||||
Gets current total load on this link.
|
||||
|
||||
@@ -94,7 +94,7 @@ class Link(object):
|
||||
total_load += protocol.get_load()
|
||||
return total_load
|
||||
|
||||
def add_protocol_load(self, _protocol, _load):
|
||||
def add_protocol_load(self, _protocol: str, _load: int) -> None:
|
||||
"""
|
||||
Adds a loading to a protocol on this link.
|
||||
|
||||
@@ -108,7 +108,7 @@ class Link(object):
|
||||
else:
|
||||
pass
|
||||
|
||||
def clear_traffic(self):
|
||||
def clear_traffic(self) -> None:
|
||||
"""Clears all traffic on this link."""
|
||||
for protocol in self.protocol_list:
|
||||
protocol.clear_load()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""The main PrimAITE session runner module."""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
@@ -14,7 +14,7 @@ def run(
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Run the PrimAITE Session.
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Nodes represent network hosts in the simulation."""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""An Active Node (i.e. not an actuator)."""
|
||||
import logging
|
||||
from typing import Final
|
||||
@@ -24,7 +24,7 @@ class ActiveNode(Node):
|
||||
software_state: SoftwareState,
|
||||
file_system_state: FileSystemState,
|
||||
config_values: TrainingConfig,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an active node.
|
||||
|
||||
@@ -60,7 +60,7 @@ class ActiveNode(Node):
|
||||
return self._software_state
|
||||
|
||||
@software_state.setter
|
||||
def software_state(self, software_state: SoftwareState):
|
||||
def software_state(self, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Get the software_state.
|
||||
|
||||
@@ -79,7 +79,7 @@ class ActiveNode(Node):
|
||||
f"Node.software_state:{self._software_state}"
|
||||
)
|
||||
|
||||
def set_software_state_if_not_compromised(self, software_state: SoftwareState):
|
||||
def set_software_state_if_not_compromised(self, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Sets Software State if the node is not compromised.
|
||||
|
||||
@@ -99,14 +99,14 @@ class ActiveNode(Node):
|
||||
f"Node.software_state:{self._software_state}"
|
||||
)
|
||||
|
||||
def update_os_patching_status(self):
|
||||
def update_os_patching_status(self) -> None:
|
||||
"""Updates operating system status based on patching cycle."""
|
||||
self.patching_count -= 1
|
||||
if self.patching_count <= 0:
|
||||
self.patching_count = 0
|
||||
self._software_state = SoftwareState.GOOD
|
||||
|
||||
def set_file_system_state(self, file_system_state: FileSystemState):
|
||||
def set_file_system_state(self, file_system_state: FileSystemState) -> None:
|
||||
"""
|
||||
Sets the file system state (actual and observed).
|
||||
|
||||
@@ -133,7 +133,7 @@ class ActiveNode(Node):
|
||||
f"Node.file_system_state.actual:{self.file_system_state_actual}"
|
||||
)
|
||||
|
||||
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState):
|
||||
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState) -> None:
|
||||
"""
|
||||
Sets the file system state (actual and observed) if not in a compromised state.
|
||||
|
||||
@@ -166,12 +166,12 @@ class ActiveNode(Node):
|
||||
f"Node.file_system_state.actual:{self.file_system_state_actual}"
|
||||
)
|
||||
|
||||
def start_file_system_scan(self):
|
||||
def start_file_system_scan(self) -> None:
|
||||
"""Starts a file system scan."""
|
||||
self.file_system_scanning = True
|
||||
self.file_system_scanning_count = self.config_values.file_system_scanning_limit
|
||||
|
||||
def update_file_system_state(self):
|
||||
def update_file_system_state(self) -> None:
|
||||
"""Updates file system status based on scanning/restore/repair cycle."""
|
||||
# Deprecate both the action count (for restoring or reparing) and the scanning count
|
||||
self.file_system_action_count -= 1
|
||||
@@ -193,14 +193,14 @@ class ActiveNode(Node):
|
||||
self.file_system_scanning = False
|
||||
self.file_system_scanning_count = 0
|
||||
|
||||
def update_resetting_status(self):
|
||||
def update_resetting_status(self) -> None:
|
||||
"""Updates the reset count & makes software and file state to GOOD."""
|
||||
super().update_resetting_status()
|
||||
if self.resetting_count <= 0:
|
||||
self.file_system_state_actual = FileSystemState.GOOD
|
||||
self.software_state = SoftwareState.GOOD
|
||||
|
||||
def update_booting_status(self):
|
||||
def update_booting_status(self) -> None:
|
||||
"""Updates the booting software and file state to GOOD."""
|
||||
super().update_booting_status()
|
||||
if self.booting_count <= 0:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""The base Node class."""
|
||||
from typing import Final
|
||||
|
||||
@@ -17,7 +17,7 @@ class Node:
|
||||
priority: Priority,
|
||||
hardware_state: HardwareState,
|
||||
config_values: TrainingConfig,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a node.
|
||||
|
||||
@@ -38,40 +38,40 @@ class Node:
|
||||
self.booting_count: int = 0
|
||||
self.shutting_down_count: int = 0
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
"""Returns the name of the node."""
|
||||
return self.name
|
||||
|
||||
def turn_on(self):
|
||||
def turn_on(self) -> None:
|
||||
"""Sets the node state to ON."""
|
||||
self.hardware_state = HardwareState.BOOTING
|
||||
self.booting_count = self.config_values.node_booting_duration
|
||||
|
||||
def turn_off(self):
|
||||
def turn_off(self) -> None:
|
||||
"""Sets the node state to OFF."""
|
||||
self.hardware_state = HardwareState.OFF
|
||||
self.shutting_down_count = self.config_values.node_shutdown_duration
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
"""Sets the node state to Resetting and starts the reset count."""
|
||||
self.hardware_state = HardwareState.RESETTING
|
||||
self.resetting_count = self.config_values.node_reset_duration
|
||||
|
||||
def update_resetting_status(self):
|
||||
def update_resetting_status(self) -> None:
|
||||
"""Updates the resetting count."""
|
||||
self.resetting_count -= 1
|
||||
if self.resetting_count <= 0:
|
||||
self.resetting_count = 0
|
||||
self.hardware_state = HardwareState.ON
|
||||
|
||||
def update_booting_status(self):
|
||||
def update_booting_status(self) -> None:
|
||||
"""Updates the booting count."""
|
||||
self.booting_count -= 1
|
||||
if self.booting_count <= 0:
|
||||
self.booting_count = 0
|
||||
self.hardware_state = HardwareState.ON
|
||||
|
||||
def update_shutdown_status(self):
|
||||
def update_shutdown_status(self) -> None:
|
||||
"""Updates the shutdown count."""
|
||||
self.shutting_down_count -= 1
|
||||
if self.shutting_down_count <= 0:
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Defines node behaviour for Green PoL."""
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodePOLType, SoftwareState
|
||||
|
||||
|
||||
class NodeStateInstructionGreen(object):
|
||||
@@ -7,14 +11,14 @@ class NodeStateInstructionGreen(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_id,
|
||||
_start_step,
|
||||
_end_step,
|
||||
_node_id,
|
||||
_node_pol_type,
|
||||
_service_name,
|
||||
_state,
|
||||
):
|
||||
_id: str,
|
||||
_start_step: int,
|
||||
_end_step: int,
|
||||
_node_id: str,
|
||||
_node_pol_type: "NodePOLType",
|
||||
_service_name: str,
|
||||
_state: Union["HardwareState", "SoftwareState", "FileSystemState"],
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the Node State Instruction.
|
||||
|
||||
@@ -30,11 +34,12 @@ class NodeStateInstructionGreen(object):
|
||||
self.start_step = _start_step
|
||||
self.end_step = _end_step
|
||||
self.node_id = _node_id
|
||||
self.node_pol_type = _node_pol_type
|
||||
self.service_name = _service_name # Not used when not a service instruction
|
||||
self.state = _state
|
||||
self.node_pol_type: "NodePOLType" = _node_pol_type
|
||||
self.service_name: str = _service_name # Not used when not a service instruction
|
||||
# TODO: confirm type of state
|
||||
self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _state
|
||||
|
||||
def get_start_step(self):
|
||||
def get_start_step(self) -> int:
|
||||
"""
|
||||
Gets the start step.
|
||||
|
||||
@@ -43,7 +48,7 @@ class NodeStateInstructionGreen(object):
|
||||
"""
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self):
|
||||
def get_end_step(self) -> int:
|
||||
"""
|
||||
Gets the end step.
|
||||
|
||||
@@ -52,7 +57,7 @@ class NodeStateInstructionGreen(object):
|
||||
"""
|
||||
return self.end_step
|
||||
|
||||
def get_node_id(self):
|
||||
def get_node_id(self) -> str:
|
||||
"""
|
||||
Gets the node ID.
|
||||
|
||||
@@ -61,7 +66,7 @@ class NodeStateInstructionGreen(object):
|
||||
"""
|
||||
return self.node_id
|
||||
|
||||
def get_node_pol_type(self):
|
||||
def get_node_pol_type(self) -> "NodePOLType":
|
||||
"""
|
||||
Gets the node pattern of life type (enum).
|
||||
|
||||
@@ -70,7 +75,7 @@ class NodeStateInstructionGreen(object):
|
||||
"""
|
||||
return self.node_pol_type
|
||||
|
||||
def get_service_name(self):
|
||||
def get_service_name(self) -> str:
|
||||
"""
|
||||
Gets the service name.
|
||||
|
||||
@@ -79,7 +84,7 @@ class NodeStateInstructionGreen(object):
|
||||
"""
|
||||
return self.service_name
|
||||
|
||||
def get_state(self):
|
||||
def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]:
|
||||
"""
|
||||
Gets the state (node or service).
|
||||
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Defines node behaviour for Green PoL."""
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from primaite.common.enums import NodePOLType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodePOLInitiator, SoftwareState
|
||||
|
||||
|
||||
@dataclass()
|
||||
class NodeStateInstructionRed(object):
|
||||
@@ -11,18 +15,18 @@ class NodeStateInstructionRed(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_id,
|
||||
_start_step,
|
||||
_end_step,
|
||||
_target_node_id,
|
||||
_pol_initiator,
|
||||
_id: str,
|
||||
_start_step: int,
|
||||
_end_step: int,
|
||||
_target_node_id: str,
|
||||
_pol_initiator: "NodePOLInitiator",
|
||||
_pol_type: NodePOLType,
|
||||
pol_protocol,
|
||||
_pol_state,
|
||||
_pol_source_node_id,
|
||||
_pol_source_node_service,
|
||||
_pol_source_node_service_state,
|
||||
):
|
||||
pol_protocol: str,
|
||||
_pol_state: Union["HardwareState", "SoftwareState", "FileSystemState"],
|
||||
_pol_source_node_id: str,
|
||||
_pol_source_node_service: str,
|
||||
_pol_source_node_service_state: str,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the Node State Instruction for the red agent.
|
||||
|
||||
@@ -38,19 +42,19 @@ class NodeStateInstructionRed(object):
|
||||
:param _pol_source_node_service: The source node service (used for initiator type SERVICE)
|
||||
:param _pol_source_node_service_state: The source node service state (used for initiator type SERVICE)
|
||||
"""
|
||||
self.id = _id
|
||||
self.start_step = _start_step
|
||||
self.end_step = _end_step
|
||||
self.target_node_id = _target_node_id
|
||||
self.initiator = _pol_initiator
|
||||
self.id: str = _id
|
||||
self.start_step: int = _start_step
|
||||
self.end_step: int = _end_step
|
||||
self.target_node_id: str = _target_node_id
|
||||
self.initiator: "NodePOLInitiator" = _pol_initiator
|
||||
self.pol_type: NodePOLType = _pol_type
|
||||
self.service_name = pol_protocol # Not used when not a service instruction
|
||||
self.state = _pol_state
|
||||
self.source_node_id = _pol_source_node_id
|
||||
self.source_node_service = _pol_source_node_service
|
||||
self.service_name: str = pol_protocol # Not used when not a service instruction
|
||||
self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _pol_state
|
||||
self.source_node_id: str = _pol_source_node_id
|
||||
self.source_node_service: str = _pol_source_node_service
|
||||
self.source_node_service_state = _pol_source_node_service_state
|
||||
|
||||
def get_start_step(self):
|
||||
def get_start_step(self) -> int:
|
||||
"""
|
||||
Gets the start step.
|
||||
|
||||
@@ -59,7 +63,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self):
|
||||
def get_end_step(self) -> int:
|
||||
"""
|
||||
Gets the end step.
|
||||
|
||||
@@ -68,7 +72,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.end_step
|
||||
|
||||
def get_target_node_id(self):
|
||||
def get_target_node_id(self) -> str:
|
||||
"""
|
||||
Gets the node ID.
|
||||
|
||||
@@ -77,7 +81,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.target_node_id
|
||||
|
||||
def get_initiator(self):
|
||||
def get_initiator(self) -> "NodePOLInitiator":
|
||||
"""
|
||||
Gets the initiator.
|
||||
|
||||
@@ -95,7 +99,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.pol_type
|
||||
|
||||
def get_service_name(self):
|
||||
def get_service_name(self) -> str:
|
||||
"""
|
||||
Gets the service name.
|
||||
|
||||
@@ -104,7 +108,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.service_name
|
||||
|
||||
def get_state(self):
|
||||
def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]:
|
||||
"""
|
||||
Gets the state (node or service).
|
||||
|
||||
@@ -113,7 +117,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.state
|
||||
|
||||
def get_source_node_id(self):
|
||||
def get_source_node_id(self) -> str:
|
||||
"""
|
||||
Gets the source node id (used for initiator type SERVICE).
|
||||
|
||||
@@ -122,7 +126,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.source_node_id
|
||||
|
||||
def get_source_node_service(self):
|
||||
def get_source_node_service(self) -> str:
|
||||
"""
|
||||
Gets the source node service (used for initiator type SERVICE).
|
||||
|
||||
@@ -131,7 +135,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.source_node_service
|
||||
|
||||
def get_source_node_service_state(self):
|
||||
def get_source_node_service_state(self) -> str:
|
||||
"""
|
||||
Gets the source node service state (used for initiator type SERVICE).
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""The Passive Node class (i.e. an actuator)."""
|
||||
from primaite.common.enums import HardwareState, NodeType, Priority
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
@@ -16,7 +16,7 @@ class PassiveNode(Node):
|
||||
priority: Priority,
|
||||
hardware_state: HardwareState,
|
||||
config_values: TrainingConfig,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a passive node.
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""A Service Node (i.e. not an actuator)."""
|
||||
import logging
|
||||
from typing import Dict, Final
|
||||
@@ -25,7 +25,7 @@ class ServiceNode(ActiveNode):
|
||||
software_state: SoftwareState,
|
||||
file_system_state: FileSystemState,
|
||||
config_values: TrainingConfig,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a Service Node.
|
||||
|
||||
@@ -52,7 +52,7 @@ class ServiceNode(ActiveNode):
|
||||
)
|
||||
self.services: Dict[str, Service] = {}
|
||||
|
||||
def add_service(self, service: Service):
|
||||
def add_service(self, service: Service) -> None:
|
||||
"""
|
||||
Adds a service to the node.
|
||||
|
||||
@@ -102,7 +102,7 @@ class ServiceNode(ActiveNode):
|
||||
return False
|
||||
return False
|
||||
|
||||
def set_service_state(self, protocol_name: str, software_state: SoftwareState):
|
||||
def set_service_state(self, protocol_name: str, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Sets the software_state of a service (protocol) on the node.
|
||||
|
||||
@@ -131,7 +131,7 @@ class ServiceNode(ActiveNode):
|
||||
f"Node.services[<key>].software_state:{software_state}"
|
||||
)
|
||||
|
||||
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState):
|
||||
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Sets the software_state of a service (protocol) on the node.
|
||||
|
||||
@@ -158,7 +158,7 @@ class ServiceNode(ActiveNode):
|
||||
f"Node.services[<key>].software_state:{software_state}"
|
||||
)
|
||||
|
||||
def get_service_state(self, protocol_name):
|
||||
def get_service_state(self, protocol_name: str) -> SoftwareState:
|
||||
"""
|
||||
Gets the state of a service.
|
||||
|
||||
@@ -169,20 +169,20 @@ class ServiceNode(ActiveNode):
|
||||
if service_value:
|
||||
return service_value.software_state
|
||||
|
||||
def update_services_patching_status(self):
|
||||
def update_services_patching_status(self) -> None:
|
||||
"""Updates the patching counter for any service that are patching."""
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_value.software_state == SoftwareState.PATCHING:
|
||||
service_value.reduce_patching_count()
|
||||
|
||||
def update_resetting_status(self):
|
||||
def update_resetting_status(self) -> None:
|
||||
"""Update resetting counter and set software state if it reached 0."""
|
||||
super().update_resetting_status()
|
||||
if self.resetting_count <= 0:
|
||||
for service in self.services.values():
|
||||
service.software_state = SoftwareState.GOOD
|
||||
|
||||
def update_booting_status(self):
|
||||
def update_booting_status(self) -> None:
|
||||
"""Update booting counter and set software to good if it reached 0."""
|
||||
super().update_booting_status()
|
||||
if self.booting_count <= 0:
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Contains default jupyter notebooks which demonstrate PrimAITE functionality."""
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
|
||||
import importlib.util
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from logging import Logger
|
||||
|
||||
from primaite import getLogger, NOTEBOOKS_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def start_jupyter_session():
|
||||
def start_jupyter_session() -> None:
|
||||
"""
|
||||
Starts a new Jupyter notebook session in the app notebooks directory.
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Pattern of Life- Represents the actions of users on the network."""
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Implements Pattern of Life on the network (nodes and links)."""
|
||||
from typing import Dict, Union
|
||||
from typing import Dict
|
||||
|
||||
from networkx import MultiGraph, shortest_path
|
||||
|
||||
@@ -10,11 +10,10 @@ from primaite.common.enums import HardwareState, NodePOLType, NodeType, Software
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
_VERBOSE = False
|
||||
_VERBOSE: bool = False
|
||||
|
||||
|
||||
def apply_iers(
|
||||
@@ -24,7 +23,7 @@ def apply_iers(
|
||||
iers: Dict[str, IER],
|
||||
acl: AccessControlList,
|
||||
step: int,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Applies IERs to the links (link pattern of life).
|
||||
|
||||
@@ -65,6 +64,8 @@ def apply_iers(
|
||||
dest_node = nodes[dest_node_id]
|
||||
|
||||
# 1. Check the source node situation
|
||||
# TODO: should be using isinstance rather than checking node type attribute. IE. just because it's a switch
|
||||
# doesn't mean it has a software state? It could be a PassiveNode or ActiveNode
|
||||
if source_node.node_type == NodeType.SWITCH:
|
||||
# It's a switch
|
||||
if (
|
||||
@@ -215,9 +216,9 @@ def apply_iers(
|
||||
|
||||
def apply_node_pol(
|
||||
nodes: Dict[str, NodeUnion],
|
||||
node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]],
|
||||
node_pol: Dict[str, NodeStateInstructionGreen],
|
||||
step: int,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Applies node pattern of life.
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Information Exchange Requirements for APE.
|
||||
|
||||
@@ -11,17 +11,17 @@ class IER(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_id,
|
||||
_start_step,
|
||||
_end_step,
|
||||
_load,
|
||||
_protocol,
|
||||
_port,
|
||||
_source_node_id,
|
||||
_dest_node_id,
|
||||
_mission_criticality,
|
||||
_running=False,
|
||||
):
|
||||
_id: str,
|
||||
_start_step: int,
|
||||
_end_step: int,
|
||||
_load: int,
|
||||
_protocol: str,
|
||||
_port: str,
|
||||
_source_node_id: str,
|
||||
_dest_node_id: str,
|
||||
_mission_criticality: int,
|
||||
_running: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an Information Exchange Request.
|
||||
|
||||
@@ -36,18 +36,18 @@ class IER(object):
|
||||
:param _mission_criticality: Criticality of this IER to the mission (0 none, 5 mission critical)
|
||||
:param _running: Indicates whether the IER is currently running
|
||||
"""
|
||||
self.id = _id
|
||||
self.start_step = _start_step
|
||||
self.end_step = _end_step
|
||||
self.source_node_id = _source_node_id
|
||||
self.dest_node_id = _dest_node_id
|
||||
self.load = _load
|
||||
self.protocol = _protocol
|
||||
self.port = _port
|
||||
self.mission_criticality = _mission_criticality
|
||||
self.running = _running
|
||||
self.id: str = _id
|
||||
self.start_step: int = _start_step
|
||||
self.end_step: int = _end_step
|
||||
self.source_node_id: str = _source_node_id
|
||||
self.dest_node_id: str = _dest_node_id
|
||||
self.load: int = _load
|
||||
self.protocol: str = _protocol
|
||||
self.port: str = _port
|
||||
self.mission_criticality: int = _mission_criticality
|
||||
self.running: bool = _running
|
||||
|
||||
def get_id(self):
|
||||
def get_id(self) -> str:
|
||||
"""
|
||||
Gets IER ID.
|
||||
|
||||
@@ -56,7 +56,7 @@ class IER(object):
|
||||
"""
|
||||
return self.id
|
||||
|
||||
def get_start_step(self):
|
||||
def get_start_step(self) -> int:
|
||||
"""
|
||||
Gets IER start step.
|
||||
|
||||
@@ -65,7 +65,7 @@ class IER(object):
|
||||
"""
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self):
|
||||
def get_end_step(self) -> int:
|
||||
"""
|
||||
Gets IER end step.
|
||||
|
||||
@@ -74,7 +74,7 @@ class IER(object):
|
||||
"""
|
||||
return self.end_step
|
||||
|
||||
def get_load(self):
|
||||
def get_load(self) -> int:
|
||||
"""
|
||||
Gets IER load.
|
||||
|
||||
@@ -83,7 +83,7 @@ class IER(object):
|
||||
"""
|
||||
return self.load
|
||||
|
||||
def get_protocol(self):
|
||||
def get_protocol(self) -> str:
|
||||
"""
|
||||
Gets IER protocol.
|
||||
|
||||
@@ -92,7 +92,7 @@ class IER(object):
|
||||
"""
|
||||
return self.protocol
|
||||
|
||||
def get_port(self):
|
||||
def get_port(self) -> str:
|
||||
"""
|
||||
Gets IER port.
|
||||
|
||||
@@ -101,7 +101,7 @@ class IER(object):
|
||||
"""
|
||||
return self.port
|
||||
|
||||
def get_source_node_id(self):
|
||||
def get_source_node_id(self) -> str:
|
||||
"""
|
||||
Gets IER source node ID.
|
||||
|
||||
@@ -110,7 +110,7 @@ class IER(object):
|
||||
"""
|
||||
return self.source_node_id
|
||||
|
||||
def get_dest_node_id(self):
|
||||
def get_dest_node_id(self) -> str:
|
||||
"""
|
||||
Gets IER destination node ID.
|
||||
|
||||
@@ -119,7 +119,7 @@ class IER(object):
|
||||
"""
|
||||
return self.dest_node_id
|
||||
|
||||
def get_is_running(self):
|
||||
def get_is_running(self) -> bool:
|
||||
"""
|
||||
Informs whether the IER is currently running.
|
||||
|
||||
@@ -128,7 +128,7 @@ class IER(object):
|
||||
"""
|
||||
return self.running
|
||||
|
||||
def set_is_running(self, _value):
|
||||
def set_is_running(self, _value: bool) -> None:
|
||||
"""
|
||||
Sets the running state of the IER.
|
||||
|
||||
@@ -137,7 +137,7 @@ class IER(object):
|
||||
"""
|
||||
self.running = _value
|
||||
|
||||
def get_mission_criticality(self):
|
||||
def get_mission_criticality(self) -> int:
|
||||
"""
|
||||
Gets the IER mission criticality (used in the reward function).
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Implements POL on the network (nodes and links) resulting from the red agent attack."""
|
||||
from typing import Dict
|
||||
|
||||
from networkx import MultiGraph, shortest_path
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import HardwareState, NodePOLInitiator, NodePOLType, NodeType, SoftwareState
|
||||
@@ -13,7 +14,9 @@ from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
_VERBOSE = False
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
_VERBOSE: bool = False
|
||||
|
||||
|
||||
def apply_red_agent_iers(
|
||||
@@ -23,7 +26,7 @@ def apply_red_agent_iers(
|
||||
iers: Dict[str, IER],
|
||||
acl: AccessControlList,
|
||||
step: int,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Applies IERs to the links (link POL) resulting from red agent attack.
|
||||
|
||||
@@ -74,6 +77,9 @@ def apply_red_agent_iers(
|
||||
pass
|
||||
else:
|
||||
# It's not a switch or an actuator (so active node)
|
||||
# TODO: this occurs after ruling out the possibility that the node is a switch or an actuator, but it
|
||||
# could still be a passive/active node, therefore it won't have a hardware_state. The logic here needs
|
||||
# to change according to duck typing.
|
||||
if source_node.hardware_state == HardwareState.ON:
|
||||
if source_node.has_service(protocol):
|
||||
# Red agents IERs can only be valid if the source service is in a compromised state
|
||||
@@ -213,7 +219,7 @@ def apply_red_agent_node_pol(
|
||||
iers: Dict[str, IER],
|
||||
node_pol: Dict[str, NodeStateInstructionRed],
|
||||
step: int,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Applies node pattern of life.
|
||||
|
||||
@@ -267,8 +273,7 @@ def apply_red_agent_node_pol(
|
||||
# Do nothing, service not on this node
|
||||
pass
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("Node Red Agent PoL not allowed - misconfiguration")
|
||||
_LOGGER.warning("Node Red Agent PoL not allowed - misconfiguration")
|
||||
|
||||
# Only apply the PoL if the checks have passed (based on the initiator type)
|
||||
if passed_checks:
|
||||
@@ -289,8 +294,7 @@ def apply_red_agent_node_pol(
|
||||
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
|
||||
target_node.set_file_system_state(state)
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("Node Red Agent PoL not allowed - did not pass checks")
|
||||
_LOGGER.debug("Node Red Agent PoL not allowed - did not pass checks")
|
||||
else:
|
||||
# PoL is not valid in this time step
|
||||
pass
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Main entry point to PrimAITE. Configure training/evaluation experiments and input/output."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, Final, Optional, Union
|
||||
from typing import Any, Dict, Final, Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
@@ -31,7 +32,7 @@ class PrimaiteSession:
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
The PrimaiteSession constructor.
|
||||
|
||||
@@ -71,7 +72,13 @@ class PrimaiteSession:
|
||||
self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
|
||||
def setup(self):
|
||||
self._agent_session: AgentSessionABC = None # noqa
|
||||
self.session_path: Path = None # noqa
|
||||
self.timestamp_str: str = None # noqa
|
||||
self.learning_path: Path = None # noqa
|
||||
self.evaluation_path: Path = None # noqa
|
||||
|
||||
def setup(self) -> None:
|
||||
"""Performs the session setup."""
|
||||
if self._training_config.agent_framework == AgentFramework.CUSTOM:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}")
|
||||
@@ -154,8 +161,8 @@ class PrimaiteSession:
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
@@ -166,8 +173,8 @@ class PrimaiteSession:
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
@@ -176,6 +183,6 @@ class PrimaiteSession:
|
||||
if not self._training_config.session_type == SessionType.TRAIN:
|
||||
self._agent_session.evaluate(**kwargs)
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Closes the agent."""
|
||||
self._agent_session.close()
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Utilities to prepare the user's data folders."""
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def run():
|
||||
def run() -> None:
|
||||
"""Perform the full clean-up."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import filecmp
|
||||
import os
|
||||
import shutil
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from primaite import getLogger, NOTEBOOKS_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def run(overwrite_existing: bool = True):
|
||||
def run(overwrite_existing: bool = True) -> None:
|
||||
"""
|
||||
Resets the demo jupyter notebooks in the users app notebooks directory.
|
||||
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import filecmp
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from primaite import getLogger, USERS_CONFIG_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def run(overwrite_existing=True):
|
||||
def run(overwrite_existing: bool = True) -> None:
|
||||
"""
|
||||
Resets the example config files in the users app config directory.
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from logging import Logger
|
||||
|
||||
from primaite import _USER_DIRS, getLogger, LOG_DIR, NOTEBOOKS_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def run():
|
||||
def run() -> None:
|
||||
"""
|
||||
Handles creation of application directories and user directories.
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Record data of the system's state and agent's observations and actions."""
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""The Transaction class."""
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from primaite.common.enums import AgentIdentifier
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
|
||||
class Transaction(object):
|
||||
"""Transaction class."""
|
||||
|
||||
def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int):
|
||||
def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int) -> None:
|
||||
"""
|
||||
Transaction constructor.
|
||||
|
||||
@@ -17,7 +21,7 @@ class Transaction(object):
|
||||
:param episode_number: The episode number
|
||||
:param step_number: The step number
|
||||
"""
|
||||
self.timestamp = datetime.now()
|
||||
self.timestamp: datetime = datetime.now()
|
||||
"The datetime of the transaction"
|
||||
self.agent_identifier: AgentIdentifier = agent_identifier
|
||||
"The agent identifier"
|
||||
@@ -25,17 +29,17 @@ class Transaction(object):
|
||||
"The episode number"
|
||||
self.step_number: int = step_number
|
||||
"The step number"
|
||||
self.obs_space = None
|
||||
self.obs_space: "spaces.Space" = None
|
||||
"The observation space (pre)"
|
||||
self.obs_space_pre = None
|
||||
self.obs_space_pre: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None
|
||||
"The observation space before any actions are taken"
|
||||
self.obs_space_post = None
|
||||
self.obs_space_post: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None
|
||||
"The observation space after any actions are taken"
|
||||
self.reward: float = None
|
||||
self.reward: Optional[float] = None
|
||||
"The reward value"
|
||||
self.action_space = None
|
||||
self.action_space: Optional[int] = None
|
||||
"The action space invoked by the agent"
|
||||
self.obs_space_description = None
|
||||
self.obs_space_description: Optional[List[str]] = None
|
||||
"The env observation space description"
|
||||
|
||||
def as_csv_data(self) -> Tuple[List, List]:
|
||||
@@ -68,7 +72,7 @@ class Transaction(object):
|
||||
return header, row
|
||||
|
||||
|
||||
def _turn_action_space_to_array(action_space) -> List[str]:
|
||||
def _turn_action_space_to_array(action_space: Union[int, List[int]]) -> List[str]:
|
||||
"""
|
||||
Turns action space into a string array so it can be saved to csv.
|
||||
|
||||
@@ -81,7 +85,7 @@ def _turn_action_space_to_array(action_space) -> List[str]:
|
||||
return [str(action_space)]
|
||||
|
||||
|
||||
def _turn_obs_space_to_array(obs_space, obs_assets, obs_features) -> List[str]:
|
||||
def _turn_obs_space_to_array(obs_space: "np.ndarray", obs_assets: int, obs_features: int) -> List[str]:
|
||||
"""
|
||||
Turns observation space into a string array so it can be saved to csv.
|
||||
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Utilities for PrimAITE."""
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import os
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def get_file_path(path: str) -> Path:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -9,7 +10,7 @@ from primaite import getLogger
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def parse_session_metadata(session_path: Union[Path, str], dict_only=False):
|
||||
def parse_session_metadata(session_path: Union[Path, str], dict_only: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Loads a session metadata from the given directory path.
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import csv
|
||||
from logging import Logger
|
||||
from typing import Final, List, Tuple, TYPE_CHECKING, Union
|
||||
@@ -6,6 +7,9 @@ from primaite import getLogger
|
||||
from primaite.transactions.transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from io import TextIOWrapper
|
||||
from pathlib import Path
|
||||
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
@@ -28,7 +32,7 @@ class SessionOutputWriter:
|
||||
env: "Primaite",
|
||||
transaction_writer: bool = False,
|
||||
learning_session: bool = True,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the Session Output Writer.
|
||||
|
||||
@@ -41,15 +45,16 @@ class SessionOutputWriter:
|
||||
determines the name of the folder which contains the final output csv. Defaults to True
|
||||
:type learning_session: bool, optional
|
||||
"""
|
||||
self._env = env
|
||||
self.transaction_writer = transaction_writer
|
||||
self.learning_session = learning_session
|
||||
self._env: "Primaite" = env
|
||||
self.transaction_writer: bool = transaction_writer
|
||||
self.learning_session: bool = learning_session
|
||||
|
||||
if self.transaction_writer:
|
||||
fn = f"all_transactions_{self._env.timestamp_str}.csv"
|
||||
else:
|
||||
fn = f"average_reward_per_episode_{self._env.timestamp_str}.csv"
|
||||
|
||||
self._csv_file_path: "Path"
|
||||
if self.learning_session:
|
||||
self._csv_file_path = self._env.session_path / "learning" / fn
|
||||
else:
|
||||
@@ -57,26 +62,26 @@ class SessionOutputWriter:
|
||||
|
||||
self._csv_file_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
self._csv_file = None
|
||||
self._csv_writer = None
|
||||
self._csv_file: "TextIOWrapper" = None
|
||||
self._csv_writer: "csv._writer" = None
|
||||
|
||||
self._first_write: bool = True
|
||||
|
||||
def _init_csv_writer(self):
|
||||
def _init_csv_writer(self) -> None:
|
||||
self._csv_file = open(self._csv_file_path, "w", encoding="UTF8", newline="")
|
||||
|
||||
self._csv_writer = csv.writer(self._csv_file)
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Close the cvs file."""
|
||||
if self._csv_file:
|
||||
self._csv_file.close()
|
||||
_LOGGER.debug(f"Finished writing file: {self._csv_file_path}")
|
||||
|
||||
def write(self, data: Union[Tuple, Transaction]):
|
||||
def write(self, data: Union[Tuple, Transaction]) -> None:
|
||||
"""
|
||||
Write a row of session data.
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
# Main Config File
|
||||
|
||||
# Generic config values
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
# Main Config File
|
||||
|
||||
# Generic config values
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '80'
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '21'
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '80'
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import datetime
|
||||
import json
|
||||
import shutil
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Used to tes the ACL functions."""
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Used to test Active Node functions."""
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Test env creation and behaviour with different observation spaces."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import pytest
|
||||
|
||||
from primaite.config.lay_down_config import data_manipulation_config_path
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Used to test Active Node functions."""
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import pytest
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import pytest
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import pytest as pytest
|
||||
|
||||
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user