Merge remote-tracking branch 'origin/dev' into feature/901-change-functionality-acl-rules

This commit is contained in:
Marek Wolan
2023-07-18 10:55:31 +01:00
105 changed files with 651 additions and 440 deletions

View File

@@ -1,3 +1,7 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
..
Credit to https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion for the custom templates.
..

View File

@@ -1,3 +1,7 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
..
Credit to https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion for the custom templates.
..

View File

@@ -1,3 +1,7 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
..
DO NOT DELETE THIS FILE! It contains the all-important `.. autosummary::` directive with `:recursive:` option, without
which API documentation wouldn't get extracted from docstrings by the `sphinx.ext.autosummary` engine. It is hidden

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Configuration file for the Sphinx documentation builder.
#
# For the full list of built-in configuration values, see the documentation:

View File

@@ -1,3 +1,7 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
Welcome to PrimAITE's documentation
====================================

View File

@@ -1,4 +1,8 @@
.. _about:
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
.. _about:
About PrimAITE
==============

View File

@@ -1,3 +1,7 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
.. _config:
The Config Files Explained

View File

@@ -1,4 +1,8 @@
Custom Agents
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
Custom Agents
=============

View File

@@ -1,3 +1,7 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
.. role:: raw-html(raw)
:format: html

View File

@@ -1,3 +1,7 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
.. _getting-started:
Getting Started

View File

@@ -1,3 +1,7 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
Glossary
=============

View File

@@ -1,3 +1,7 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
v1.2 to v2.0 Migration guide
============================

View File

@@ -1,3 +1,7 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
.. _run a primaite session:
Run a PrimAITE Session
@@ -44,7 +48,8 @@ For example, when running a session at 17:30:00 on 31st January 2023, the sessio
``~/primaite/sessions/2023-01-31/2023-01-31_17-30-00/``.
Loading a session
-------
-----------------
A previous session can be loaded by providing the **directory** of the previous session to either the ``primaite session`` command from the cli
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` with session_path.

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from setuptools import setup
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel # noqa

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import logging
import logging.config
import sys
@@ -6,7 +6,7 @@ from bisect import bisect
from logging import Formatter, Logger, LogRecord, StreamHandler
from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import Dict, Final
from typing import Any, Dict, Final
import pkg_resources
import yaml
@@ -16,7 +16,7 @@ _PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite")
"""An instance of `PlatformDirs` set with appname='primaite'."""
def _get_primaite_config():
def _get_primaite_config() -> Dict:
config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml"
if not config_path.exists():
config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
@@ -72,7 +72,7 @@ class _LevelFormatter(Formatter):
Credit to: https://stackoverflow.com/a/68154386
"""
def __init__(self, formats: Dict[int, str], **kwargs):
def __init__(self, formats: Dict[int, str], **kwargs: Any) -> None:
super().__init__()
if "fmt" in kwargs:

View File

@@ -1,2 +1,2 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Access Control List. Models firewall functionality."""

View File

@@ -1,7 +1,7 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""A class that implements the access control list implementation for the network."""
import logging
from typing import Final, List, Union
from typing import Dict, Final, List, Union
from primaite.acl.acl_rule import ACLRule
from primaite.common.enums import RulePermissionType
@@ -12,7 +12,7 @@ _LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
class AccessControlList:
"""Access Control List class."""
def __init__(self, implicit_permission, max_acl_rules):
def __init__(self, implicit_permission: RulePermissionType, max_acl_rules: int) -> None:
"""Init."""
# Implicit ALLOW or DENY firewall spec
self.acl_implicit_permission = implicit_permission
@@ -30,7 +30,7 @@ class AccessControlList:
self._acl: List[Union[ACLRule, None]] = [None] * (self.max_acl_rules - 1)
@property
def acl(self):
def acl(self) -> List[Union[ACLRule, None]]:
"""Public access method for private _acl."""
return self._acl + [self.acl_implicit_rule]
@@ -84,7 +84,9 @@ class AccessControlList:
# If there has been no rule to allow the IER through, it will return a blocked signal by default
return True
def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port, _position):
def add_rule(
self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str, _position: int
) -> None:
"""
Adds a new rule.
@@ -141,12 +143,12 @@ class AccessControlList:
if isinstance(self._acl[index], ACLRule) and hash(self._acl[index]) == delete_rule_hash:
self._acl[index] = None
def remove_all_rules(self):
def remove_all_rules(self) -> None:
"""Removes all rules."""
for i in range(len(self._acl)):
self._acl[i] = None
def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port):
def get_dictionary_hash(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> int:
"""
Produces a hash value for a rule.
@@ -164,7 +166,9 @@ class AccessControlList:
hash_value = hash(rule)
return hash_value
def get_relevant_rules(self, _source_ip_address, _dest_ip_address, _protocol, _port):
def get_relevant_rules(
self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str
) -> Dict[int, ACLRule]:
"""Get all ACL rules that relate to the given arguments.
:param _source_ip_address: the source IP address to check
@@ -172,7 +176,7 @@ class AccessControlList:
:param _protocol: the protocol to check
:param _port: the port to check
:return: Dictionary of all ACL rules that relate to the given arguments
:rtype: Dict[str, ACLRule]
:rtype: Dict[int, ACLRule]
"""
relevant_rules = {}
for rule in self.acl:

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""A class that implements an access control list rule."""
from primaite.common.enums import RulePermissionType
@@ -6,7 +6,9 @@ from primaite.common.enums import RulePermissionType
class ACLRule:
"""Access Control List Rule class."""
def __init__(self, _permission: RulePermissionType, _source_ip, _dest_ip, _protocol, _port):
def __init__(
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
) -> None:
"""
Initialise an ACL Rule.
@@ -16,13 +18,13 @@ class ACLRule:
:param _protocol: The rule protocol
:param _port: The rule port
"""
self.permission = _permission
self.source_ip = _source_ip
self.dest_ip = _dest_ip
self.protocol = _protocol
self.port = _port
self.permission: RulePermissionType = _permission
self.source_ip: str = _source_ip
self.dest_ip: str = _dest_ip
self.protocol: str = _protocol
self.port: str = _port
def __hash__(self):
def __hash__(self) -> int:
"""
Override the hash function.
@@ -39,7 +41,7 @@ class ACLRule:
)
)
def get_permission(self):
def get_permission(self) -> str:
"""
Gets the permission attribute.
@@ -48,7 +50,7 @@ class ACLRule:
"""
return self.permission
def get_source_ip(self):
def get_source_ip(self) -> str:
"""
Gets the source IP address attribute.
@@ -57,7 +59,7 @@ class ACLRule:
"""
return self.source_ip
def get_dest_ip(self):
def get_dest_ip(self) -> str:
"""
Gets the desintation IP address attribute.
@@ -66,7 +68,7 @@ class ACLRule:
"""
return self.dest_ip
def get_protocol(self):
def get_protocol(self) -> str:
"""
Gets the protocol attribute.
@@ -75,7 +77,7 @@ class ACLRule:
"""
return self.protocol
def get_port(self):
def get_port(self) -> str:
"""
Gets the port attribute.

View File

@@ -1 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Common interface between RL agents from different libraries and PrimAITE."""

View File

@@ -1,10 +1,12 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from __future__ import annotations
import json
from abc import ABC, abstractmethod
from datetime import datetime
from logging import Logger
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Any, Dict, Optional, Union
from uuid import uuid4
import primaite
@@ -15,7 +17,7 @@ from primaite.data_viz.session_plots import plot_av_reward_per_episode
from primaite.environment.primaite_env import Primaite
from primaite.utils.session_metadata_parser import parse_session_metadata
_LOGGER = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
def get_session_path(session_timestamp: datetime) -> Path:
@@ -50,7 +52,7 @@ class AgentSessionABC(ABC):
training_config_path: Optional[Union[str, Path]] = None,
lay_down_config_path: Optional[Union[str, Path]] = None,
session_path: Optional[Union[str, Path]] = None,
):
) -> None:
"""
Initialise an agent session from config files, or load a previous session.
@@ -130,11 +132,11 @@ class AgentSessionABC(ABC):
return path
@property
def uuid(self):
def uuid(self) -> str:
"""The Agent Session UUID."""
return self._uuid
def _write_session_metadata_file(self):
def _write_session_metadata_file(self) -> None:
"""
Write the ``session_metadata.json`` file.
@@ -170,7 +172,7 @@ class AgentSessionABC(ABC):
json.dump(metadata_dict, file)
_LOGGER.debug("Finished writing session metadata file")
def _update_session_metadata_file(self):
def _update_session_metadata_file(self) -> None:
"""
Update the ``session_metadata.json`` file.
@@ -199,7 +201,7 @@ class AgentSessionABC(ABC):
_LOGGER.debug("Finished updating session metadata file")
@abstractmethod
def _setup(self):
def _setup(self) -> None:
_LOGGER.info(
"Welcome to the Primary-level AI Training Environment " f"(PrimAITE) (version: {primaite.__version__})"
)
@@ -209,14 +211,14 @@ class AgentSessionABC(ABC):
self._can_evaluate = False
@abstractmethod
def _save_checkpoint(self):
def _save_checkpoint(self) -> None:
pass
@abstractmethod
def learn(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Train the agent.
@@ -233,8 +235,8 @@ class AgentSessionABC(ABC):
@abstractmethod
def evaluate(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
@@ -247,10 +249,10 @@ class AgentSessionABC(ABC):
_LOGGER.info("Finished evaluation")
@abstractmethod
def _get_latest_checkpoint(self):
def _get_latest_checkpoint(self) -> None:
pass
def load(self, path: Union[str, Path]):
def load(self, path: Union[str, Path]) -> None:
"""Load an agent from file."""
md_dict, training_config_path, laydown_config_path = parse_session_metadata(path)
@@ -274,21 +276,21 @@ class AgentSessionABC(ABC):
return self.learning_path / file_name
@abstractmethod
def save(self):
def save(self) -> None:
"""Save the agent."""
pass
@abstractmethod
def export(self):
def export(self) -> None:
"""Export the agent to transportable file format."""
pass
def close(self):
def close(self) -> None:
"""Closes the agent."""
self._env.episode_av_reward_writer.close() # noqa
self._env.transaction_writer.close() # noqa
def _plot_av_reward_per_episode(self, learning_session: bool = True):
def _plot_av_reward_per_episode(self, learning_session: bool = True) -> None:
# self.close()
title = f"PrimAITE Session {self.timestamp_str} "
subtitle = str(self._training_config)

View File

@@ -1,7 +1,10 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import time
from abc import abstractmethod
from pathlib import Path
from typing import Optional, Union
from typing import Any, Optional, Union
import numpy as np
from primaite import getLogger
from primaite.agents.agent_abc import AgentSessionABC
@@ -23,7 +26,7 @@ class HardCodedAgentSessionABC(AgentSessionABC):
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
):
) -> None:
"""
Initialise a hardcoded agent session.
@@ -36,7 +39,7 @@ class HardCodedAgentSessionABC(AgentSessionABC):
super().__init__(training_config_path, lay_down_config_path, session_path)
self._setup()
def _setup(self):
def _setup(self) -> None:
self._env: Primaite = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
@@ -47,16 +50,16 @@ class HardCodedAgentSessionABC(AgentSessionABC):
self._can_learn = False
self._can_evaluate = True
def _save_checkpoint(self):
def _save_checkpoint(self) -> None:
pass
def _get_latest_checkpoint(self):
def _get_latest_checkpoint(self) -> None:
pass
def learn(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Train the agent.
@@ -65,13 +68,13 @@ class HardCodedAgentSessionABC(AgentSessionABC):
_LOGGER.warning("Deterministic agents cannot learn")
@abstractmethod
def _calculate_action(self, obs):
def _calculate_action(self, obs: np.ndarray) -> None:
pass
def evaluate(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
@@ -102,14 +105,14 @@ class HardCodedAgentSessionABC(AgentSessionABC):
self._env.close()
@classmethod
def load(cls, path=None):
def load(cls, path: Union[str, Path] = None) -> None:
"""Load an agent from file."""
_LOGGER.warning("Deterministic agents cannot be loaded")
def save(self):
def save(self) -> None:
"""Save the agent."""
_LOGGER.warning("Deterministic agents cannot be saved")
def export(self):
def export(self) -> None:
"""Export the agent to transportable file format."""
_LOGGER.warning("Deterministic agents cannot be exported")

View File

@@ -1,4 +1,5 @@
from typing import Any, Dict, List, Union
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from typing import Dict, List, Union
import numpy as np
@@ -32,7 +33,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
def get_blocked_green_iers(
self, green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[Any, Any]:
) -> Dict[str, IER]:
"""Get blocked green IERs.
:param green_iers: Green IERs to check for being
@@ -60,7 +61,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
return blocked_green_iers
def get_matching_acl_rules_for_ier(self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]):
def get_matching_acl_rules_for_ier(
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[int, ACLRule]:
"""Get list of ACL rules which are relevant to an IER.
:param ier: Information Exchange Request to query against the ACL list
@@ -83,7 +86,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
def get_blocking_acl_rules_for_ier(
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[str, Any]:
) -> Dict[int, ACLRule]:
"""
Get blocking ACL rules for an IER.
@@ -111,7 +114,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
def get_allow_acl_rules_for_ier(
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[str, Any]:
) -> Dict[int, ACLRule]:
"""Get all allowing ACL rules for an IER.
:param ier: Information Exchange Request to query against the ACL list
@@ -141,7 +144,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
acl: AccessControlList,
nodes: Dict[str, Union[ServiceNode, ActiveNode]],
services_list: List[str],
) -> Dict[str, ACLRule]:
) -> Dict[int, ACLRule]:
"""Filter ACL rules to only those which are relevant to the specified nodes.
:param source_node_id: Source node
@@ -173,6 +176,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
if protocol != "ANY":
protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services
# TODO: This should throw an error because protocol is a string
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
return matching_rules
@@ -186,7 +190,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
acl: AccessControlList,
nodes: Dict[str, NodeUnion],
services_list: List[str],
) -> Dict[str, ACLRule]:
) -> Dict[int, ACLRule]:
"""List ALLOW rules relating to specified nodes.
:param source_node_id: Source node id
@@ -233,7 +237,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
acl: AccessControlList,
nodes: Dict[str, NodeUnion],
services_list: List[str],
) -> Dict[str, ACLRule]:
) -> Dict[int, ACLRule]:
"""List DENY rules relating to specified nodes.
:param source_node_id: Source node id

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import numpy as np
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
@@ -101,6 +102,7 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
property_action,
action_service_index,
]
# TODO: transform_action_node_enum takes only one argument, not sure why two are given here.
action = transform_action_node_enum(action, action_dict)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step

View File

@@ -1,10 +1,12 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from __future__ import annotations
import json
import shutil
from datetime import datetime
from logging import Logger
from pathlib import Path
from typing import Optional, Union
from typing import Any, Callable, Dict, Optional, Union
from uuid import uuid4
from ray.rllib.algorithms import Algorithm
@@ -18,10 +20,11 @@ from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
def _env_creator(env_config):
# TODO: verify type of env_config
def _env_creator(env_config: Dict[str, Any]) -> Primaite:
return Primaite(
training_config_path=env_config["training_config_path"],
lay_down_config_path=env_config["lay_down_config_path"],
@@ -30,11 +33,12 @@ def _env_creator(env_config):
)
def _custom_log_creator(session_path: Path):
# TODO: verify type hint return type
def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]:
logdir = session_path / "ray_results"
logdir.mkdir(parents=True, exist_ok=True)
def logger_creator(config):
def logger_creator(config: Dict) -> UnifiedLogger:
return UnifiedLogger(config, logdir, loggers=None)
return logger_creator
@@ -48,7 +52,7 @@ class RLlibAgent(AgentSessionABC):
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
):
) -> None:
"""
Initialise the RLLib Agent training session.
@@ -73,6 +77,7 @@ class RLlibAgent(AgentSessionABC):
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
raise ValueError(msg)
self._agent_config_class: Union[PPOConfig, A2CConfig]
if self._training_config.agent_identifier == AgentIdentifier.PPO:
self._agent_config_class = PPOConfig
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
@@ -94,7 +99,7 @@ class RLlibAgent(AgentSessionABC):
f"{self._training_config.deep_learning_framework}"
)
def _update_session_metadata_file(self):
def _update_session_metadata_file(self) -> None:
"""
Update the ``session_metadata.json`` file.
@@ -122,7 +127,7 @@ class RLlibAgent(AgentSessionABC):
json.dump(metadata_dict, file)
_LOGGER.debug("Finished updating session metadata file")
def _setup(self):
def _setup(self) -> None:
super()._setup()
register_env("primaite", _env_creator)
self._agent_config = self._agent_config_class()
@@ -148,7 +153,7 @@ class RLlibAgent(AgentSessionABC):
)
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
def _save_checkpoint(self):
def _save_checkpoint(self) -> None:
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._current_result["episodes_total"]
save_checkpoint = False
@@ -159,8 +164,8 @@ class RLlibAgent(AgentSessionABC):
def learn(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
@@ -180,8 +185,8 @@ class RLlibAgent(AgentSessionABC):
def evaluate(
self,
**kwargs,
):
**kwargs: None,
) -> None:
"""
Evaluate the agent.
@@ -189,7 +194,7 @@ class RLlibAgent(AgentSessionABC):
"""
raise NotImplementedError
def _get_latest_checkpoint(self):
def _get_latest_checkpoint(self) -> None:
raise NotImplementedError
@classmethod
@@ -197,7 +202,7 @@ class RLlibAgent(AgentSessionABC):
"""Load an agent from file."""
raise NotImplementedError
def save(self, overwrite_existing: bool = True):
def save(self, overwrite_existing: bool = True) -> None:
"""Save the agent."""
# Make temp dir to save in isolation
temp_dir = self.learning_path / str(uuid4())
@@ -217,6 +222,6 @@ class RLlibAgent(AgentSessionABC):
# Drop the temp directory
shutil.rmtree(temp_dir)
def export(self):
def export(self) -> None:
"""Export the agent to transportable file format."""
raise NotImplementedError

View File

@@ -1,8 +1,10 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from __future__ import annotations
import json
from logging import Logger
from pathlib import Path
from typing import Optional, Union
from typing import Any, Optional, Union
import numpy as np
from stable_baselines3 import A2C, PPO
@@ -13,7 +15,7 @@ from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
class SB3Agent(AgentSessionABC):
@@ -24,7 +26,7 @@ class SB3Agent(AgentSessionABC):
training_config_path: Optional[Union[str, Path]] = None,
lay_down_config_path: Optional[Union[str, Path]] = None,
session_path: Optional[Union[str, Path]] = None,
):
) -> None:
"""
Initialise the SB3 Agent training session.
@@ -42,6 +44,7 @@ class SB3Agent(AgentSessionABC):
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
raise ValueError(msg)
self._agent_class: Union[PPO, A2C]
if self._training_config.agent_identifier == AgentIdentifier.PPO:
self._agent_class = PPO
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
@@ -65,7 +68,7 @@ class SB3Agent(AgentSessionABC):
self._setup()
def _setup(self):
def _setup(self) -> None:
"""Set up the SB3 Agent."""
self._env = Primaite(
training_config_path=self._training_config_path,
@@ -112,7 +115,7 @@ class SB3Agent(AgentSessionABC):
super()._setup()
def _save_checkpoint(self):
def _save_checkpoint(self) -> None:
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._env.episode_count
save_checkpoint = False
@@ -123,13 +126,13 @@ class SB3Agent(AgentSessionABC):
self._agent.save(checkpoint_path)
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
def _get_latest_checkpoint(self):
def _get_latest_checkpoint(self) -> None:
pass
def learn(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Train the agent.
@@ -152,8 +155,8 @@ class SB3Agent(AgentSessionABC):
def evaluate(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
@@ -182,10 +185,10 @@ class SB3Agent(AgentSessionABC):
self._env.close()
super().evaluate()
def save(self):
def save(self) -> None:
"""Save the agent."""
self._agent.save(self._saved_agent_path)
def export(self):
def export(self) -> None:
"""Export the agent to transportable file format."""
raise NotImplementedError

View File

@@ -1,3 +1,7 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import numpy as np
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
@@ -9,7 +13,7 @@ class RandomAgent(HardCodedAgentSessionABC):
Get a completely random action from the action space.
"""
def _calculate_action(self, obs):
def _calculate_action(self, obs: np.ndarray) -> int:
return self._env.action_space.sample()
@@ -20,7 +24,7 @@ class DummyAgent(HardCodedAgentSessionABC):
All action spaces setup so dummy action is always 0 regardless of action type used.
"""
def _calculate_action(self, obs):
def _calculate_action(self, obs: np.ndarray) -> int:
return 0
@@ -31,7 +35,7 @@ class DoNothingACLAgent(HardCodedAgentSessionABC):
A valid ACL action that has no effect; does nothing.
"""
def _calculate_action(self, obs):
def _calculate_action(self, obs: np.ndarray) -> int:
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
nothing_action = transform_action_acl_enum(nothing_action)
nothing_action = get_new_action(nothing_action, self._env.action_dict)
@@ -46,7 +50,7 @@ class DoNothingNodeAgent(HardCodedAgentSessionABC):
A valid Node action that has no effect; does nothing.
"""
def _calculate_action(self, obs):
def _calculate_action(self, obs: np.ndarray) -> int:
nothing_action = [1, "NONE", "ON", 0]
nothing_action = transform_action_node_enum(nothing_action)
nothing_action = get_new_action(nothing_action, self._env.action_dict)

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from typing import Dict, List, Union
import numpy as np
@@ -34,11 +35,11 @@ def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]:
else:
property_action = "NONE"
new_action = [action[0], action_node_property, property_action, action[3]]
new_action: list[Union[int, str]] = [action[0], action_node_property, property_action, action[3]]
return new_action
def transform_action_acl_readable(action: List[str]) -> List[Union[str, int]]:
def transform_action_acl_readable(action: List[int]) -> List[Union[str, int]]:
"""
Transform an ACL action to a more readable format.

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Provides a CLI using Typer as an entry point."""
import logging
import os
@@ -19,7 +19,7 @@ app = typer.Typer()
@app.command()
def build_dirs():
def build_dirs() -> None:
"""Build the PrimAITE app directories."""
from primaite.setup import setup_app_dirs
@@ -27,7 +27,7 @@ def build_dirs():
@app.command()
def reset_notebooks(overwrite: bool = True):
def reset_notebooks(overwrite: bool = True) -> None:
"""
Force a reset of the demo notebooks in the users notebooks directory.
@@ -39,7 +39,7 @@ def reset_notebooks(overwrite: bool = True):
@app.command()
def logs(last_n: Annotated[int, typer.Option("-n")]):
def logs(last_n: Annotated[int, typer.Option("-n")]) -> None:
"""
Print the PrimAITE log file.
@@ -60,7 +60,7 @@ _LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # n
@app.command()
def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None):
def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None) -> None:
"""
View or set the PrimAITE Log Level.
@@ -88,7 +88,7 @@ def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None):
@app.command()
def notebooks():
def notebooks() -> None:
"""Start Jupyter Lab in the users PrimAITE notebooks directory."""
from primaite.notebooks import start_jupyter_session
@@ -96,7 +96,7 @@ def notebooks():
@app.command()
def version():
def version() -> None:
"""Get the installed PrimAITE version number."""
import primaite
@@ -104,7 +104,7 @@ def version():
@app.command()
def clean_up():
def clean_up() -> None:
"""Cleans up left over files from previous version installations."""
from primaite.setup import old_installation_clean_up
@@ -112,7 +112,7 @@ def clean_up():
@app.command()
def setup(overwrite_existing: bool = True):
def setup(overwrite_existing: bool = True) -> None:
"""
Perform the PrimAITE first-time setup.
@@ -151,7 +151,7 @@ def setup(overwrite_existing: bool = True):
@app.command()
def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[str] = None):
def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[str] = None) -> None:
"""
Run a PrimAITE session.
@@ -185,7 +185,7 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[
@app.command()
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None):
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None) -> None:
"""
View or set the plotly template for Session plots.

View File

@@ -1,2 +1,2 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Objects which are shared between many PrimAITE modules."""

View File

@@ -1,8 +1,8 @@
from typing import Type, Union
from typing import Union
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.passive_node import PassiveNode
from primaite.nodes.service_node import ServiceNode
NodeUnion: Type = Union[ActiveNode, PassiveNode, ServiceNode]
NodeUnion = Union[ActiveNode, PassiveNode, ServiceNode]
"""A Union of ActiveNode, PassiveNode, and ServiceNode."""

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Enumerations for APE."""
from enum import Enum, IntEnum
@@ -148,6 +148,7 @@ class ActionType(Enum):
ANY = 2
# TODO: this is not used anymore, write a ticket to delete it.
class ObservationType(Enum):
"""Observation type enumeration."""

View File

@@ -1,21 +1,21 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The protocol class."""
class Protocol(object):
"""Protocol class."""
def __init__(self, _name):
def __init__(self, _name: str) -> None:
"""
Initialise a protocol.
:param _name: The name of the protocol
:type _name: str
"""
self.name = _name
self.load = 0 # bps
self.name: str = _name
self.load: int = 0 # bps
def get_name(self):
def get_name(self) -> str:
"""
Gets the protocol name.
@@ -24,7 +24,7 @@ class Protocol(object):
"""
return self.name
def get_load(self):
def get_load(self) -> int:
"""
Gets the protocol load.
@@ -33,7 +33,7 @@ class Protocol(object):
"""
return self.load
def add_load(self, _load):
def add_load(self, _load: int) -> None:
"""
Adds load to the protocol.
@@ -42,6 +42,6 @@ class Protocol(object):
"""
self.load += _load
def clear_load(self):
def clear_load(self) -> None:
"""Clears the load on this protocol."""
self.load = 0

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The Service class."""
from primaite.common.enums import SoftwareState
@@ -7,7 +7,7 @@ from primaite.common.enums import SoftwareState
class Service(object):
"""Service class."""
def __init__(self, name: str, port: str, software_state: SoftwareState):
def __init__(self, name: str, port: str, software_state: SoftwareState) -> None:
"""
Initialise a service.
@@ -15,12 +15,12 @@ class Service(object):
:param port: The service port.
:param software_state: The service SoftwareState.
"""
self.name = name
self.port = port
self.software_state = software_state
self.patching_count = 0
self.name: str = name
self.port: str = port
self.software_state: SoftwareState = software_state
self.patching_count: int = 0
def reduce_patching_count(self):
def reduce_patching_count(self) -> None:
"""Reduces the patching count for the service."""
self.patching_count -= 1
if self.patching_count <= 0:

View File

@@ -1 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Configuration parameters for running experiments."""

View File

@@ -1,4 +1,5 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from logging import Logger
from pathlib import Path
from typing import Any, Dict, Final, Union
@@ -6,7 +7,7 @@ import yaml
from primaite import getLogger, USERS_CONFIG_DIR
_LOGGER = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"

View File

@@ -1,7 +1,8 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from __future__ import annotations
from dataclasses import dataclass, field
from logging import Logger
from pathlib import Path
from typing import Any, Dict, Final, Optional, Union
@@ -19,7 +20,7 @@ from primaite.common.enums import (
SessionType,
)
_LOGGER = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
_EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training"
@@ -86,7 +87,7 @@ class TrainingConfig:
session_type: SessionType = SessionType.TRAIN
"The type of PrimAITE session to run"
load_agent: str = False
load_agent: bool = False
"Determine whether to load an agent from file"
agent_load_file: Optional[str] = None
@@ -198,7 +199,7 @@ class TrainingConfig:
"The random number generator seed to be used while training the agent"
@classmethod
def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig:
def from_dict(cls, config_dict: Dict[str, Any]) -> TrainingConfig:
"""
Create an instance of TrainingConfig from a dict.
@@ -216,12 +217,14 @@ class TrainingConfig:
"implicit_acl_rule": RulePermissionType,
}
# convert the string representation of enums into the actual enum values themselves?
for key, value in field_enum_map.items():
if key in config_dict:
config_dict[key] = value[config_dict[key]]
return TrainingConfig(**config_dict)
def to_dict(self, json_serializable: bool = True):
def to_dict(self, json_serializable: bool = True) -> Dict:
"""
Serialise the ``TrainingConfig`` as dict.
@@ -341,7 +344,7 @@ def convert_legacy_training_config_dict(
return config_dict
def _get_new_key_from_legacy(legacy_key: str) -> str:
def _get_new_key_from_legacy(legacy_key: str) -> Optional[str]:
"""
Maps legacy training config keys to the new format keys.

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Utility to generate plots of sessions metrics after PrimAITE."""
from enum import Enum

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from pathlib import Path
from typing import Dict, Optional, Union

View File

@@ -1,2 +1,2 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Gym/Gymnasium environment for RL agents consisting of a simulated computer network."""

View File

@@ -1,6 +1,8 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Module for handling configurable observation spaces in PrimAITE."""
import logging
from abc import ABC, abstractmethod
from logging import Logger
from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union
import numpy as np
@@ -17,14 +19,15 @@ from primaite.nodes.service_node import ServiceNode
if TYPE_CHECKING:
from primaite.environment.primaite_env import Primaite
_LOGGER = logging.getLogger(__name__)
_LOGGER: Logger = logging.getLogger(__name__)
class AbstractObservationComponent(ABC):
"""Represents a part of the PrimAITE observation space."""
@abstractmethod
def __init__(self, env: "Primaite"):
def __init__(self, env: "Primaite") -> None:
"""
Initialise observation component.
@@ -39,7 +42,7 @@ class AbstractObservationComponent(ABC):
return NotImplemented
@abstractmethod
def update(self):
def update(self) -> None:
"""Update the observation based on the current state of the environment."""
self.current_observation = NotImplemented
@@ -74,7 +77,7 @@ class NodeLinkTable(AbstractObservationComponent):
_MAX_VAL: int = 1_000_000_000
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite"):
def __init__(self, env: "Primaite") -> None:
"""
Initialise a NodeLinkTable observation space component.
@@ -101,7 +104,7 @@ class NodeLinkTable(AbstractObservationComponent):
self.structure = self.generate_structure()
def update(self):
def update(self) -> None:
"""
Update the observation based on current environment state.
@@ -148,7 +151,7 @@ class NodeLinkTable(AbstractObservationComponent):
protocol_index += 1
item_index += 1
def generate_structure(self):
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
nodes = self.env.nodes.values()
links = self.env.links.values()
@@ -211,7 +214,7 @@ class NodeStatuses(AbstractObservationComponent):
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite"):
def __init__(self, env: "Primaite") -> None:
"""
Initialise a NodeStatuses observation component.
@@ -237,7 +240,7 @@ class NodeStatuses(AbstractObservationComponent):
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
self.structure = self.generate_structure()
def update(self):
def update(self) -> None:
"""
Update the observation based on current environment state.
@@ -268,7 +271,7 @@ class NodeStatuses(AbstractObservationComponent):
)
self.current_observation[:] = obs
def generate_structure(self):
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
services = self.env.services_list
@@ -318,7 +321,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
env: "Primaite",
combine_service_traffic: bool = False,
quantisation_levels: int = 5,
):
) -> None:
"""
Initialise a LinkTrafficLevels observation component.
@@ -360,7 +363,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
self.structure = self.generate_structure()
def update(self):
def update(self) -> None:
"""
Update the observation based on current environment state.
@@ -386,7 +389,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
self.current_observation[:] = obs
def generate_structure(self):
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
structure = []
for _, link in self.env.links.items():
@@ -470,7 +473,7 @@ class AccessControlList(AbstractObservationComponent):
self.structure = self.generate_structure()
def update(self):
def update(self) -> None:
"""Update the observation based on current environment state.
The structure of the observation space is described in :class:`.AccessControlList`
@@ -550,7 +553,7 @@ class AccessControlList(AbstractObservationComponent):
self.current_observation[:] = obs
def generate_structure(self):
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
structure = []
for acl_rule in self.env.acl.acl:
@@ -593,7 +596,7 @@ class ObservationsHandler:
"ACCESS_CONTROL_LIST": AccessControlList,
}
def __init__(self):
def __init__(self) -> None:
"""Initialise the observation handler."""
self.registered_obs_components: List[AbstractObservationComponent] = []
@@ -606,7 +609,7 @@ class ObservationsHandler:
# used for transactions and when flatten=true
self._flat_observation: np.ndarray
def update_obs(self):
def update_obs(self) -> None:
"""Fetch fresh information about the environment."""
current_obs = []
for obs in self.registered_obs_components:
@@ -619,7 +622,7 @@ class ObservationsHandler:
self._observation = tuple(current_obs)
self._flat_observation = spaces.flatten(self._space, self._observation)
def register(self, obs_component: AbstractObservationComponent):
def register(self, obs_component: AbstractObservationComponent) -> None:
"""
Add a component for this handler to track.
@@ -629,7 +632,7 @@ class ObservationsHandler:
self.registered_obs_components.append(obs_component)
self.update_space()
def deregister(self, obs_component: AbstractObservationComponent):
def deregister(self, obs_component: AbstractObservationComponent) -> None:
"""
Remove a component from this handler.
@@ -640,7 +643,7 @@ class ObservationsHandler:
self.registered_obs_components.remove(obs_component)
self.update_space()
def update_space(self):
def update_space(self) -> None:
"""Rebuild the handler's composite observation space from its components."""
component_spaces = []
for obs_comp in self.registered_obs_components:
@@ -657,7 +660,7 @@ class ObservationsHandler:
self._flat_space = spaces.Box(0, 1, (0,))
@property
def space(self):
def space(self) -> spaces.Space:
"""Observation space, return the flattened version if flatten is True."""
if len(self.registered_obs_components) > 1:
return self._flat_space
@@ -665,7 +668,7 @@ class ObservationsHandler:
return self._space
@property
def current_observation(self):
def current_observation(self) -> Union[np.ndarray, Tuple[np.ndarray]]:
"""Current observation, return the flattened version if flatten is True."""
if len(self.registered_obs_components) > 1:
return self._flat_observation
@@ -673,7 +676,7 @@ class ObservationsHandler:
return self._observation
@classmethod
def from_config(cls, env: "Primaite", obs_space_config: dict):
def from_config(cls, env: "Primaite", obs_space_config: dict) -> "ObservationsHandler":
"""
Parse a config dictinary, return a new observation handler populated with new observation component objects.
@@ -716,7 +719,7 @@ class ObservationsHandler:
handler.update_obs()
return handler
def describe_structure(self):
def describe_structure(self) -> List[str]:
"""
Create a list of names for the features of the obs space.

View File

@@ -1,11 +1,12 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Main environment module containing the PRIMmary AI Training Evironment (Primaite) class."""
import copy
import logging
import uuid as uuid
from logging import Logger
from pathlib import Path
from random import choice, randint, sample, uniform
from typing import Dict, Final, Tuple, Union
from typing import Any, Dict, Final, List, Tuple, Union
import networkx as nx
import numpy as np
@@ -20,6 +21,7 @@ from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
ActionType,
AgentFramework,
AgentIdentifier,
FileSystemState,
HardwareState,
NodePOLInitiator,
@@ -48,7 +50,7 @@ from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_nod
from primaite.transactions.transaction import Transaction
from primaite.utils.session_output_writer import SessionOutputWriter
_LOGGER = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
class Primaite(Env):
@@ -66,7 +68,7 @@ class Primaite(Env):
lay_down_config_path: Union[str, Path],
session_path: Path,
timestamp_str: str,
):
) -> None:
"""
The Primaite constructor.
@@ -77,13 +79,14 @@ class Primaite(Env):
"""
self.session_path: Final[Path] = session_path
self.timestamp_str: Final[str] = timestamp_str
self._training_config_path = training_config_path
self._lay_down_config_path = lay_down_config_path
self._training_config_path: Union[str, Path] = training_config_path
self._lay_down_config_path: Union[str, Path] = lay_down_config_path
self.training_config: TrainingConfig = training_config.load(training_config_path)
_LOGGER.info(f"Using: {str(self.training_config)}")
# Number of steps in an episode
self.episode_steps: int
if self.training_config.session_type == SessionType.TRAIN:
self.episode_steps = self.training_config.num_train_steps
elif self.training_config.session_type == SessionType.EVAL:
@@ -94,7 +97,7 @@ class Primaite(Env):
super(Primaite, self).__init__()
# The agent in use
self.agent_identifier = self.training_config.agent_identifier
self.agent_identifier: AgentIdentifier = self.training_config.agent_identifier
# Create a dictionary to hold all the nodes
self.nodes: Dict[str, NodeUnion] = {}
@@ -113,42 +116,42 @@ class Primaite(Env):
self.green_iers_reference: Dict[str, IER] = {}
# Create a dictionary to hold all the node PoLs (this will come from an external source)
self.node_pol = {}
self.node_pol: Dict[str, NodeStateInstructionGreen] = {}
# Create a dictionary to hold all the red agent IERs (this will come from an external source)
self.red_iers = {}
self.red_iers: Dict[str, IER] = {}
# Create a dictionary to hold all the red agent node PoLs (this will come from an external source)
self.red_node_pol = {}
self.red_node_pol: Dict[str, NodeStateInstructionRed] = {}
# Create the Access Control List
self.acl = AccessControlList(
self.acl: AccessControlList = AccessControlList(
self.training_config.implicit_acl_rule,
self.training_config.max_number_acl_rules,
)
# Sets limit for number of ACL rules in environment
self.max_number_acl_rules = self.training_config.max_number_acl_rules
self.max_number_acl_rules: int = self.training_config.max_number_acl_rules
# Create a list of services (enums)
self.services_list = []
self.services_list: List[str] = []
# Create a list of ports
self.ports_list = []
self.ports_list: List[str] = []
# Create graph (network)
self.network = nx.MultiGraph()
self.network: nx.Graph = nx.MultiGraph()
# Create a graph (network) reference
self.network_reference = nx.MultiGraph()
self.network_reference: nx.Graph = nx.MultiGraph()
# Create step count
self.step_count = 0
self.step_count: int = 0
self.total_step_count: int = 0
"""The total number of time steps completed."""
# Create step info dictionary
self.step_info = {}
self.step_info: Dict[Any] = {}
# Total reward
self.total_reward: float = 0
@@ -157,22 +160,23 @@ class Primaite(Env):
self.average_reward: float = 0
# Episode count
self.episode_count = 0
self.episode_count: int = 0
# Number of nodes - gets a value by examining the nodes dictionary after it's been populated
self.num_nodes = 0
self.num_nodes: int = 0
# Number of links - gets a value by examining the links dictionary after it's been populated
self.num_links = 0
self.num_links: int = 0
# Number of services - gets a value when config is loaded
self.num_services = 0
self.num_services: int = 0
# Number of ports - gets a value when config is loaded
self.num_ports = 0
self.num_ports: int = 0
# The action type
self.action_type = 0
# TODO: confirm type
self.action_type: int = 0
# TODO fix up with TrainingConfig
# stores the observation config from the yaml, default is NODE_LINK_TABLE
@@ -184,7 +188,7 @@ class Primaite(Env):
# It will be initialised later.
self.obs_handler: ObservationsHandler
self._obs_space_description = None
self._obs_space_description: List[str] = None
"The env observation space description for transactions writing"
# Open the config file and build the environment laydown
@@ -216,9 +220,13 @@ class Primaite(Env):
_LOGGER.error("Could not save network diagram", exc_info=True)
# Initiate observation space
self.observation_space: spaces.Space
self.env_obs: np.ndarray
self.observation_space, self.env_obs = self.init_observations()
# Define Action Space - depends on action space type (Node or ACL)
self.action_dict: Dict[int, List[int]]
self.action_space: spaces.Space
if self.training_config.action_type == ActionType.NODE:
_LOGGER.debug("Action space type NODE selected")
# Terms (for node action space):
@@ -246,8 +254,12 @@ class Primaite(Env):
else:
_LOGGER.error(f"Invalid action type selected: {self.training_config.action_type}")
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=True)
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=True)
self.episode_av_reward_writer: SessionOutputWriter = SessionOutputWriter(
self, transaction_writer=False, learning_session=True
)
self.transaction_writer: SessionOutputWriter = SessionOutputWriter(
self, transaction_writer=True, learning_session=True
)
@property
def actual_episode_count(self) -> int:
@@ -256,7 +268,7 @@ class Primaite(Env):
return self.episode_count - 1
return self.episode_count
def set_as_eval(self):
def set_as_eval(self) -> None:
"""Set the writers to write to eval directories."""
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=False)
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=False)
@@ -265,12 +277,12 @@ class Primaite(Env):
self.total_step_count = 0
self.episode_steps = self.training_config.num_eval_steps
def _write_av_reward_per_episode(self):
def _write_av_reward_per_episode(self) -> None:
if self.actual_episode_count > 0:
csv_data = self.actual_episode_count, self.average_reward
self.episode_av_reward_writer.write(csv_data)
def reset(self):
def reset(self) -> np.ndarray:
"""
AI Gym Reset function.
@@ -304,7 +316,7 @@ class Primaite(Env):
return self.env_obs
def step(self, action):
def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict]:
"""
AI Gym Step function.
@@ -423,7 +435,7 @@ class Primaite(Env):
# Return
return self.env_obs, reward, done, self.step_info
def close(self):
def close(self) -> None:
"""Override parent close and close writers."""
# Close files if last episode/step
# if self.can_finish:
@@ -432,18 +444,18 @@ class Primaite(Env):
self.transaction_writer.close()
self.episode_av_reward_writer.close()
def init_acl(self):
def init_acl(self) -> None:
"""Initialise the Access Control List."""
self.acl.remove_all_rules()
def output_link_status(self):
def output_link_status(self) -> None:
"""Output the link status of all links to the console."""
for link_key, link_value in self.links.items():
_LOGGER.debug("Link ID: " + link_value.get_id())
for protocol in link_value.protocol_list:
print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load()))
def interpret_action_and_apply(self, _action):
def interpret_action_and_apply(self, _action: int) -> None:
"""
Applies agent actions to the nodes and Access Control List.
@@ -462,7 +474,7 @@ class Primaite(Env):
else:
logging.error("Invalid action type found")
def apply_actions_to_nodes(self, _action):
def apply_actions_to_nodes(self, _action: int) -> None:
"""
Applies agent actions to the nodes.
@@ -550,7 +562,7 @@ class Primaite(Env):
else:
return
def apply_actions_to_acl(self, _action):
def apply_actions_to_acl(self, _action: int) -> None:
"""
Applies agent actions to the Access Control List [TO DO].
@@ -630,7 +642,7 @@ class Primaite(Env):
else:
return
def apply_time_based_updates(self):
def apply_time_based_updates(self) -> None:
"""
Updates anything that needs to count down and then change state.
@@ -686,12 +698,12 @@ class Primaite(Env):
return self.obs_handler.space, self.obs_handler.current_observation
def update_environent_obs(self):
def update_environent_obs(self) -> None:
"""Updates the observation space based on the node and link status."""
self.obs_handler.update_obs()
self.env_obs = self.obs_handler.current_observation
def load_lay_down_config(self):
def load_lay_down_config(self) -> None:
"""Loads config data in order to build the environment configuration."""
for item in self.lay_down_config:
if item["item_type"] == "NODE":
@@ -729,7 +741,7 @@ class Primaite(Env):
_LOGGER.info("Environment configuration loaded")
print("Environment configuration loaded")
def create_node(self, item):
def create_node(self, item: Dict) -> None:
"""
Creates a node from config data.
@@ -810,7 +822,7 @@ class Primaite(Env):
# Add node to network (reference)
self.network_reference.add_nodes_from([node_ref])
def create_link(self, item: Dict):
def create_link(self, item: Dict) -> None:
"""
Creates a link from config data.
@@ -854,7 +866,7 @@ class Primaite(Env):
self.services_list,
)
def create_green_ier(self, item):
def create_green_ier(self, item: Dict) -> None:
"""
Creates a green IER from config data.
@@ -895,7 +907,7 @@ class Primaite(Env):
ier_mission_criticality,
)
def create_red_ier(self, item):
def create_red_ier(self, item: Dict) -> None:
"""
Creates a red IER from config data.
@@ -925,7 +937,7 @@ class Primaite(Env):
ier_mission_criticality,
)
def create_green_pol(self, item):
def create_green_pol(self, item: Dict) -> None:
"""
Creates a green PoL object from config data.
@@ -959,7 +971,7 @@ class Primaite(Env):
pol_state,
)
def create_red_pol(self, item):
def create_red_pol(self, item: Dict) -> None:
"""
Creates a red PoL object from config data.
@@ -1000,7 +1012,7 @@ class Primaite(Env):
pol_source_node_service_state,
)
def create_acl_rule(self, item):
def create_acl_rule(self, item: Dict) -> None:
"""
Creates an ACL rule from config data.
@@ -1023,7 +1035,8 @@ class Primaite(Env):
acl_rule_position,
)
def create_services_list(self, services):
# TODO: confirm typehint using runtime
def create_services_list(self, services: Dict) -> None:
"""
Creates a list of services (enum) from config data.
@@ -1039,7 +1052,7 @@ class Primaite(Env):
# Set the number of services
self.num_services = len(self.services_list)
def create_ports_list(self, ports):
def create_ports_list(self, ports: Dict) -> None:
"""
Creates a list of ports from config data.
@@ -1055,7 +1068,8 @@ class Primaite(Env):
# Set the number of ports
self.num_ports = len(self.ports_list)
def get_observation_info(self, observation_info):
# TODO: this is not used anymore, write a ticket to delete it
def get_observation_info(self, observation_info: Dict) -> None:
"""
Extracts observation_info.
@@ -1064,7 +1078,8 @@ class Primaite(Env):
"""
self.observation_type = ObservationType[observation_info["type"]]
def get_action_info(self, action_info):
# TODO: this is not used anymore, write a ticket to delete it.
def get_action_info(self, action_info: Dict) -> None:
"""
Extracts action_info.
@@ -1073,7 +1088,7 @@ class Primaite(Env):
"""
self.action_type = ActionType[action_info["type"]]
def save_obs_config(self, obs_config: dict):
def save_obs_config(self, obs_config: dict) -> None:
"""
Cache the config for the observation space.
@@ -1086,7 +1101,7 @@ class Primaite(Env):
"""
self.obs_config = obs_config
def reset_environment(self):
def reset_environment(self) -> None:
"""
Resets environment.
@@ -1111,7 +1126,7 @@ class Primaite(Env):
for ier_key, ier_value in self.red_iers.items():
ier_value.set_is_running(False)
def reset_node(self, item):
def reset_node(self, item: Dict) -> None:
"""
Resets the statuses of a node.
@@ -1159,7 +1174,7 @@ class Primaite(Env):
# Bad formatting
pass
def create_node_action_dict(self):
def create_node_action_dict(self) -> Dict[int, List[int]]:
"""
Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.
@@ -1199,7 +1214,7 @@ class Primaite(Env):
return actions
def create_acl_action_dict(self):
def create_acl_action_dict(self) -> Dict[int, List[int]]:
"""Creates a dictionary mapping each possible discrete action to more readable multidiscrete action."""
# Terms (for ACL action space):
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
@@ -1240,7 +1255,7 @@ class Primaite(Env):
return actions
def create_node_and_acl_action_dict(self):
def create_node_and_acl_action_dict(self) -> Dict[int, List[int]]:
"""
Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action.
@@ -1257,7 +1272,7 @@ class Primaite(Env):
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
return combined_action_dict
def _create_random_red_agent(self):
def _create_random_red_agent(self) -> None:
"""Decide on random red agent for the episode to be called in env.reset()."""
# Reset the current red iers and red node pol
self.red_iers = {}

View File

@@ -1,25 +1,31 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Implements reward function."""
from typing import Dict
from logging import Logger
from typing import Dict, TYPE_CHECKING, Union
from primaite import getLogger
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
from primaite.common.service import Service
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.config.training_config import TrainingConfig
from primaite.pol.ier import IER
_LOGGER: Logger = getLogger(__name__)
def calculate_reward_function(
initial_nodes,
final_nodes,
reference_nodes,
green_iers,
green_iers_reference,
red_iers,
step_count,
config_values,
initial_nodes: Dict[str, NodeUnion],
final_nodes: Dict[str, NodeUnion],
reference_nodes: Dict[str, NodeUnion],
green_iers: Dict[str, "IER"],
green_iers_reference: Dict[str, "IER"],
red_iers: Dict[str, "IER"],
step_count: int,
config_values: "TrainingConfig",
) -> float:
"""
Compares the states of the initial and final nodes/links to get a reward.
@@ -93,7 +99,9 @@ def calculate_reward_function(
return reward_value
def score_node_operating_state(final_node, initial_node, reference_node, config_values) -> float:
def score_node_operating_state(
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig"
) -> float:
"""
Calculates score relating to the hardware state of a node.
@@ -142,7 +150,12 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
return score
def score_node_os_state(final_node, initial_node, reference_node, config_values) -> float:
def score_node_os_state(
final_node: Union[ActiveNode, ServiceNode],
initial_node: Union[ActiveNode, ServiceNode],
reference_node: Union[ActiveNode, ServiceNode],
config_values: "TrainingConfig",
) -> float:
"""
Calculates score relating to the Software State of a node.
@@ -193,7 +206,9 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
return score
def score_node_service_state(final_node, initial_node, reference_node, config_values) -> float:
def score_node_service_state(
final_node: ServiceNode, initial_node: ServiceNode, reference_node: ServiceNode, config_values: "TrainingConfig"
) -> float:
"""
Calculates score relating to the service state(s) of a node.
@@ -265,7 +280,12 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
return score
def score_node_file_system(final_node, initial_node, reference_node, config_values) -> float:
def score_node_file_system(
final_node: Union[ActiveNode, ServiceNode],
initial_node: Union[ActiveNode, ServiceNode],
reference_node: Union[ActiveNode, ServiceNode],
config_values: "TrainingConfig",
) -> float:
"""
Calculates score relating to the file system state of a node.

View File

@@ -1,2 +1,2 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Network connections between nodes in the simulation."""

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The link class."""
from typing import List
@@ -8,7 +8,7 @@ from primaite.common.protocol import Protocol
class Link(object):
"""Link class."""
def __init__(self, _id, _bandwidth, _source_node_name, _dest_node_name, _services):
def __init__(self, _id: str, _bandwidth: int, _source_node_name: str, _dest_node_name: str, _services: str) -> None:
"""
Initialise a Link within the simulated network.
@@ -18,17 +18,17 @@ class Link(object):
:param _dest_node_name: The name of the destination node
:param _protocols: The protocols to add to the link
"""
self.id = _id
self.bandwidth = _bandwidth
self.source_node_name = _source_node_name
self.dest_node_name = _dest_node_name
self.id: str = _id
self.bandwidth: int = _bandwidth
self.source_node_name: str = _source_node_name
self.dest_node_name: str = _dest_node_name
self.protocol_list: List[Protocol] = []
# Add the default protocols
for protocol_name in _services:
self.add_protocol(protocol_name)
def add_protocol(self, _protocol):
def add_protocol(self, _protocol: str) -> None:
"""
Adds a new protocol to the list of protocols on this link.
@@ -37,7 +37,7 @@ class Link(object):
"""
self.protocol_list.append(Protocol(_protocol))
def get_id(self):
def get_id(self) -> str:
"""
Gets link ID.
@@ -46,7 +46,7 @@ class Link(object):
"""
return self.id
def get_source_node_name(self):
def get_source_node_name(self) -> str:
"""
Gets source node name.
@@ -55,7 +55,7 @@ class Link(object):
"""
return self.source_node_name
def get_dest_node_name(self):
def get_dest_node_name(self) -> str:
"""
Gets destination node name.
@@ -64,7 +64,7 @@ class Link(object):
"""
return self.dest_node_name
def get_bandwidth(self):
def get_bandwidth(self) -> int:
"""
Gets bandwidth of link.
@@ -73,7 +73,7 @@ class Link(object):
"""
return self.bandwidth
def get_protocol_list(self):
def get_protocol_list(self) -> List[Protocol]:
"""
Gets list of protocols on this link.
@@ -82,7 +82,7 @@ class Link(object):
"""
return self.protocol_list
def get_current_load(self):
def get_current_load(self) -> int:
"""
Gets current total load on this link.
@@ -94,7 +94,7 @@ class Link(object):
total_load += protocol.get_load()
return total_load
def add_protocol_load(self, _protocol, _load):
def add_protocol_load(self, _protocol: str, _load: int) -> None:
"""
Adds a loading to a protocol on this link.
@@ -108,7 +108,7 @@ class Link(object):
else:
pass
def clear_traffic(self):
def clear_traffic(self) -> None:
"""Clears all traffic on this link."""
for protocol in self.protocol_list:
protocol.clear_load()

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The main PrimAITE session runner module."""
import argparse
from pathlib import Path
@@ -14,7 +14,7 @@ def run(
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
):
) -> None:
"""
Run the PrimAITE Session.

View File

@@ -1,2 +1,2 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Nodes represent network hosts in the simulation."""

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""An Active Node (i.e. not an actuator)."""
import logging
from typing import Final
@@ -24,7 +24,7 @@ class ActiveNode(Node):
software_state: SoftwareState,
file_system_state: FileSystemState,
config_values: TrainingConfig,
):
) -> None:
"""
Initialise an active node.
@@ -60,7 +60,7 @@ class ActiveNode(Node):
return self._software_state
@software_state.setter
def software_state(self, software_state: SoftwareState):
def software_state(self, software_state: SoftwareState) -> None:
"""
Get the software_state.
@@ -79,7 +79,7 @@ class ActiveNode(Node):
f"Node.software_state:{self._software_state}"
)
def set_software_state_if_not_compromised(self, software_state: SoftwareState):
def set_software_state_if_not_compromised(self, software_state: SoftwareState) -> None:
"""
Sets Software State if the node is not compromised.
@@ -99,14 +99,14 @@ class ActiveNode(Node):
f"Node.software_state:{self._software_state}"
)
def update_os_patching_status(self):
def update_os_patching_status(self) -> None:
"""Updates operating system status based on patching cycle."""
self.patching_count -= 1
if self.patching_count <= 0:
self.patching_count = 0
self._software_state = SoftwareState.GOOD
def set_file_system_state(self, file_system_state: FileSystemState):
def set_file_system_state(self, file_system_state: FileSystemState) -> None:
"""
Sets the file system state (actual and observed).
@@ -133,7 +133,7 @@ class ActiveNode(Node):
f"Node.file_system_state.actual:{self.file_system_state_actual}"
)
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState):
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState) -> None:
"""
Sets the file system state (actual and observed) if not in a compromised state.
@@ -166,12 +166,12 @@ class ActiveNode(Node):
f"Node.file_system_state.actual:{self.file_system_state_actual}"
)
def start_file_system_scan(self):
def start_file_system_scan(self) -> None:
"""Starts a file system scan."""
self.file_system_scanning = True
self.file_system_scanning_count = self.config_values.file_system_scanning_limit
def update_file_system_state(self):
def update_file_system_state(self) -> None:
"""Updates file system status based on scanning/restore/repair cycle."""
# Deprecate both the action count (for restoring or reparing) and the scanning count
self.file_system_action_count -= 1
@@ -193,14 +193,14 @@ class ActiveNode(Node):
self.file_system_scanning = False
self.file_system_scanning_count = 0
def update_resetting_status(self):
def update_resetting_status(self) -> None:
"""Updates the reset count & makes software and file state to GOOD."""
super().update_resetting_status()
if self.resetting_count <= 0:
self.file_system_state_actual = FileSystemState.GOOD
self.software_state = SoftwareState.GOOD
def update_booting_status(self):
def update_booting_status(self) -> None:
"""Updates the booting software and file state to GOOD."""
super().update_booting_status()
if self.booting_count <= 0:

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The base Node class."""
from typing import Final
@@ -17,7 +17,7 @@ class Node:
priority: Priority,
hardware_state: HardwareState,
config_values: TrainingConfig,
):
) -> None:
"""
Initialise a node.
@@ -38,40 +38,40 @@ class Node:
self.booting_count: int = 0
self.shutting_down_count: int = 0
def __repr__(self):
def __repr__(self) -> str:
"""Returns the name of the node."""
return self.name
def turn_on(self):
def turn_on(self) -> None:
"""Sets the node state to ON."""
self.hardware_state = HardwareState.BOOTING
self.booting_count = self.config_values.node_booting_duration
def turn_off(self):
def turn_off(self) -> None:
"""Sets the node state to OFF."""
self.hardware_state = HardwareState.OFF
self.shutting_down_count = self.config_values.node_shutdown_duration
def reset(self):
def reset(self) -> None:
"""Sets the node state to Resetting and starts the reset count."""
self.hardware_state = HardwareState.RESETTING
self.resetting_count = self.config_values.node_reset_duration
def update_resetting_status(self):
def update_resetting_status(self) -> None:
"""Updates the resetting count."""
self.resetting_count -= 1
if self.resetting_count <= 0:
self.resetting_count = 0
self.hardware_state = HardwareState.ON
def update_booting_status(self):
def update_booting_status(self) -> None:
"""Updates the booting count."""
self.booting_count -= 1
if self.booting_count <= 0:
self.booting_count = 0
self.hardware_state = HardwareState.ON
def update_shutdown_status(self):
def update_shutdown_status(self) -> None:
"""Updates the shutdown count."""
self.shutting_down_count -= 1
if self.shutting_down_count <= 0:

View File

@@ -1,5 +1,9 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Defines node behaviour for Green PoL."""
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from primaite.common.enums import FileSystemState, HardwareState, NodePOLType, SoftwareState
class NodeStateInstructionGreen(object):
@@ -7,14 +11,14 @@ class NodeStateInstructionGreen(object):
def __init__(
self,
_id,
_start_step,
_end_step,
_node_id,
_node_pol_type,
_service_name,
_state,
):
_id: str,
_start_step: int,
_end_step: int,
_node_id: str,
_node_pol_type: "NodePOLType",
_service_name: str,
_state: Union["HardwareState", "SoftwareState", "FileSystemState"],
) -> None:
"""
Initialise the Node State Instruction.
@@ -30,11 +34,12 @@ class NodeStateInstructionGreen(object):
self.start_step = _start_step
self.end_step = _end_step
self.node_id = _node_id
self.node_pol_type = _node_pol_type
self.service_name = _service_name # Not used when not a service instruction
self.state = _state
self.node_pol_type: "NodePOLType" = _node_pol_type
self.service_name: str = _service_name # Not used when not a service instruction
# TODO: confirm type of state
self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _state
def get_start_step(self):
def get_start_step(self) -> int:
"""
Gets the start step.
@@ -43,7 +48,7 @@ class NodeStateInstructionGreen(object):
"""
return self.start_step
def get_end_step(self):
def get_end_step(self) -> int:
"""
Gets the end step.
@@ -52,7 +57,7 @@ class NodeStateInstructionGreen(object):
"""
return self.end_step
def get_node_id(self):
def get_node_id(self) -> str:
"""
Gets the node ID.
@@ -61,7 +66,7 @@ class NodeStateInstructionGreen(object):
"""
return self.node_id
def get_node_pol_type(self):
def get_node_pol_type(self) -> "NodePOLType":
"""
Gets the node pattern of life type (enum).
@@ -70,7 +75,7 @@ class NodeStateInstructionGreen(object):
"""
return self.node_pol_type
def get_service_name(self):
def get_service_name(self) -> str:
"""
Gets the service name.
@@ -79,7 +84,7 @@ class NodeStateInstructionGreen(object):
"""
return self.service_name
def get_state(self):
def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]:
"""
Gets the state (node or service).

View File

@@ -1,9 +1,13 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Defines node behaviour for Green PoL."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Union
from primaite.common.enums import NodePOLType
if TYPE_CHECKING:
from primaite.common.enums import FileSystemState, HardwareState, NodePOLInitiator, SoftwareState
@dataclass()
class NodeStateInstructionRed(object):
@@ -11,18 +15,18 @@ class NodeStateInstructionRed(object):
def __init__(
self,
_id,
_start_step,
_end_step,
_target_node_id,
_pol_initiator,
_id: str,
_start_step: int,
_end_step: int,
_target_node_id: str,
_pol_initiator: "NodePOLInitiator",
_pol_type: NodePOLType,
pol_protocol,
_pol_state,
_pol_source_node_id,
_pol_source_node_service,
_pol_source_node_service_state,
):
pol_protocol: str,
_pol_state: Union["HardwareState", "SoftwareState", "FileSystemState"],
_pol_source_node_id: str,
_pol_source_node_service: str,
_pol_source_node_service_state: str,
) -> None:
"""
Initialise the Node State Instruction for the red agent.
@@ -38,19 +42,19 @@ class NodeStateInstructionRed(object):
:param _pol_source_node_service: The source node service (used for initiator type SERVICE)
:param _pol_source_node_service_state: The source node service state (used for initiator type SERVICE)
"""
self.id = _id
self.start_step = _start_step
self.end_step = _end_step
self.target_node_id = _target_node_id
self.initiator = _pol_initiator
self.id: str = _id
self.start_step: int = _start_step
self.end_step: int = _end_step
self.target_node_id: str = _target_node_id
self.initiator: "NodePOLInitiator" = _pol_initiator
self.pol_type: NodePOLType = _pol_type
self.service_name = pol_protocol # Not used when not a service instruction
self.state = _pol_state
self.source_node_id = _pol_source_node_id
self.source_node_service = _pol_source_node_service
self.service_name: str = pol_protocol # Not used when not a service instruction
self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _pol_state
self.source_node_id: str = _pol_source_node_id
self.source_node_service: str = _pol_source_node_service
self.source_node_service_state = _pol_source_node_service_state
def get_start_step(self):
def get_start_step(self) -> int:
"""
Gets the start step.
@@ -59,7 +63,7 @@ class NodeStateInstructionRed(object):
"""
return self.start_step
def get_end_step(self):
def get_end_step(self) -> int:
"""
Gets the end step.
@@ -68,7 +72,7 @@ class NodeStateInstructionRed(object):
"""
return self.end_step
def get_target_node_id(self):
def get_target_node_id(self) -> str:
"""
Gets the node ID.
@@ -77,7 +81,7 @@ class NodeStateInstructionRed(object):
"""
return self.target_node_id
def get_initiator(self):
def get_initiator(self) -> "NodePOLInitiator":
"""
Gets the initiator.
@@ -95,7 +99,7 @@ class NodeStateInstructionRed(object):
"""
return self.pol_type
def get_service_name(self):
def get_service_name(self) -> str:
"""
Gets the service name.
@@ -104,7 +108,7 @@ class NodeStateInstructionRed(object):
"""
return self.service_name
def get_state(self):
def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]:
"""
Gets the state (node or service).
@@ -113,7 +117,7 @@ class NodeStateInstructionRed(object):
"""
return self.state
def get_source_node_id(self):
def get_source_node_id(self) -> str:
"""
Gets the source node id (used for initiator type SERVICE).
@@ -122,7 +126,7 @@ class NodeStateInstructionRed(object):
"""
return self.source_node_id
def get_source_node_service(self):
def get_source_node_service(self) -> str:
"""
Gets the source node service (used for initiator type SERVICE).
@@ -131,7 +135,7 @@ class NodeStateInstructionRed(object):
"""
return self.source_node_service
def get_source_node_service_state(self):
def get_source_node_service_state(self) -> str:
"""
Gets the source node service state (used for initiator type SERVICE).

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The Passive Node class (i.e. an actuator)."""
from primaite.common.enums import HardwareState, NodeType, Priority
from primaite.config.training_config import TrainingConfig
@@ -16,7 +16,7 @@ class PassiveNode(Node):
priority: Priority,
hardware_state: HardwareState,
config_values: TrainingConfig,
):
) -> None:
"""
Initialise a passive node.

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""A Service Node (i.e. not an actuator)."""
import logging
from typing import Dict, Final
@@ -25,7 +25,7 @@ class ServiceNode(ActiveNode):
software_state: SoftwareState,
file_system_state: FileSystemState,
config_values: TrainingConfig,
):
) -> None:
"""
Initialise a Service Node.
@@ -52,7 +52,7 @@ class ServiceNode(ActiveNode):
)
self.services: Dict[str, Service] = {}
def add_service(self, service: Service):
def add_service(self, service: Service) -> None:
"""
Adds a service to the node.
@@ -102,7 +102,7 @@ class ServiceNode(ActiveNode):
return False
return False
def set_service_state(self, protocol_name: str, software_state: SoftwareState):
def set_service_state(self, protocol_name: str, software_state: SoftwareState) -> None:
"""
Sets the software_state of a service (protocol) on the node.
@@ -131,7 +131,7 @@ class ServiceNode(ActiveNode):
f"Node.services[<key>].software_state:{software_state}"
)
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState):
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState) -> None:
"""
Sets the software_state of a service (protocol) on the node.
@@ -158,7 +158,7 @@ class ServiceNode(ActiveNode):
f"Node.services[<key>].software_state:{software_state}"
)
def get_service_state(self, protocol_name):
def get_service_state(self, protocol_name: str) -> SoftwareState:
"""
Gets the state of a service.
@@ -169,20 +169,20 @@ class ServiceNode(ActiveNode):
if service_value:
return service_value.software_state
def update_services_patching_status(self):
def update_services_patching_status(self) -> None:
"""Updates the patching counter for any service that are patching."""
for service_key, service_value in self.services.items():
if service_value.software_state == SoftwareState.PATCHING:
service_value.reduce_patching_count()
def update_resetting_status(self):
def update_resetting_status(self) -> None:
"""Update resetting counter and set software state if it reached 0."""
super().update_resetting_status()
if self.resetting_count <= 0:
for service in self.services.values():
service.software_state = SoftwareState.GOOD
def update_booting_status(self):
def update_booting_status(self) -> None:
"""Update booting counter and set software to good if it reached 0."""
super().update_booting_status()
if self.booting_count <= 0:

View File

@@ -1,16 +1,18 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Contains default jupyter notebooks which demonstrate PrimAITE functionality."""
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
import importlib.util
import os
import subprocess
import sys
from logging import Logger
from primaite import getLogger, NOTEBOOKS_DIR
_LOGGER = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
def start_jupyter_session():
def start_jupyter_session() -> None:
"""
Starts a new Jupyter notebook session in the app notebooks directory.

View File

@@ -1,2 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Pattern of Life- Represents the actions of users on the network."""
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.

View File

@@ -1,6 +1,6 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Implements Pattern of Life on the network (nodes and links)."""
from typing import Dict, Union
from typing import Dict
from networkx import MultiGraph, shortest_path
@@ -10,11 +10,10 @@ from primaite.common.enums import HardwareState, NodePOLType, NodeType, Software
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from primaite.nodes.service_node import ServiceNode
from primaite.pol.ier import IER
_VERBOSE = False
_VERBOSE: bool = False
def apply_iers(
@@ -24,7 +23,7 @@ def apply_iers(
iers: Dict[str, IER],
acl: AccessControlList,
step: int,
):
) -> None:
"""
Applies IERs to the links (link pattern of life).
@@ -65,6 +64,8 @@ def apply_iers(
dest_node = nodes[dest_node_id]
# 1. Check the source node situation
# TODO: should be using isinstance rather than checking node type attribute. IE. just because it's a switch
# doesn't mean it has a software state? It could be a PassiveNode or ActiveNode
if source_node.node_type == NodeType.SWITCH:
# It's a switch
if (
@@ -215,9 +216,9 @@ def apply_iers(
def apply_node_pol(
nodes: Dict[str, NodeUnion],
node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]],
node_pol: Dict[str, NodeStateInstructionGreen],
step: int,
):
) -> None:
"""
Applies node pattern of life.

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""
Information Exchange Requirements for APE.
@@ -11,17 +11,17 @@ class IER(object):
def __init__(
self,
_id,
_start_step,
_end_step,
_load,
_protocol,
_port,
_source_node_id,
_dest_node_id,
_mission_criticality,
_running=False,
):
_id: str,
_start_step: int,
_end_step: int,
_load: int,
_protocol: str,
_port: str,
_source_node_id: str,
_dest_node_id: str,
_mission_criticality: int,
_running: bool = False,
) -> None:
"""
Initialise an Information Exchange Request.
@@ -36,18 +36,18 @@ class IER(object):
:param _mission_criticality: Criticality of this IER to the mission (0 none, 5 mission critical)
:param _running: Indicates whether the IER is currently running
"""
self.id = _id
self.start_step = _start_step
self.end_step = _end_step
self.source_node_id = _source_node_id
self.dest_node_id = _dest_node_id
self.load = _load
self.protocol = _protocol
self.port = _port
self.mission_criticality = _mission_criticality
self.running = _running
self.id: str = _id
self.start_step: int = _start_step
self.end_step: int = _end_step
self.source_node_id: str = _source_node_id
self.dest_node_id: str = _dest_node_id
self.load: int = _load
self.protocol: str = _protocol
self.port: str = _port
self.mission_criticality: int = _mission_criticality
self.running: bool = _running
def get_id(self):
def get_id(self) -> str:
"""
Gets IER ID.
@@ -56,7 +56,7 @@ class IER(object):
"""
return self.id
def get_start_step(self):
def get_start_step(self) -> int:
"""
Gets IER start step.
@@ -65,7 +65,7 @@ class IER(object):
"""
return self.start_step
def get_end_step(self):
def get_end_step(self) -> int:
"""
Gets IER end step.
@@ -74,7 +74,7 @@ class IER(object):
"""
return self.end_step
def get_load(self):
def get_load(self) -> int:
"""
Gets IER load.
@@ -83,7 +83,7 @@ class IER(object):
"""
return self.load
def get_protocol(self):
def get_protocol(self) -> str:
"""
Gets IER protocol.
@@ -92,7 +92,7 @@ class IER(object):
"""
return self.protocol
def get_port(self):
def get_port(self) -> str:
"""
Gets IER port.
@@ -101,7 +101,7 @@ class IER(object):
"""
return self.port
def get_source_node_id(self):
def get_source_node_id(self) -> str:
"""
Gets IER source node ID.
@@ -110,7 +110,7 @@ class IER(object):
"""
return self.source_node_id
def get_dest_node_id(self):
def get_dest_node_id(self) -> str:
"""
Gets IER destination node ID.
@@ -119,7 +119,7 @@ class IER(object):
"""
return self.dest_node_id
def get_is_running(self):
def get_is_running(self) -> bool:
"""
Informs whether the IER is currently running.
@@ -128,7 +128,7 @@ class IER(object):
"""
return self.running
def set_is_running(self, _value):
def set_is_running(self, _value: bool) -> None:
"""
Sets the running state of the IER.
@@ -137,7 +137,7 @@ class IER(object):
"""
self.running = _value
def get_mission_criticality(self):
def get_mission_criticality(self) -> int:
"""
Gets the IER mission criticality (used in the reward function).

View File

@@ -1,9 +1,10 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Implements POL on the network (nodes and links) resulting from the red agent attack."""
from typing import Dict
from networkx import MultiGraph, shortest_path
from primaite import getLogger
from primaite.acl.access_control_list import AccessControlList
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import HardwareState, NodePOLInitiator, NodePOLType, NodeType, SoftwareState
@@ -13,7 +14,9 @@ from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from primaite.nodes.service_node import ServiceNode
from primaite.pol.ier import IER
_VERBOSE = False
_LOGGER = getLogger(__name__)
_VERBOSE: bool = False
def apply_red_agent_iers(
@@ -23,7 +26,7 @@ def apply_red_agent_iers(
iers: Dict[str, IER],
acl: AccessControlList,
step: int,
):
) -> None:
"""
Applies IERs to the links (link POL) resulting from red agent attack.
@@ -74,6 +77,9 @@ def apply_red_agent_iers(
pass
else:
# It's not a switch or an actuator (so active node)
# TODO: this occurs after ruling out the possibility that the node is a switch or an actuator, but it
# could still be a passive/active node, therefore it won't have a hardware_state. The logic here needs
# to change according to duck typing.
if source_node.hardware_state == HardwareState.ON:
if source_node.has_service(protocol):
# Red agents IERs can only be valid if the source service is in a compromised state
@@ -213,7 +219,7 @@ def apply_red_agent_node_pol(
iers: Dict[str, IER],
node_pol: Dict[str, NodeStateInstructionRed],
step: int,
):
) -> None:
"""
Applies node pattern of life.
@@ -267,8 +273,7 @@ def apply_red_agent_node_pol(
# Do nothing, service not on this node
pass
else:
if _VERBOSE:
print("Node Red Agent PoL not allowed - misconfiguration")
_LOGGER.warning("Node Red Agent PoL not allowed - misconfiguration")
# Only apply the PoL if the checks have passed (based on the initiator type)
if passed_checks:
@@ -289,8 +294,7 @@ def apply_red_agent_node_pol(
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
target_node.set_file_system_state(state)
else:
if _VERBOSE:
print("Node Red Agent PoL not allowed - did not pass checks")
_LOGGER.debug("Node Red Agent PoL not allowed - did not pass checks")
else:
# PoL is not valid in this time step
pass

View File

@@ -1,8 +1,9 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Main entry point to PrimAITE. Configure training/evaluation experiments and input/output."""
from __future__ import annotations
from pathlib import Path
from typing import Dict, Final, Optional, Union
from typing import Any, Dict, Final, Optional, Union
from primaite import getLogger
from primaite.agents.agent_abc import AgentSessionABC
@@ -31,7 +32,7 @@ class PrimaiteSession:
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
):
) -> None:
"""
The PrimaiteSession constructor.
@@ -71,7 +72,13 @@ class PrimaiteSession:
self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
def setup(self):
self._agent_session: AgentSessionABC = None # noqa
self.session_path: Path = None # noqa
self.timestamp_str: str = None # noqa
self.learning_path: Path = None # noqa
self.evaluation_path: Path = None # noqa
def setup(self) -> None:
"""Performs the session setup."""
if self._training_config.agent_framework == AgentFramework.CUSTOM:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}")
@@ -154,8 +161,8 @@ class PrimaiteSession:
def learn(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Train the agent.
@@ -166,8 +173,8 @@ class PrimaiteSession:
def evaluate(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
@@ -176,6 +183,6 @@ class PrimaiteSession:
if not self._training_config.session_type == SessionType.TRAIN:
self._agent_session.evaluate(**kwargs)
def close(self):
def close(self) -> None:
"""Closes the agent."""
self._agent_session.close()

View File

@@ -1,2 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Utilities to prepare the user's data folders."""
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.

View File

@@ -1,10 +1,15 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from typing import TYPE_CHECKING
from primaite import getLogger
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from logging import Logger
_LOGGER: Logger = getLogger(__name__)
def run():
def run() -> None:
"""Perform the full clean-up."""
pass

View File

@@ -1,17 +1,18 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import filecmp
import os
import shutil
from logging import Logger
from pathlib import Path
import pkg_resources
from primaite import getLogger, NOTEBOOKS_DIR
_LOGGER = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
def run(overwrite_existing: bool = True):
def run(overwrite_existing: bool = True) -> None:
"""
Resets the demo jupyter notebooks in the users app notebooks directory.

View File

@@ -1,16 +1,21 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import filecmp
import os
import shutil
from pathlib import Path
from typing import TYPE_CHECKING
import pkg_resources
from primaite import getLogger, USERS_CONFIG_DIR
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from logging import Logger
_LOGGER: Logger = getLogger(__name__)
def run(overwrite_existing=True):
def run(overwrite_existing: bool = True) -> None:
"""
Resets the example config files in the users app config directory.

View File

@@ -1,10 +1,12 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from logging import Logger
from primaite import _USER_DIRS, getLogger, LOG_DIR, NOTEBOOKS_DIR
_LOGGER = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
def run():
def run() -> None:
"""
Handles creation of application directories and user directories.

View File

@@ -1,2 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Record data of the system's state and agent's observations and actions."""
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.

View File

@@ -1,15 +1,19 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The Transaction class."""
from datetime import datetime
from typing import List, Tuple
from typing import List, Optional, Tuple, TYPE_CHECKING, Union
from primaite.common.enums import AgentIdentifier
if TYPE_CHECKING:
import numpy as np
from gym import spaces
class Transaction(object):
"""Transaction class."""
def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int):
def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int) -> None:
"""
Transaction constructor.
@@ -17,7 +21,7 @@ class Transaction(object):
:param episode_number: The episode number
:param step_number: The step number
"""
self.timestamp = datetime.now()
self.timestamp: datetime = datetime.now()
"The datetime of the transaction"
self.agent_identifier: AgentIdentifier = agent_identifier
"The agent identifier"
@@ -25,17 +29,17 @@ class Transaction(object):
"The episode number"
self.step_number: int = step_number
"The step number"
self.obs_space = None
self.obs_space: "spaces.Space" = None
"The observation space (pre)"
self.obs_space_pre = None
self.obs_space_pre: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None
"The observation space before any actions are taken"
self.obs_space_post = None
self.obs_space_post: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None
"The observation space after any actions are taken"
self.reward: float = None
self.reward: Optional[float] = None
"The reward value"
self.action_space = None
self.action_space: Optional[int] = None
"The action space invoked by the agent"
self.obs_space_description = None
self.obs_space_description: Optional[List[str]] = None
"The env observation space description"
def as_csv_data(self) -> Tuple[List, List]:
@@ -68,7 +72,7 @@ class Transaction(object):
return header, row
def _turn_action_space_to_array(action_space) -> List[str]:
def _turn_action_space_to_array(action_space: Union[int, List[int]]) -> List[str]:
"""
Turns action space into a string array so it can be saved to csv.
@@ -81,7 +85,7 @@ def _turn_action_space_to_array(action_space) -> List[str]:
return [str(action_space)]
def _turn_obs_space_to_array(obs_space, obs_assets, obs_features) -> List[str]:
def _turn_obs_space_to_array(obs_space: "np.ndarray", obs_assets: int, obs_features: int) -> List[str]:
"""
Turns observation space into a string array so it can be saved to csv.

View File

@@ -1 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Utilities for PrimAITE."""

View File

@@ -1,12 +1,13 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import os
from logging import Logger
from pathlib import Path
import pkg_resources
from primaite import getLogger
_LOGGER = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
def get_file_path(path: str) -> Path:

View File

@@ -1,6 +1,7 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import json
from pathlib import Path
from typing import Union
from typing import Any, Dict, Union
import yaml
@@ -9,7 +10,7 @@ from primaite import getLogger
_LOGGER = getLogger(__name__)
def parse_session_metadata(session_path: Union[Path, str], dict_only=False):
def parse_session_metadata(session_path: Union[Path, str], dict_only: bool = False) -> Dict[str, Any]:
"""
Loads a session metadata from the given directory path.

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from pathlib import Path
from typing import Any, Dict, Tuple, Union

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import csv
from logging import Logger
from typing import Final, List, Tuple, TYPE_CHECKING, Union
@@ -6,6 +7,9 @@ from primaite import getLogger
from primaite.transactions.transaction import Transaction
if TYPE_CHECKING:
from io import TextIOWrapper
from pathlib import Path
from primaite.environment.primaite_env import Primaite
_LOGGER: Logger = getLogger(__name__)
@@ -28,7 +32,7 @@ class SessionOutputWriter:
env: "Primaite",
transaction_writer: bool = False,
learning_session: bool = True,
):
) -> None:
"""
Initialise the Session Output Writer.
@@ -41,15 +45,16 @@ class SessionOutputWriter:
determines the name of the folder which contains the final output csv. Defaults to True
:type learning_session: bool, optional
"""
self._env = env
self.transaction_writer = transaction_writer
self.learning_session = learning_session
self._env: "Primaite" = env
self.transaction_writer: bool = transaction_writer
self.learning_session: bool = learning_session
if self.transaction_writer:
fn = f"all_transactions_{self._env.timestamp_str}.csv"
else:
fn = f"average_reward_per_episode_{self._env.timestamp_str}.csv"
self._csv_file_path: "Path"
if self.learning_session:
self._csv_file_path = self._env.session_path / "learning" / fn
else:
@@ -57,26 +62,26 @@ class SessionOutputWriter:
self._csv_file_path.parent.mkdir(exist_ok=True, parents=True)
self._csv_file = None
self._csv_writer = None
self._csv_file: "TextIOWrapper" = None
self._csv_writer: "csv._writer" = None
self._first_write: bool = True
def _init_csv_writer(self):
def _init_csv_writer(self) -> None:
self._csv_file = open(self._csv_file_path, "w", encoding="UTF8", newline="")
self._csv_writer = csv.writer(self._csv_file)
def __del__(self):
def __del__(self) -> None:
self.close()
def close(self):
def close(self) -> None:
"""Close the cvs file."""
if self._csv_file:
self._csv_file.close()
_LOGGER.debug(f"Finished writing file: {self._csv_file_path}")
def write(self, data: Union[Tuple, Transaction]):
def write(self, data: Union[Tuple, Transaction]) -> None:
"""
Write a row of session data.

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from pathlib import Path
from typing import Final

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Main Config File
# Generic config values

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Main Config File
# Generic config values

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
- item_type: PORTS
ports_list:
- port: '80'

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
- item_type: PORTS
ports_list:
- port: '21'

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
- item_type: PORTS
ports_list:
- port: '80'

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import datetime
import json
import shutil

View File

@@ -0,0 +1 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import tempfile
from datetime import datetime
from pathlib import Path

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Used to tes the ACL functions."""
from primaite.acl.access_control_list import AccessControlList

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Used to test Active Node functions."""
import pytest

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Test env creation and behaviour with different observation spaces."""
import numpy as np

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import os
import pytest

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import pytest
from primaite.config.lay_down_config import data_manipulation_config_path

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Used to test Active Node functions."""
import pytest

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import pytest
from primaite import getLogger

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import pytest
from primaite import getLogger

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import pytest as pytest
from primaite.config.lay_down_config import dos_very_basic_config_path

Some files were not shown because too many files have changed in this diff Show More