Merge branch 'dev' into feature/1632_Add_benchmarking_scripts

This commit is contained in:
Chris McCarthy
2023-07-18 13:24:15 +01:00
68 changed files with 1473 additions and 553 deletions

View File

@@ -66,11 +66,11 @@ The environment config file consists of the following attributes:
.. code-block:: yaml
observation_space:
flatten: true
components:
- name: NODE_LINK_TABLE
- name: NODE_STATUSES
- name: LINK_TRAFFIC_LEVELS
- name: ACCESS_CONTROL_LIST
options:
combine_service_traffic : False
quantisation_levels: 99
@@ -80,6 +80,7 @@ The environment config file consists of the following attributes:
* :py:mod:`NODE_LINK_TABLE<primaite.environment.observations.NodeLinkTable>` this does not accept any additional options
* :py:mod:`NODE_STATUSES<primaite.environment.observations.NodeStatuses>`, this does not accept any additional options
* :py:mod:`ACCESS_CONTROL_LIST<primaite.environment.observations.AccessControlList>`, this does not accept additional options
* :py:mod:`LINK_TRAFFIC_LEVELS<primaite.environment.observations.LinkTrafficLevels>`, this accepts the following options:
* ``combine_service_traffic`` - whether to consider bandwidth use separately for each network protocol or combine them into a single bandwidth reading (boolean)
@@ -128,6 +129,14 @@ The environment config file consists of the following attributes:
The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases
* **implicit_acl_rule** [str]
Determines which Explicit rule the ACL list has - two options are: DENY or ALLOW.
* **max_number_acl_rules** [int]
Sets a limit on how many ACL rules there can be in the ACL list throughout the training session.
**Reward-Based Config Values**
Rewards are calculated based on the difference between the current state and reference state (the 'should be' state) of the environment.
@@ -477,3 +486,4 @@ The lay down config file consists of the following attributes:
* **destination** [IP address]: Defines the destination IP address for the rule in xxx.xxx.xxx.xxx format
* **protocol** [freetext]: Defines the protocol for the rule. Must match a value in the services list
* **port** [int]: Defines the port for the rule. Must match a value in the ports list
* **position** [int]: Defines where to place the ACL rule in the list. Lower index or (higher up in the list) means they are checked first. Index starts at 0 (Python indexes).

View File

@@ -53,3 +53,5 @@ v1.2 to v2.0 Migration guide
* hard coded agent view
Each of these items have default values which are designed so that PrimAITE has the same behaviour as it did in 1.2.0, so you do not have to specify them.
ACL Rules in laydown configs have a new required parameter: ``position``. The lower the position, the higher up in the ACL table the rule will placed. If you have custom laydowns, you will need to go through them and add a position to each ACL_RULE.

View File

@@ -47,6 +47,105 @@ The sub-directory is formatted as such: ``~/primaite/sessions/<yyyy-mm-dd>/<yyyy
For example, when running a session at 17:30:00 on 31st January 2023, the session will output to:
``~/primaite/sessions/2023-01-31/2023-01-31_17-30-00/``.
Outputs
-------
PrimAITE produces four types of outputs:
* Session Metadata
* Results
* Diagrams
* Saved agents (training checkpoints and a final trained agent)
**Session Metadata**
PrimAITE creates a ``session_metadata.json`` file that contains the following metadata:
* **uuid** - The UUID assigned to the session upon instantiation.
* **start_datetime** - The date & time the session started in iso format.
* **end_datetime** - The date & time the session ended in iso format.
* **learning**
* **total_episodes** - The total number of training episodes completed.
* **total_time_steps** - The total number of training time steps completed.
* **evaluation**
* **total_episodes** - The total number of evaluation episodes completed.
* **total_time_steps** - The total number of evaluation time steps completed.
* **env**
* **training_config**
* **All training config items**
* **lay_down_config**
* **All lay down config items**
**Results**
PrimAITE automatically creates two sets of results from each learning and evaluation session:
* Average reward per episode - a csv file listing the average reward for each episode of the session. This provides, for example, an indication of the change over a training session of the reward value
* All transactions - a csv file listing the following values for every step of every episode:
* Timestamp
* Episode number
* Step number
* Reward value
* Action taken (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X
* Initial observation space (what the blue agent observed when it decided its action)
**Diagrams**
* For each session, PrimAITE automatically creates a visualisation of the system / network lay down configuration.
* For each learning and evaluation task within the session, PrimAITE automatically plots the average reward per episode using PlotLY and saves it to the learning or evaluation subdirectory in the session directory.
**Saved agents**
For each training session, assuming the agent being trained implements the *save()* function and this function is called by the code, PrimAITE automatically saves the agent state.
**Example Session Directory Structure**
.. code-block:: text
~/
└── primaite/
└── sessions/
└── 2023-07-18/
└── 2023-07-18_11-06-04/
├── evaluation/
│ ├── all_transactions_2023-07-18_11-06-04.csv
│ ├── average_reward_per_episode_2023-07-18_11-06-04.csv
│ └── average_reward_per_episode_2023-07-18_11-06-04.png
├── learning/
│ ├── all_transactions_2023-07-18_11-06-04.csv
│ ├── average_reward_per_episode_2023-07-18_11-06-04.csv
│ ├── average_reward_per_episode_2023-07-18_11-06-04.png
│ ├── checkpoints/
│ │ └── sb3ppo_10.zip
│ ├── SB3_PPO.zip
│ └── tensorboard_logs/
│ ├── PPO_1/
│ │ └── events.out.tfevents.1689674765.METD-9PMRFB3.42960.0
│ ├── PPO_2/
│ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.1
│ ├── PPO_3/
│ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.2
│ ├── PPO_4/
│ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.3
│ ├── PPO_5/
│ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.4
│ ├── PPO_6/
│ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.5
│ ├── PPO_7/
│ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.6
│ ├── PPO_8/
│ │ └── events.out.tfevents.1689674769.METD-9PMRFB3.42960.7
│ ├── PPO_9/
│ │ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.8
│ └── PPO_10/
│ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.9
├── network_2023-07-18_11-06-04.png
└── session_metadata.json
Loading a session
-----------------
@@ -78,52 +177,3 @@ A previous session can be loaded by providing the **directory** of the previous
run(session_path=<previous session directory>)
When PrimAITE runs a loaded session, PrimAITE will output in the provided session directory
Outputs
-------
PrimAITE produces four types of outputs:
* Session Metadata
* Results
* Diagrams
* Saved agents
**Session Metadata**
PrimAITE creates a ``session_metadata.json`` file that contains the following metadata:
* **uuid** - The UUID assigned to the session upon instantiation.
* **start_datetime** - The date & time the session started in iso format.
* **end_datetime** - The date & time the session ended in iso format.
* **total_episodes** - The total number of training episodes completed.
* **total_time_steps** - The total number of training time steps completed.
* **env**
* **training_config**
* **All training config items**
* **lay_down_config**
* **All lay down config items**
**Results**
PrimAITE automatically creates two sets of results from each session:
* Average reward per episode - a csv file listing the average reward for each episode of the session. This provides, for example, an indication of the change over a training session of the reward value
* All transactions - a csv file listing the following values for every step of every episode:
* Timestamp
* Episode number
* Step number
* Reward value
* Action taken (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X
* Initial observation space (what the blue agent observed when it decided its action)
**Diagrams**
For each session, PrimAITE automatically creates a visualisation of the system / network lay down configuration.
**Saved agents**
For each training session, assuming the agent being trained implements the *save()* function and this function is called by the code, PrimAITE automatically saves the agent state.

View File

@@ -6,7 +6,7 @@ from bisect import bisect
from logging import Formatter, Logger, LogRecord, StreamHandler
from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import Dict, Final
from typing import Any, Dict, Final
import pkg_resources
import yaml
@@ -16,7 +16,7 @@ _PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite")
"""An instance of `PlatformDirs` set with appname='primaite'."""
def _get_primaite_config():
def _get_primaite_config() -> Dict:
config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml"
if not config_path.exists():
config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
@@ -72,7 +72,7 @@ class _LevelFormatter(Formatter):
Credit to: https://stackoverflow.com/a/68154386
"""
def __init__(self, formats: Dict[int, str], **kwargs):
def __init__(self, formats: Dict[int, str], **kwargs: Any) -> None:
super().__init__()
if "fmt" in kwargs:

View File

@@ -1,16 +1,38 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""A class that implements the access control list implementation for the network."""
from typing import Dict
import logging
from typing import Dict, Final, List, Union
from primaite.acl.acl_rule import ACLRule
from primaite.common.enums import RulePermissionType
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
class AccessControlList:
"""Access Control List class."""
def __init__(self):
"""Initialise an empty AccessControlList."""
self.acl: Dict[str, ACLRule] = {} # A dictionary of ACL Rules
def __init__(self, implicit_permission: RulePermissionType, max_acl_rules: int) -> None:
"""Init."""
# Implicit ALLOW or DENY firewall spec
self.acl_implicit_permission = implicit_permission
# Implicit rule in ACL list
if self.acl_implicit_permission == RulePermissionType.DENY:
self.acl_implicit_rule = ACLRule(RulePermissionType.DENY, "ANY", "ANY", "ANY", "ANY")
elif self.acl_implicit_permission == RulePermissionType.ALLOW:
self.acl_implicit_rule = ACLRule(RulePermissionType.ALLOW, "ANY", "ANY", "ANY", "ANY")
else:
raise ValueError(f"implicit permission must be ALLOW or DENY, got {self.acl_implicit_permission}")
# Maximum number of ACL Rules in ACL
self.max_acl_rules: int = max_acl_rules
# A list of ACL Rules
self._acl: List[Union[ACLRule, None]] = [None] * (self.max_acl_rules - 1)
@property
def acl(self) -> List[Union[ACLRule, None]]:
"""Public access method for private _acl."""
return self._acl + [self.acl_implicit_rule]
def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool:
"""Checks for IP address matches.
@@ -47,21 +69,30 @@ class AccessControlList:
Returns:
Indicates block if all conditions are satisfied.
"""
for rule_key, rule_value in self.acl.items():
if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address):
if (rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY") and (
str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY"
):
# There's a matching rule. Get the permission
if rule_value.get_permission() == "DENY":
return True
elif rule_value.get_permission() == "ALLOW":
return False
for rule in self.acl:
if isinstance(rule, ACLRule):
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
if (rule.get_protocol() == _protocol or rule.get_protocol() == "ANY") and (
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY"
):
# There's a matching rule. Get the permission
if rule.get_permission() == RulePermissionType.DENY:
return True
elif rule.get_permission() == RulePermissionType.ALLOW:
return False
# If there has been no rule to allow the IER through, it will return a blocked signal by default
return True
def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
def add_rule(
self,
_permission: RulePermissionType,
_source_ip: str,
_dest_ip: str,
_protocol: str,
_port: str,
_position: str,
) -> None:
"""
Adds a new rule.
@@ -71,12 +102,36 @@ class AccessControlList:
_dest_ip: the destination IP address
_protocol: the protocol
_port: the port
_position: position to insert ACL rule into ACL list (starting from index 1 and NOT 0)
"""
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
hash_value = hash(new_rule)
self.acl[hash_value] = new_rule
try:
position_index = int(_position)
except TypeError:
_LOGGER.info(f"Position {_position} could not be converted to integer.")
return
def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
# Checks position is in correct range
if self.max_acl_rules - 1 > position_index > -1:
try:
_LOGGER.info(f"Position {position_index} is valid.")
# Check to see Agent will not overwrite current ACL in ACL list
if self._acl[position_index] is None:
_LOGGER.info(f"Inserting rule {new_rule} at position {position_index}")
# Adds rule
self._acl[position_index] = new_rule
else:
# Cannot overwrite it
_LOGGER.info(f"Error: inserting rule at non-empty position {position_index}")
return
except Exception:
_LOGGER.info(f"New Rule could NOT be added to list at position {position_index}.")
else:
_LOGGER.info(f"Position {position_index} is an invalid/overwrites implicit firewall rule")
def remove_rule(
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
) -> None:
"""
Removes a rule.
@@ -87,19 +142,21 @@ class AccessControlList:
_protocol: the protocol
_port: the port
"""
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
hash_value = hash(rule)
# There will not always be something 'popable' since the agent will be trying random things
try:
self.acl.pop(hash_value)
except Exception:
return
rule_to_delete = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
delete_rule_hash = hash(rule_to_delete)
def remove_all_rules(self):
for index in range(0, len(self._acl)):
if isinstance(self._acl[index], ACLRule) and hash(self._acl[index]) == delete_rule_hash:
self._acl[index] = None
def remove_all_rules(self) -> None:
"""Removes all rules."""
self.acl.clear()
for i in range(len(self._acl)):
self._acl[i] = None
def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port):
def get_dictionary_hash(
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
) -> int:
"""
Produces a hash value for a rule.
@@ -117,7 +174,9 @@ class AccessControlList:
hash_value = hash(rule)
return hash_value
def get_relevant_rules(self, _source_ip_address, _dest_ip_address, _protocol, _port):
def get_relevant_rules(
self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str
) -> Dict[int, ACLRule]:
"""Get all ACL rules that relate to the given arguments.
:param _source_ip_address: the source IP address to check
@@ -125,18 +184,15 @@ class AccessControlList:
:param _protocol: the protocol to check
:param _port: the port to check
:return: Dictionary of all ACL rules that relate to the given arguments
:rtype: Dict[str, ACLRule]
:rtype: Dict[int, ACLRule]
"""
relevant_rules = {}
for rule_key, rule_value in self.acl.items():
if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address):
if (
rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY" or _protocol == "ANY"
) and (
str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY" or str(_port) == "ANY"
for rule in self.acl:
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
if (rule.get_protocol() == _protocol or rule.get_protocol() == "ANY" or _protocol == "ANY") and (
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY" or str(_port) == "ANY"
):
# There's a matching rule.
relevant_rules[rule_key] = rule_value
relevant_rules[self._acl.index(rule)] = rule
return relevant_rules

View File

@@ -1,11 +1,14 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""A class that implements an access control list rule."""
from primaite.common.enums import RulePermissionType
class ACLRule:
"""Access Control List Rule class."""
def __init__(self, _permission, _source_ip, _dest_ip, _protocol, _port):
def __init__(
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
) -> None:
"""
Initialise an ACL Rule.
@@ -15,13 +18,13 @@ class ACLRule:
:param _protocol: The rule protocol
:param _port: The rule port
"""
self.permission = _permission
self.source_ip = _source_ip
self.dest_ip = _dest_ip
self.protocol = _protocol
self.port = _port
self.permission: RulePermissionType = _permission
self.source_ip: str = _source_ip
self.dest_ip: str = _dest_ip
self.protocol: str = _protocol
self.port: str = _port
def __hash__(self):
def __hash__(self) -> int:
"""
Override the hash function.
@@ -38,7 +41,7 @@ class ACLRule:
)
)
def get_permission(self):
def get_permission(self) -> str:
"""
Gets the permission attribute.
@@ -47,7 +50,7 @@ class ACLRule:
"""
return self.permission
def get_source_ip(self):
def get_source_ip(self) -> str:
"""
Gets the source IP address attribute.
@@ -56,7 +59,7 @@ class ACLRule:
"""
return self.source_ip
def get_dest_ip(self):
def get_dest_ip(self) -> str:
"""
Gets the desintation IP address attribute.
@@ -65,7 +68,7 @@ class ACLRule:
"""
return self.dest_ip
def get_protocol(self):
def get_protocol(self) -> str:
"""
Gets the protocol attribute.
@@ -74,7 +77,7 @@ class ACLRule:
"""
return self.protocol
def get_port(self):
def get_port(self) -> str:
"""
Gets the port attribute.

View File

@@ -4,8 +4,9 @@ from __future__ import annotations
import json
from abc import ABC, abstractmethod
from datetime import datetime
from logging import Logger
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Any, Dict, Optional, Union
from uuid import uuid4
import primaite
@@ -16,7 +17,7 @@ from primaite.data_viz.session_plots import plot_av_reward_per_episode
from primaite.environment.primaite_env import Primaite
from primaite.utils.session_metadata_parser import parse_session_metadata
_LOGGER = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
def get_session_path(session_timestamp: datetime) -> Path:
@@ -51,7 +52,7 @@ class AgentSessionABC(ABC):
training_config_path: Optional[Union[str, Path]] = None,
lay_down_config_path: Optional[Union[str, Path]] = None,
session_path: Optional[Union[str, Path]] = None,
):
) -> None:
"""
Initialise an agent session from config files, or load a previous session.
@@ -131,11 +132,11 @@ class AgentSessionABC(ABC):
return path
@property
def uuid(self):
def uuid(self) -> str:
"""The Agent Session UUID."""
return self._uuid
def _write_session_metadata_file(self):
def _write_session_metadata_file(self) -> None:
"""
Write the ``session_metadata.json`` file.
@@ -171,7 +172,7 @@ class AgentSessionABC(ABC):
json.dump(metadata_dict, file)
_LOGGER.debug("Finished writing session metadata file")
def _update_session_metadata_file(self):
def _update_session_metadata_file(self) -> None:
"""
Update the ``session_metadata.json`` file.
@@ -200,7 +201,7 @@ class AgentSessionABC(ABC):
_LOGGER.debug("Finished updating session metadata file")
@abstractmethod
def _setup(self):
def _setup(self) -> None:
_LOGGER.info(
"Welcome to the Primary-level AI Training Environment " f"(PrimAITE) (version: {primaite.__version__})"
)
@@ -210,14 +211,14 @@ class AgentSessionABC(ABC):
self._can_evaluate = False
@abstractmethod
def _save_checkpoint(self):
def _save_checkpoint(self) -> None:
pass
@abstractmethod
def learn(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Train the agent.
@@ -234,8 +235,8 @@ class AgentSessionABC(ABC):
@abstractmethod
def evaluate(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
@@ -248,10 +249,10 @@ class AgentSessionABC(ABC):
_LOGGER.info("Finished evaluation")
@abstractmethod
def _get_latest_checkpoint(self):
def _get_latest_checkpoint(self) -> None:
pass
def load(self, path: Union[str, Path]):
def load(self, path: Union[str, Path]) -> None:
"""Load an agent from file."""
md_dict, training_config_path, laydown_config_path = parse_session_metadata(path)
@@ -275,21 +276,21 @@ class AgentSessionABC(ABC):
return self.learning_path / file_name
@abstractmethod
def save(self):
def save(self) -> None:
"""Save the agent."""
pass
@abstractmethod
def export(self):
def export(self) -> None:
"""Export the agent to transportable file format."""
pass
def close(self):
def close(self) -> None:
"""Closes the agent."""
self._env.episode_av_reward_writer.close() # noqa
self._env.transaction_writer.close() # noqa
def _plot_av_reward_per_episode(self, learning_session: bool = True):
def _plot_av_reward_per_episode(self, learning_session: bool = True) -> None:
# self.close()
title = f"PrimAITE Session {self.timestamp_str} "
subtitle = str(self._training_config)

View File

@@ -2,7 +2,9 @@
import time
from abc import abstractmethod
from pathlib import Path
from typing import Optional, Union
from typing import Any, Optional, Union
import numpy as np
from primaite import getLogger
from primaite.agents.agent_abc import AgentSessionABC
@@ -24,7 +26,7 @@ class HardCodedAgentSessionABC(AgentSessionABC):
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
):
) -> None:
"""
Initialise a hardcoded agent session.
@@ -37,7 +39,7 @@ class HardCodedAgentSessionABC(AgentSessionABC):
super().__init__(training_config_path, lay_down_config_path, session_path)
self._setup()
def _setup(self):
def _setup(self) -> None:
self._env: Primaite = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
@@ -48,16 +50,16 @@ class HardCodedAgentSessionABC(AgentSessionABC):
self._can_learn = False
self._can_evaluate = True
def _save_checkpoint(self):
def _save_checkpoint(self) -> None:
pass
def _get_latest_checkpoint(self):
def _get_latest_checkpoint(self) -> None:
pass
def learn(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Train the agent.
@@ -66,13 +68,13 @@ class HardCodedAgentSessionABC(AgentSessionABC):
_LOGGER.warning("Deterministic agents cannot learn")
@abstractmethod
def _calculate_action(self, obs):
def _calculate_action(self, obs: np.ndarray) -> None:
pass
def evaluate(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
@@ -103,14 +105,14 @@ class HardCodedAgentSessionABC(AgentSessionABC):
self._env.close()
@classmethod
def load(cls, path=None):
def load(cls, path: Union[str, Path] = None) -> None:
"""Load an agent from file."""
_LOGGER.warning("Deterministic agents cannot be loaded")
def save(self):
def save(self) -> None:
"""Save the agent."""
_LOGGER.warning("Deterministic agents cannot be saved")
def export(self):
def export(self) -> None:
"""Export the agent to transportable file format."""
_LOGGER.warning("Deterministic agents cannot be exported")

View File

@@ -1,5 +1,5 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from typing import Any, Dict, List, Union
from typing import Dict, List, Union
import numpy as np
@@ -33,7 +33,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
def get_blocked_green_iers(
self, green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[Any, Any]:
) -> Dict[str, IER]:
"""Get blocked green IERs.
:param green_iers: Green IERs to check for being
@@ -61,7 +61,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
return blocked_green_iers
def get_matching_acl_rules_for_ier(self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]):
def get_matching_acl_rules_for_ier(
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[int, ACLRule]:
"""Get list of ACL rules which are relevant to an IER.
:param ier: Information Exchange Request to query against the ACL list
@@ -84,7 +86,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
def get_blocking_acl_rules_for_ier(
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[str, Any]:
) -> Dict[int, ACLRule]:
"""
Get blocking ACL rules for an IER.
@@ -112,7 +114,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
def get_allow_acl_rules_for_ier(
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[str, Any]:
) -> Dict[int, ACLRule]:
"""Get all allowing ACL rules for an IER.
:param ier: Information Exchange Request to query against the ACL list
@@ -142,7 +144,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
acl: AccessControlList,
nodes: Dict[str, Union[ServiceNode, ActiveNode]],
services_list: List[str],
) -> Dict[str, ACLRule]:
) -> Dict[int, ACLRule]:
"""Filter ACL rules to only those which are relevant to the specified nodes.
:param source_node_id: Source node
@@ -174,6 +176,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
if protocol != "ANY":
protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services
# TODO: This should throw an error because protocol is a string
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
return matching_rules
@@ -187,7 +190,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
acl: AccessControlList,
nodes: Dict[str, NodeUnion],
services_list: List[str],
) -> Dict[str, ACLRule]:
) -> Dict[int, ACLRule]:
"""List ALLOW rules relating to specified nodes.
:param source_node_id: Source node id
@@ -234,7 +237,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
acl: AccessControlList,
nodes: Dict[str, NodeUnion],
services_list: List[str],
) -> Dict[str, ACLRule]:
) -> Dict[int, ACLRule]:
"""List DENY rules relating to specified nodes.
:param source_node_id: Source node id

View File

@@ -102,6 +102,7 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
property_action,
action_service_index,
]
# TODO: transform_action_node_enum takes only one argument, not sure why two are given here.
action = transform_action_node_enum(action, action_dict)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step

View File

@@ -4,8 +4,9 @@ from __future__ import annotations
import json
import shutil
from datetime import datetime
from logging import Logger
from pathlib import Path
from typing import Optional, Union
from typing import Any, Callable, Dict, Optional, Union
from uuid import uuid4
from ray.rllib.algorithms import Algorithm
@@ -19,10 +20,11 @@ from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
def _env_creator(env_config):
# TODO: verify type of env_config
def _env_creator(env_config: Dict[str, Any]) -> Primaite:
return Primaite(
training_config_path=env_config["training_config_path"],
lay_down_config_path=env_config["lay_down_config_path"],
@@ -31,11 +33,12 @@ def _env_creator(env_config):
)
def _custom_log_creator(session_path: Path):
# TODO: verify type hint return type
def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]:
logdir = session_path / "ray_results"
logdir.mkdir(parents=True, exist_ok=True)
def logger_creator(config):
def logger_creator(config: Dict) -> UnifiedLogger:
return UnifiedLogger(config, logdir, loggers=None)
return logger_creator
@@ -49,7 +52,7 @@ class RLlibAgent(AgentSessionABC):
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
):
) -> None:
"""
Initialise the RLLib Agent training session.
@@ -74,6 +77,7 @@ class RLlibAgent(AgentSessionABC):
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
raise ValueError(msg)
self._agent_config_class: Union[PPOConfig, A2CConfig]
if self._training_config.agent_identifier == AgentIdentifier.PPO:
self._agent_config_class = PPOConfig
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
@@ -95,7 +99,7 @@ class RLlibAgent(AgentSessionABC):
f"{self._training_config.deep_learning_framework}"
)
def _update_session_metadata_file(self):
def _update_session_metadata_file(self) -> None:
"""
Update the ``session_metadata.json`` file.
@@ -123,7 +127,7 @@ class RLlibAgent(AgentSessionABC):
json.dump(metadata_dict, file)
_LOGGER.debug("Finished updating session metadata file")
def _setup(self):
def _setup(self) -> None:
super()._setup()
register_env("primaite", _env_creator)
self._agent_config = self._agent_config_class()
@@ -149,7 +153,7 @@ class RLlibAgent(AgentSessionABC):
)
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
def _save_checkpoint(self):
def _save_checkpoint(self) -> None:
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._current_result["episodes_total"]
save_checkpoint = False
@@ -160,8 +164,8 @@ class RLlibAgent(AgentSessionABC):
def learn(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
@@ -181,8 +185,8 @@ class RLlibAgent(AgentSessionABC):
def evaluate(
self,
**kwargs,
):
**kwargs: None,
) -> None:
"""
Evaluate the agent.
@@ -190,7 +194,7 @@ class RLlibAgent(AgentSessionABC):
"""
raise NotImplementedError
def _get_latest_checkpoint(self):
def _get_latest_checkpoint(self) -> None:
raise NotImplementedError
@classmethod
@@ -198,7 +202,7 @@ class RLlibAgent(AgentSessionABC):
"""Load an agent from file."""
raise NotImplementedError
def save(self, overwrite_existing: bool = True):
def save(self, overwrite_existing: bool = True) -> None:
"""Save the agent."""
# Make temp dir to save in isolation
temp_dir = self.learning_path / str(uuid4())
@@ -218,6 +222,6 @@ class RLlibAgent(AgentSessionABC):
# Drop the temp directory
shutil.rmtree(temp_dir)
def export(self):
def export(self) -> None:
"""Export the agent to transportable file format."""
raise NotImplementedError

View File

@@ -2,8 +2,9 @@
from __future__ import annotations
import json
from logging import Logger
from pathlib import Path
from typing import Optional, Union
from typing import Any, Optional, Union
import numpy as np
from stable_baselines3 import A2C, PPO
@@ -14,7 +15,7 @@ from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
class SB3Agent(AgentSessionABC):
@@ -25,7 +26,7 @@ class SB3Agent(AgentSessionABC):
training_config_path: Optional[Union[str, Path]] = None,
lay_down_config_path: Optional[Union[str, Path]] = None,
session_path: Optional[Union[str, Path]] = None,
):
) -> None:
"""
Initialise the SB3 Agent training session.
@@ -43,6 +44,7 @@ class SB3Agent(AgentSessionABC):
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
raise ValueError(msg)
self._agent_class: Union[PPO, A2C]
if self._training_config.agent_identifier == AgentIdentifier.PPO:
self._agent_class = PPO
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
@@ -66,7 +68,7 @@ class SB3Agent(AgentSessionABC):
self._setup()
def _setup(self):
def _setup(self) -> None:
"""Set up the SB3 Agent."""
self._env = Primaite(
training_config_path=self._training_config_path,
@@ -113,7 +115,7 @@ class SB3Agent(AgentSessionABC):
super()._setup()
def _save_checkpoint(self):
def _save_checkpoint(self) -> None:
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._env.episode_count
save_checkpoint = False
@@ -124,13 +126,13 @@ class SB3Agent(AgentSessionABC):
self._agent.save(checkpoint_path)
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
def _get_latest_checkpoint(self):
def _get_latest_checkpoint(self) -> None:
pass
def learn(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Train the agent.
@@ -153,8 +155,8 @@ class SB3Agent(AgentSessionABC):
def evaluate(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
@@ -183,10 +185,10 @@ class SB3Agent(AgentSessionABC):
self._env.close()
super().evaluate()
def save(self):
def save(self) -> None:
"""Save the agent."""
self._agent.save(self._saved_agent_path)
def export(self):
def export(self) -> None:
"""Export the agent to transportable file format."""
raise NotImplementedError

View File

@@ -1,4 +1,7 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import numpy as np
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
@@ -10,7 +13,7 @@ class RandomAgent(HardCodedAgentSessionABC):
Get a completely random action from the action space.
"""
def _calculate_action(self, obs):
def _calculate_action(self, obs: np.ndarray) -> int:
return self._env.action_space.sample()
@@ -21,7 +24,7 @@ class DummyAgent(HardCodedAgentSessionABC):
All action spaces setup so dummy action is always 0 regardless of action type used.
"""
def _calculate_action(self, obs):
def _calculate_action(self, obs: np.ndarray) -> int:
return 0
@@ -32,7 +35,7 @@ class DoNothingACLAgent(HardCodedAgentSessionABC):
A valid ACL action that has no effect; does nothing.
"""
def _calculate_action(self, obs):
def _calculate_action(self, obs: np.ndarray) -> int:
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
nothing_action = transform_action_acl_enum(nothing_action)
nothing_action = get_new_action(nothing_action, self._env.action_dict)
@@ -47,7 +50,7 @@ class DoNothingNodeAgent(HardCodedAgentSessionABC):
A valid Node action that has no effect; does nothing.
"""
def _calculate_action(self, obs):
def _calculate_action(self, obs: np.ndarray) -> int:
nothing_action = [1, "NONE", "ON", 0]
nothing_action = transform_action_node_enum(nothing_action)
nothing_action = get_new_action(nothing_action, self._env.action_dict)

View File

@@ -35,11 +35,11 @@ def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]:
else:
property_action = "NONE"
new_action = [action[0], action_node_property, property_action, action[3]]
new_action: list[Union[int, str]] = [action[0], action_node_property, property_action, action[3]]
return new_action
def transform_action_acl_readable(action: List[str]) -> List[Union[str, int]]:
def transform_action_acl_readable(action: List[int]) -> List[Union[str, int]]:
"""
Transform an ACL action to a more readable format.

View File

@@ -19,7 +19,7 @@ app = typer.Typer()
@app.command()
def build_dirs():
def build_dirs() -> None:
"""Build the PrimAITE app directories."""
from primaite.setup import setup_app_dirs
@@ -27,7 +27,7 @@ def build_dirs():
@app.command()
def reset_notebooks(overwrite: bool = True):
def reset_notebooks(overwrite: bool = True) -> None:
"""
Force a reset of the demo notebooks in the users notebooks directory.
@@ -39,7 +39,7 @@ def reset_notebooks(overwrite: bool = True):
@app.command()
def logs(last_n: Annotated[int, typer.Option("-n")]):
def logs(last_n: Annotated[int, typer.Option("-n")]) -> None:
"""
Print the PrimAITE log file.
@@ -60,7 +60,7 @@ _LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # n
@app.command()
def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None):
def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None) -> None:
"""
View or set the PrimAITE Log Level.
@@ -88,7 +88,7 @@ def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None):
@app.command()
def notebooks():
def notebooks() -> None:
"""Start Jupyter Lab in the users PrimAITE notebooks directory."""
from primaite.notebooks import start_jupyter_session
@@ -96,7 +96,7 @@ def notebooks():
@app.command()
def version():
def version() -> None:
"""Get the installed PrimAITE version number."""
import primaite
@@ -104,7 +104,7 @@ def version():
@app.command()
def clean_up():
def clean_up() -> None:
"""Cleans up left over files from previous version installations."""
from primaite.setup import old_installation_clean_up
@@ -112,7 +112,7 @@ def clean_up():
@app.command()
def setup(overwrite_existing: bool = True):
def setup(overwrite_existing: bool = True) -> None:
"""
Perform the PrimAITE first-time setup.
@@ -151,7 +151,7 @@ def setup(overwrite_existing: bool = True):
@app.command()
def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[str] = None):
def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[str] = None) -> None:
"""
Run a PrimAITE session.
@@ -185,7 +185,7 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[
@app.command()
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None):
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None) -> None:
"""
View or set the plotly template for Session plots.

View File

@@ -1,9 +1,8 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from typing import Type, Union
from typing import Union
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.passive_node import PassiveNode
from primaite.nodes.service_node import ServiceNode
NodeUnion: Type = Union[ActiveNode, PassiveNode, ServiceNode]
NodeUnion = Union[ActiveNode, PassiveNode, ServiceNode]
"""A Union of ActiveNode, PassiveNode, and ServiceNode."""

View File

@@ -148,6 +148,7 @@ class ActionType(Enum):
ANY = 2
# TODO: this is not used anymore, write a ticket to delete it.
class ObservationType(Enum):
"""Observation type enumeration."""
@@ -197,3 +198,11 @@ class SB3OutputVerboseLevel(IntEnum):
NONE = 0
INFO = 1
DEBUG = 2
class RulePermissionType(Enum):
"""Any firewall rule type."""
NONE = 0
DENY = 1
ALLOW = 2

View File

@@ -5,17 +5,17 @@
class Protocol(object):
"""Protocol class."""
def __init__(self, _name):
def __init__(self, _name: str) -> None:
"""
Initialise a protocol.
:param _name: The name of the protocol
:type _name: str
"""
self.name = _name
self.load = 0 # bps
self.name: str = _name
self.load: int = 0 # bps
def get_name(self):
def get_name(self) -> str:
"""
Gets the protocol name.
@@ -24,7 +24,7 @@ class Protocol(object):
"""
return self.name
def get_load(self):
def get_load(self) -> int:
"""
Gets the protocol load.
@@ -33,7 +33,7 @@ class Protocol(object):
"""
return self.load
def add_load(self, _load):
def add_load(self, _load: int) -> None:
"""
Adds load to the protocol.
@@ -42,6 +42,6 @@ class Protocol(object):
"""
self.load += _load
def clear_load(self):
def clear_load(self) -> None:
"""Clears the load on this protocol."""
self.load = 0

View File

@@ -7,7 +7,7 @@ from primaite.common.enums import SoftwareState
class Service(object):
"""Service class."""
def __init__(self, name: str, port: str, software_state: SoftwareState):
def __init__(self, name: str, port: str, software_state: SoftwareState) -> None:
"""
Initialise a service.
@@ -15,12 +15,12 @@ class Service(object):
:param port: The service port.
:param software_state: The service SoftwareState.
"""
self.name = name
self.port = port
self.software_state = software_state
self.patching_count = 0
self.name: str = name
self.port: str = port
self.software_state: SoftwareState = software_state
self.patching_count: int = 0
def reduce_patching_count(self):
def reduce_patching_count(self) -> None:
"""Reduces the patching count for the service."""
self.patching_count -= 1
if self.patching_count <= 0:

View File

@@ -163,3 +163,4 @@
destination: ANY
protocol: ANY
port: ANY
position: 0

View File

@@ -243,6 +243,7 @@
destination: 192.168.10.14
protocol: TCP
port: 80
position: 0
- item_type: ACL_RULE
id: '26'
permission: ALLOW
@@ -250,6 +251,7 @@
destination: 192.168.10.14
protocol: TCP
port: 80
position: 1
- item_type: ACL_RULE
id: '27'
permission: ALLOW
@@ -257,6 +259,7 @@
destination: 192.168.10.14
protocol: TCP
port: 80
position: 2
- item_type: ACL_RULE
id: '28'
permission: ALLOW
@@ -264,6 +267,7 @@
destination: 192.168.20.15
protocol: TCP
port: 80
position: 3
- item_type: ACL_RULE
id: '29'
permission: ALLOW
@@ -271,6 +275,7 @@
destination: 192.168.10.13
protocol: TCP
port: 80
position: 4
- item_type: ACL_RULE
id: '30'
permission: DENY
@@ -278,6 +283,7 @@
destination: 192.168.20.15
protocol: TCP
port: 80
position: 5
- item_type: ACL_RULE
id: '31'
permission: DENY
@@ -285,6 +291,7 @@
destination: 192.168.20.15
protocol: TCP
port: 80
position: 6
- item_type: ACL_RULE
id: '32'
permission: DENY
@@ -292,6 +299,7 @@
destination: 192.168.20.15
protocol: TCP
port: 80
position: 7
- item_type: ACL_RULE
id: '33'
permission: DENY
@@ -299,6 +307,7 @@
destination: 192.168.10.14
protocol: TCP
port: 80
position: 8
- item_type: RED_POL
id: '34'
start_step: 20

View File

@@ -111,6 +111,7 @@
destination: 192.168.1.4
protocol: TCP
port: 80
position: 0
- item_type: ACL_RULE
id: '12'
permission: ALLOW
@@ -118,6 +119,7 @@
destination: 192.168.1.4
protocol: TCP
port: 80
position: 1
- item_type: ACL_RULE
id: '13'
permission: ALLOW
@@ -125,6 +127,7 @@
destination: 192.168.1.3
protocol: TCP
port: 80
position: 2
- item_type: RED_POL
id: '14'
start_step: 20

View File

@@ -345,6 +345,7 @@
destination: 192.168.2.10
protocol: ANY
port: ANY
position: 0
- item_type: ACL_RULE
id: '34'
permission: ALLOW
@@ -352,6 +353,7 @@
destination: 192.168.2.14
protocol: ANY
port: ANY
position: 1
- item_type: ACL_RULE
id: '35'
permission: ALLOW
@@ -359,6 +361,7 @@
destination: 192.168.2.14
protocol: ANY
port: ANY
position: 2
- item_type: ACL_RULE
id: '36'
permission: ALLOW
@@ -366,6 +369,7 @@
destination: 192.168.2.10
protocol: ANY
port: ANY
position: 3
- item_type: ACL_RULE
id: '37'
permission: ALLOW
@@ -373,6 +377,7 @@
destination: 192.168.10.11
protocol: ANY
port: ANY
position: 4
- item_type: ACL_RULE
id: '38'
permission: ALLOW
@@ -380,6 +385,7 @@
destination: 192.168.10.12
protocol: ANY
port: ANY
position: 5
- item_type: ACL_RULE
id: '39'
permission: ALLOW
@@ -387,6 +393,7 @@
destination: 192.168.2.14
protocol: ANY
port: ANY
position: 6
- item_type: ACL_RULE
id: '40'
permission: ALLOW
@@ -394,6 +401,7 @@
destination: 192.168.2.10
protocol: ANY
port: ANY
position: 7
- item_type: ACL_RULE
id: '41'
permission: ALLOW
@@ -401,6 +409,7 @@
destination: 192.168.2.16
protocol: ANY
port: ANY
position: 8
- item_type: ACL_RULE
id: '42'
permission: ALLOW
@@ -408,6 +417,7 @@
destination: 192.168.2.16
protocol: ANY
port: ANY
position: 9
- item_type: ACL_RULE
id: '43'
permission: ALLOW
@@ -415,6 +425,7 @@
destination: 192.168.2.10
protocol: ANY
port: ANY
position: 10
- item_type: ACL_RULE
id: '44'
permission: ALLOW
@@ -422,6 +433,7 @@
destination: 192.168.2.14
protocol: ANY
port: ANY
position: 11
- item_type: ACL_RULE
id: '45'
permission: ALLOW
@@ -429,6 +441,7 @@
destination: 192.168.2.16
protocol: ANY
port: ANY
position: 12
- item_type: ACL_RULE
id: '46'
permission: ALLOW
@@ -436,6 +449,7 @@
destination: 192.168.1.12
protocol: ANY
port: ANY
position: 13
- item_type: ACL_RULE
id: '47'
permission: ALLOW
@@ -443,6 +457,7 @@
destination: 192.168.1.12
protocol: ANY
port: ANY
position: 14
- item_type: ACL_RULE
id: '48'
permission: ALLOW
@@ -450,6 +465,7 @@
destination: 192.168.1.12
protocol: ANY
port: ANY
position: 15
- item_type: ACL_RULE
id: '49'
permission: DENY
@@ -457,6 +473,7 @@
destination: ANY
protocol: ANY
port: ANY
position: 16
- item_type: RED_POL
id: '50'
start_step: 50

View File

@@ -35,7 +35,7 @@ random_red_agent: False
# Default is None (null)
seed: null
# Set whether the agent will be deterministic instead of stochastic
# Set whether the agent evaluation will be deterministic instead of stochastic
# Options are:
# True
# False
@@ -51,15 +51,15 @@ hard_coded_agent_view: FULL
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
action_type: ANY
# observation space
observation_space:
# flatten: true
flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
- name: NODE_STATUSES
- name: LINK_TRAFFIC_LEVELS
- name: ACCESS_CONTROL_LIST
# Number of episodes for training to run per session
num_train_episodes: 10
@@ -90,6 +90,11 @@ session_type: TRAIN_EVAL
# The high value for the observation space
observation_space_high_value: 1000000000
# Implicit ACL firewall rule at end of ACL list to be the default action (ALLOW or DENY)
implicit_acl_rule: DENY
# Total number of ACL rules allowed in the environment
max_number_acl_rules: 30
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)

View File

@@ -1,4 +1,5 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from logging import Logger
from pathlib import Path
from typing import Any, Dict, Final, Union
@@ -6,7 +7,7 @@ import yaml
from primaite import getLogger, USERS_CONFIG_DIR
_LOGGER = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
from dataclasses import dataclass, field
from logging import Logger
from pathlib import Path
from typing import Any, Dict, Final, Optional, Union
@@ -14,11 +15,12 @@ from primaite.common.enums import (
AgentIdentifier,
DeepLearningFramework,
HardCodedAgentView,
RulePermissionType,
SB3OutputVerboseLevel,
SessionType,
)
_LOGGER = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
_EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training"
@@ -85,7 +87,7 @@ class TrainingConfig:
session_type: SessionType = SessionType.TRAIN
"The type of PrimAITE session to run"
load_agent: str = False
load_agent: bool = False
"Determine whether to load an agent from file"
agent_load_file: Optional[str] = None
@@ -98,6 +100,12 @@ class TrainingConfig:
sb3_output_verbose_level: SB3OutputVerboseLevel = SB3OutputVerboseLevel.NONE
"Stable Baselines3 learn/eval output verbosity level"
implicit_acl_rule: RulePermissionType = RulePermissionType.DENY
"ALLOW or DENY implicit firewall rule to go at the end of list of ACL list."
max_number_acl_rules: int = 30
"Sets a limit for number of acl rules allowed in the list and environment."
# Reward values
# Generic
all_ok: float = 0
@@ -191,7 +199,7 @@ class TrainingConfig:
"The random number generator seed to be used while training the agent"
@classmethod
def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig:
def from_dict(cls, config_dict: Dict[str, Any]) -> TrainingConfig:
"""
Create an instance of TrainingConfig from a dict.
@@ -206,14 +214,17 @@ class TrainingConfig:
"session_type": SessionType,
"sb3_output_verbose_level": SB3OutputVerboseLevel,
"hard_coded_agent_view": HardCodedAgentView,
"implicit_acl_rule": RulePermissionType,
}
# convert the string representation of enums into the actual enum values themselves?
for key, value in field_enum_map.items():
if key in config_dict:
config_dict[key] = value[config_dict[key]]
return TrainingConfig(**config_dict)
def to_dict(self, json_serializable: bool = True):
def to_dict(self, json_serializable: bool = True) -> Dict:
"""
Serialise the ``TrainingConfig`` as dict.
@@ -230,6 +241,7 @@ class TrainingConfig:
data["sb3_output_verbose_level"] = self.sb3_output_verbose_level.name
data["session_type"] = self.session_type.name
data["hard_coded_agent_view"] = self.hard_coded_agent_view.name
data["implicit_acl_rule"] = self.implicit_acl_rule.name
return data
@@ -332,7 +344,7 @@ def convert_legacy_training_config_dict(
return config_dict
def _get_new_key_from_legacy(legacy_key: str) -> str:
def _get_new_key_from_legacy(legacy_key: str) -> Optional[str]:
"""
Maps legacy training config keys to the new format keys.

View File

@@ -2,12 +2,14 @@
"""Module for handling configurable observation spaces in PrimAITE."""
import logging
from abc import ABC, abstractmethod
from logging import Logger
from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union
import numpy as np
from gym import spaces
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
from primaite.acl.acl_rule import ACLRule
from primaite.common.enums import FileSystemState, HardwareState, RulePermissionType, SoftwareState
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
@@ -18,14 +20,14 @@ if TYPE_CHECKING:
from primaite.environment.primaite_env import Primaite
_LOGGER = logging.getLogger(__name__)
_LOGGER: Logger = logging.getLogger(__name__)
class AbstractObservationComponent(ABC):
"""Represents a part of the PrimAITE observation space."""
@abstractmethod
def __init__(self, env: "Primaite"):
def __init__(self, env: "Primaite") -> None:
"""
Initialise observation component.
@@ -40,7 +42,7 @@ class AbstractObservationComponent(ABC):
return NotImplemented
@abstractmethod
def update(self):
def update(self) -> None:
"""Update the observation based on the current state of the environment."""
self.current_observation = NotImplemented
@@ -75,7 +77,7 @@ class NodeLinkTable(AbstractObservationComponent):
_MAX_VAL: int = 1_000_000_000
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite"):
def __init__(self, env: "Primaite") -> None:
"""
Initialise a NodeLinkTable observation space component.
@@ -102,7 +104,7 @@ class NodeLinkTable(AbstractObservationComponent):
self.structure = self.generate_structure()
def update(self):
def update(self) -> None:
"""
Update the observation based on current environment state.
@@ -149,7 +151,7 @@ class NodeLinkTable(AbstractObservationComponent):
protocol_index += 1
item_index += 1
def generate_structure(self):
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
nodes = self.env.nodes.values()
links = self.env.links.values()
@@ -212,7 +214,7 @@ class NodeStatuses(AbstractObservationComponent):
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite"):
def __init__(self, env: "Primaite") -> None:
"""
Initialise a NodeStatuses observation component.
@@ -238,7 +240,7 @@ class NodeStatuses(AbstractObservationComponent):
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
self.structure = self.generate_structure()
def update(self):
def update(self) -> None:
"""
Update the observation based on current environment state.
@@ -269,11 +271,12 @@ class NodeStatuses(AbstractObservationComponent):
)
self.current_observation[:] = obs
def generate_structure(self):
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
services = self.env.services_list
structure = []
for _, node in self.env.nodes.items():
node_id = node.node_id
structure.append(f"node_{node_id}_hardware_state_NONE")
@@ -318,7 +321,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
env: "Primaite",
combine_service_traffic: bool = False,
quantisation_levels: int = 5,
):
) -> None:
"""
Initialise a LinkTrafficLevels observation component.
@@ -360,7 +363,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
self.structure = self.generate_structure()
def update(self):
def update(self) -> None:
"""
Update the observation based on current environment state.
@@ -386,7 +389,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
self.current_observation[:] = obs
def generate_structure(self):
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
structure = []
for _, link in self.env.links.items():
@@ -402,6 +405,182 @@ class LinkTrafficLevels(AbstractObservationComponent):
return structure
class AccessControlList(AbstractObservationComponent):
"""Flat list of all the Access Control Rules in the Access Control List.
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
integers.
Each ACL Rule has 6 elements. It will have the following structure:
.. code-block::
[
acl_rule1 permission,
acl_rule1 source_ip,
acl_rule1 dest_ip,
acl_rule1 protocol,
acl_rule1 port,
acl_rule1 position,
acl_rule2 permission,
acl_rule2 source_ip,
acl_rule2 dest_ip,
acl_rule2 protocol,
acl_rule2 port,
acl_rule2 position,
...
]
Terms (for ACL Observation Space):
[0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW)
[0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
[0, num nodes] - Dest IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
[0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol)
[0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port)
[0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list)
NOTE: NA is Non-Applicable - this means the ACL Rule in the list is a NoneType and NOT an ACLRule object.
"""
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite"):
"""
Initialise an AccessControlList observation component.
:param env: The environment that forms the basis of the observations
:type env: Primaite
"""
super().__init__(env)
# 1. Define the shape of your observation space component
# The NA and ANY types means that there are 2 extra items for Nodes, Services and Ports.
# Number of ACL rules incremented by 1 for positions starting at index 0.
acl_shape = [
len(RulePermissionType),
len(env.nodes) + 2,
len(env.nodes) + 2,
len(env.services_list) + 2,
len(env.ports_list) + 2,
env.max_number_acl_rules,
]
shape = acl_shape * self.env.max_number_acl_rules
# 2. Create Observation space
self.space = spaces.MultiDiscrete(shape)
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
self.structure = self.generate_structure()
def update(self) -> None:
"""Update the observation based on current environment state.
The structure of the observation space is described in :class:`.AccessControlList`
"""
obs = []
for index in range(0, len(self.env.acl.acl)):
acl_rule = self.env.acl.acl[index]
if isinstance(acl_rule, ACLRule):
permission = acl_rule.permission
source_ip = acl_rule.source_ip
dest_ip = acl_rule.dest_ip
protocol = acl_rule.protocol
port = acl_rule.port
position = index
# Map each ACL attribute from what it was to an integer to fit the observation space
source_ip_int = None
dest_ip_int = None
if permission == RulePermissionType.DENY:
permission_int = 1
else:
permission_int = 2
if source_ip == "ANY":
source_ip_int = 1
else:
# Map Node ID (+ 1) to source IP address
nodes = list(self.env.nodes.values())
for node in nodes:
if (
isinstance(node, ServiceNode) or isinstance(node, ActiveNode)
) and node.ip_address == source_ip:
source_ip_int = int(node.node_id) + 1
break
if dest_ip == "ANY":
dest_ip_int = 1
else:
# Map Node ID (+ 1) to dest IP address
# Index of Nodes start at 1 so + 1 is needed so NA can be added.
nodes = list(self.env.nodes.values())
for node in nodes:
if (
isinstance(node, ServiceNode) or isinstance(node, ActiveNode)
) and node.ip_address == dest_ip:
dest_ip_int = int(node.node_id) + 1
if protocol == "ANY":
protocol_int = 1
else:
# Index of protocols and ports start from 0 so + 2 is needed to add NA and ANY
try:
protocol_int = self.env.services_list.index(protocol) + 2
except AttributeError:
_LOGGER.info(f"Service {protocol} could not be found")
protocol_int = None
if port == "ANY":
port_int = 1
else:
if port in self.env.ports_list:
port_int = self.env.ports_list.index(port) + 2
else:
_LOGGER.info(f"Port {port} could not be found.")
port_int = None
# Add to current obs
obs.extend(
[
permission_int,
source_ip_int,
dest_ip_int,
protocol_int,
port_int,
position,
]
)
else:
# The Nothing or NA representation of 'NONE' ACL rules
obs.extend([0, 0, 0, 0, 0, 0])
self.current_observation[:] = obs
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
structure = []
for acl_rule in self.env.acl.acl:
acl_rule_id = self.env.acl.acl.index(acl_rule)
for permission in RulePermissionType:
structure.append(f"acl_rule_{acl_rule_id}_permission_{permission.name}")
structure.append(f"acl_rule_{acl_rule_id}_source_ip_ANY")
for node in self.env.nodes.keys():
structure.append(f"acl_rule_{acl_rule_id}_source_ip_{node}")
structure.append(f"acl_rule_{acl_rule_id}_dest_ip_ANY")
for node in self.env.nodes.keys():
structure.append(f"acl_rule_{acl_rule_id}_dest_ip_{node}")
structure.append(f"acl_rule_{acl_rule_id}_service_ANY")
for service in self.env.services_list:
structure.append(f"acl_rule_{acl_rule_id}_service_{service}")
structure.append(f"acl_rule_{acl_rule_id}_port_ANY")
for port in self.env.ports_list:
structure.append(f"acl_rule_{acl_rule_id}_port_{port}")
return structure
class ObservationsHandler:
"""
Component-based observation space handler.
@@ -414,9 +593,10 @@ class ObservationsHandler:
"NODE_LINK_TABLE": NodeLinkTable,
"NODE_STATUSES": NodeStatuses,
"LINK_TRAFFIC_LEVELS": LinkTrafficLevels,
"ACCESS_CONTROL_LIST": AccessControlList,
}
def __init__(self):
def __init__(self) -> None:
"""Initialise the observation handler."""
self.registered_obs_components: List[AbstractObservationComponent] = []
@@ -429,9 +609,7 @@ class ObservationsHandler:
# used for transactions and when flatten=true
self._flat_observation: np.ndarray
self.flatten: bool = False
def update_obs(self):
def update_obs(self) -> None:
"""Fetch fresh information about the environment."""
current_obs = []
for obs in self.registered_obs_components:
@@ -444,7 +622,7 @@ class ObservationsHandler:
self._observation = tuple(current_obs)
self._flat_observation = spaces.flatten(self._space, self._observation)
def register(self, obs_component: AbstractObservationComponent):
def register(self, obs_component: AbstractObservationComponent) -> None:
"""
Add a component for this handler to track.
@@ -454,7 +632,7 @@ class ObservationsHandler:
self.registered_obs_components.append(obs_component)
self.update_space()
def deregister(self, obs_component: AbstractObservationComponent):
def deregister(self, obs_component: AbstractObservationComponent) -> None:
"""
Remove a component from this handler.
@@ -465,7 +643,7 @@ class ObservationsHandler:
self.registered_obs_components.remove(obs_component)
self.update_space()
def update_space(self):
def update_space(self) -> None:
"""Rebuild the handler's composite observation space from its components."""
component_spaces = []
for obs_comp in self.registered_obs_components:
@@ -482,23 +660,23 @@ class ObservationsHandler:
self._flat_space = spaces.Box(0, 1, (0,))
@property
def space(self):
def space(self) -> spaces.Space:
"""Observation space, return the flattened version if flatten is True."""
if self.flatten:
if len(self.registered_obs_components) > 1:
return self._flat_space
else:
return self._space
@property
def current_observation(self):
def current_observation(self) -> Union[np.ndarray, Tuple[np.ndarray]]:
"""Current observation, return the flattened version if flatten is True."""
if self.flatten:
if len(self.registered_obs_components) > 1:
return self._flat_observation
else:
return self._observation
@classmethod
def from_config(cls, env: "Primaite", obs_space_config: dict):
def from_config(cls, env: "Primaite", obs_space_config: dict) -> "ObservationsHandler":
"""
Parse a config dictinary, return a new observation handler populated with new observation component objects.
@@ -527,9 +705,6 @@ class ObservationsHandler:
# Instantiate the handler
handler = cls()
if obs_space_config.get("flatten"):
handler.flatten = True
for component_cfg in obs_space_config["components"]:
# Figure out which class can instantiate the desired component
comp_type = component_cfg["name"]
@@ -544,7 +719,7 @@ class ObservationsHandler:
handler.update_obs()
return handler
def describe_structure(self):
def describe_structure(self) -> List[str]:
"""
Create a list of names for the features of the obs space.

View File

@@ -3,9 +3,10 @@
import copy
import logging
import uuid as uuid
from logging import Logger
from pathlib import Path
from random import choice, randint, sample, uniform
from typing import Dict, Final, Tuple, Union
from typing import Any, Dict, Final, List, Tuple, Union
import networkx as nx
import numpy as np
@@ -20,6 +21,7 @@ from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
ActionType,
AgentFramework,
AgentIdentifier,
FileSystemState,
HardwareState,
NodePOLInitiator,
@@ -48,7 +50,7 @@ from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_nod
from primaite.transactions.transaction import Transaction
from primaite.utils.session_output_writer import SessionOutputWriter
_LOGGER = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
class Primaite(Env):
@@ -66,7 +68,7 @@ class Primaite(Env):
lay_down_config_path: Union[str, Path],
session_path: Path,
timestamp_str: str,
):
) -> None:
"""
The Primaite constructor.
@@ -77,13 +79,14 @@ class Primaite(Env):
"""
self.session_path: Final[Path] = session_path
self.timestamp_str: Final[str] = timestamp_str
self._training_config_path = training_config_path
self._lay_down_config_path = lay_down_config_path
self._training_config_path: Union[str, Path] = training_config_path
self._lay_down_config_path: Union[str, Path] = lay_down_config_path
self.training_config: TrainingConfig = training_config.load(training_config_path)
_LOGGER.info(f"Using: {str(self.training_config)}")
# Number of steps in an episode
self.episode_steps: int
if self.training_config.session_type == SessionType.TRAIN:
self.episode_steps = self.training_config.num_train_steps
elif self.training_config.session_type == SessionType.EVAL:
@@ -94,7 +97,7 @@ class Primaite(Env):
super(Primaite, self).__init__()
# The agent in use
self.agent_identifier = self.training_config.agent_identifier
self.agent_identifier: AgentIdentifier = self.training_config.agent_identifier
# Create a dictionary to hold all the nodes
self.nodes: Dict[str, NodeUnion] = {}
@@ -113,37 +116,42 @@ class Primaite(Env):
self.green_iers_reference: Dict[str, IER] = {}
# Create a dictionary to hold all the node PoLs (this will come from an external source)
self.node_pol = {}
self.node_pol: Dict[str, NodeStateInstructionGreen] = {}
# Create a dictionary to hold all the red agent IERs (this will come from an external source)
self.red_iers = {}
self.red_iers: Dict[str, IER] = {}
# Create a dictionary to hold all the red agent node PoLs (this will come from an external source)
self.red_node_pol = {}
self.red_node_pol: Dict[str, NodeStateInstructionRed] = {}
# Create the Access Control List
self.acl = AccessControlList()
self.acl: AccessControlList = AccessControlList(
self.training_config.implicit_acl_rule,
self.training_config.max_number_acl_rules,
)
# Sets limit for number of ACL rules in environment
self.max_number_acl_rules: int = self.training_config.max_number_acl_rules
# Create a list of services (enums)
self.services_list = []
self.services_list: List[str] = []
# Create a list of ports
self.ports_list = []
self.ports_list: List[str] = []
# Create graph (network)
self.network = nx.MultiGraph()
self.network: nx.Graph = nx.MultiGraph()
# Create a graph (network) reference
self.network_reference = nx.MultiGraph()
self.network_reference: nx.Graph = nx.MultiGraph()
# Create step count
self.step_count = 0
self.step_count: int = 0
self.total_step_count: int = 0
"""The total number of time steps completed."""
# Create step info dictionary
self.step_info = {}
self.step_info: Dict[Any] = {}
# Total reward
self.total_reward: float = 0
@@ -152,22 +160,23 @@ class Primaite(Env):
self.average_reward: float = 0
# Episode count
self.episode_count = 0
self.episode_count: int = 0
# Number of nodes - gets a value by examining the nodes dictionary after it's been populated
self.num_nodes = 0
self.num_nodes: int = 0
# Number of links - gets a value by examining the links dictionary after it's been populated
self.num_links = 0
self.num_links: int = 0
# Number of services - gets a value when config is loaded
self.num_services = 0
self.num_services: int = 0
# Number of ports - gets a value when config is loaded
self.num_ports = 0
self.num_ports: int = 0
# The action type
self.action_type = 0
# TODO: confirm type
self.action_type: int = 0
# TODO fix up with TrainingConfig
# stores the observation config from the yaml, default is NODE_LINK_TABLE
@@ -179,7 +188,7 @@ class Primaite(Env):
# It will be initialised later.
self.obs_handler: ObservationsHandler
self._obs_space_description = None
self._obs_space_description: List[str] = None
"The env observation space description for transactions writing"
# Open the config file and build the environment laydown
@@ -211,9 +220,13 @@ class Primaite(Env):
_LOGGER.error("Could not save network diagram", exc_info=True)
# Initiate observation space
self.observation_space: spaces.Space
self.env_obs: np.ndarray
self.observation_space, self.env_obs = self.init_observations()
# Define Action Space - depends on action space type (Node or ACL)
self.action_dict: Dict[int, List[int]]
self.action_space: spaces.Space
if self.training_config.action_type == ActionType.NODE:
_LOGGER.debug("Action space type NODE selected")
# Terms (for node action space):
@@ -241,8 +254,12 @@ class Primaite(Env):
else:
_LOGGER.error(f"Invalid action type selected: {self.training_config.action_type}")
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=True)
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=True)
self.episode_av_reward_writer: SessionOutputWriter = SessionOutputWriter(
self, transaction_writer=False, learning_session=True
)
self.transaction_writer: SessionOutputWriter = SessionOutputWriter(
self, transaction_writer=True, learning_session=True
)
@property
def actual_episode_count(self) -> int:
@@ -251,7 +268,7 @@ class Primaite(Env):
return self.episode_count - 1
return self.episode_count
def set_as_eval(self):
def set_as_eval(self) -> None:
"""Set the writers to write to eval directories."""
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=False)
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=False)
@@ -260,12 +277,12 @@ class Primaite(Env):
self.total_step_count = 0
self.episode_steps = self.training_config.num_eval_steps
def _write_av_reward_per_episode(self):
def _write_av_reward_per_episode(self) -> None:
if self.actual_episode_count > 0:
csv_data = self.actual_episode_count, self.average_reward
self.episode_av_reward_writer.write(csv_data)
def reset(self):
def reset(self) -> np.ndarray:
"""
AI Gym Reset function.
@@ -299,7 +316,7 @@ class Primaite(Env):
return self.env_obs
def step(self, action):
def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict]:
"""
AI Gym Step function.
@@ -418,7 +435,7 @@ class Primaite(Env):
# Return
return self.env_obs, reward, done, self.step_info
def close(self):
def close(self) -> None:
"""Override parent close and close writers."""
# Close files if last episode/step
# if self.can_finish:
@@ -427,18 +444,18 @@ class Primaite(Env):
self.transaction_writer.close()
self.episode_av_reward_writer.close()
def init_acl(self):
def init_acl(self) -> None:
"""Initialise the Access Control List."""
self.acl.remove_all_rules()
def output_link_status(self):
def output_link_status(self) -> None:
"""Output the link status of all links to the console."""
for link_key, link_value in self.links.items():
_LOGGER.debug("Link ID: " + link_value.get_id())
for protocol in link_value.protocol_list:
print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load()))
def interpret_action_and_apply(self, _action):
def interpret_action_and_apply(self, _action: int) -> None:
"""
Applies agent actions to the nodes and Access Control List.
@@ -446,19 +463,18 @@ class Primaite(Env):
_action: The action space from the agent
"""
# At the moment, actions are only affecting nodes
if self.training_config.action_type == ActionType.NODE:
self.apply_actions_to_nodes(_action)
elif self.training_config.action_type == ActionType.ACL:
self.apply_actions_to_acl(_action)
elif len(self.action_dict[_action]) == 6: # ACL actions in multidiscrete form have len 6
elif len(self.action_dict[_action]) == 7: # ACL actions in multidiscrete form have len 7
self.apply_actions_to_acl(_action)
elif len(self.action_dict[_action]) == 4: # Node actions in multdiscrete (array) from have len 4
self.apply_actions_to_nodes(_action)
else:
logging.error("Invalid action type found")
def apply_actions_to_nodes(self, _action):
def apply_actions_to_nodes(self, _action: int) -> None:
"""
Applies agent actions to the nodes.
@@ -546,7 +562,7 @@ class Primaite(Env):
else:
return
def apply_actions_to_acl(self, _action):
def apply_actions_to_acl(self, _action: int) -> None:
"""
Applies agent actions to the Access Control List [TO DO].
@@ -562,6 +578,7 @@ class Primaite(Env):
action_destination_ip = readable_action[3]
action_protocol = readable_action[4]
action_port = readable_action[5]
acl_rule_position = readable_action[6]
if action_decision == 0:
# It's decided to do nothing
@@ -611,6 +628,7 @@ class Primaite(Env):
acl_rule_destination,
acl_rule_protocol,
acl_rule_port,
acl_rule_position,
)
elif action_decision == 2:
# Remove the rule
@@ -624,7 +642,7 @@ class Primaite(Env):
else:
return
def apply_time_based_updates(self):
def apply_time_based_updates(self) -> None:
"""
Updates anything that needs to count down and then change state.
@@ -680,12 +698,12 @@ class Primaite(Env):
return self.obs_handler.space, self.obs_handler.current_observation
def update_environent_obs(self):
def update_environent_obs(self) -> None:
"""Updates the observation space based on the node and link status."""
self.obs_handler.update_obs()
self.env_obs = self.obs_handler.current_observation
def load_lay_down_config(self):
def load_lay_down_config(self) -> None:
"""Loads config data in order to build the environment configuration."""
for item in self.lay_down_config:
if item["item_type"] == "NODE":
@@ -723,7 +741,7 @@ class Primaite(Env):
_LOGGER.info("Environment configuration loaded")
print("Environment configuration loaded")
def create_node(self, item):
def create_node(self, item: Dict) -> None:
"""
Creates a node from config data.
@@ -804,7 +822,7 @@ class Primaite(Env):
# Add node to network (reference)
self.network_reference.add_nodes_from([node_ref])
def create_link(self, item: Dict):
def create_link(self, item: Dict) -> None:
"""
Creates a link from config data.
@@ -848,7 +866,7 @@ class Primaite(Env):
self.services_list,
)
def create_green_ier(self, item):
def create_green_ier(self, item: Dict) -> None:
"""
Creates a green IER from config data.
@@ -889,7 +907,7 @@ class Primaite(Env):
ier_mission_criticality,
)
def create_red_ier(self, item):
def create_red_ier(self, item: Dict) -> None:
"""
Creates a red IER from config data.
@@ -919,7 +937,7 @@ class Primaite(Env):
ier_mission_criticality,
)
def create_green_pol(self, item):
def create_green_pol(self, item: Dict) -> None:
"""
Creates a green PoL object from config data.
@@ -953,7 +971,7 @@ class Primaite(Env):
pol_state,
)
def create_red_pol(self, item):
def create_red_pol(self, item: Dict) -> None:
"""
Creates a red PoL object from config data.
@@ -994,7 +1012,7 @@ class Primaite(Env):
pol_source_node_service_state,
)
def create_acl_rule(self, item):
def create_acl_rule(self, item: Dict) -> None:
"""
Creates an ACL rule from config data.
@@ -1006,6 +1024,7 @@ class Primaite(Env):
acl_rule_destination = item["destination"]
acl_rule_protocol = item["protocol"]
acl_rule_port = item["port"]
acl_rule_position = item["position"]
self.acl.add_rule(
acl_rule_permission,
@@ -1013,9 +1032,11 @@ class Primaite(Env):
acl_rule_destination,
acl_rule_protocol,
acl_rule_port,
acl_rule_position,
)
def create_services_list(self, services):
# TODO: confirm typehint using runtime
def create_services_list(self, services: Dict) -> None:
"""
Creates a list of services (enum) from config data.
@@ -1031,7 +1052,7 @@ class Primaite(Env):
# Set the number of services
self.num_services = len(self.services_list)
def create_ports_list(self, ports):
def create_ports_list(self, ports: Dict) -> None:
"""
Creates a list of ports from config data.
@@ -1047,7 +1068,8 @@ class Primaite(Env):
# Set the number of ports
self.num_ports = len(self.ports_list)
def get_observation_info(self, observation_info):
# TODO: this is not used anymore, write a ticket to delete it
def get_observation_info(self, observation_info: Dict) -> None:
"""
Extracts observation_info.
@@ -1056,7 +1078,8 @@ class Primaite(Env):
"""
self.observation_type = ObservationType[observation_info["type"]]
def get_action_info(self, action_info):
# TODO: this is not used anymore, write a ticket to delete it.
def get_action_info(self, action_info: Dict) -> None:
"""
Extracts action_info.
@@ -1065,7 +1088,7 @@ class Primaite(Env):
"""
self.action_type = ActionType[action_info["type"]]
def save_obs_config(self, obs_config: dict):
def save_obs_config(self, obs_config: dict) -> None:
"""
Cache the config for the observation space.
@@ -1078,7 +1101,7 @@ class Primaite(Env):
"""
self.obs_config = obs_config
def reset_environment(self):
def reset_environment(self) -> None:
"""
Resets environment.
@@ -1103,7 +1126,7 @@ class Primaite(Env):
for ier_key, ier_value in self.red_iers.items():
ier_value.set_is_running(False)
def reset_node(self, item):
def reset_node(self, item: Dict) -> None:
"""
Resets the statuses of a node.
@@ -1151,7 +1174,7 @@ class Primaite(Env):
# Bad formatting
pass
def create_node_action_dict(self):
def create_node_action_dict(self) -> Dict[int, List[int]]:
"""
Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.
@@ -1167,6 +1190,11 @@ class Primaite(Env):
...
}
"""
# Terms (for node action space):
# [0, num nodes] - node ID (0 = nothing, node ID)
# [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
# reserve 0 action to be a nothing action
actions = {0: [1, 0, 0, 0]}
action_key = 1
@@ -1186,10 +1214,18 @@ class Primaite(Env):
return actions
def create_acl_action_dict(self):
def create_acl_action_dict(self) -> Dict[int, List[int]]:
"""Creates a dictionary mapping each possible discrete action to more readable multidiscrete action."""
# Terms (for ACL action space):
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
# [0, 1] - Permission (0 = DENY, 1 = ALLOW)
# [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses)
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
# [0, max acl rules - 1] - Position (0 = first index, then 1 -> x index resolving to acl rule in acl list)
# reserve 0 action to be a nothing action
actions = {0: [0, 0, 0, 0, 0, 0]}
actions = {0: [0, 0, 0, 0, 0, 0, 0]}
action_key = 1
# 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE
@@ -1201,22 +1237,25 @@ class Primaite(Env):
for dest_ip in range(self.num_nodes + 1):
for protocol in range(self.num_services + 1):
for port in range(self.num_ports + 1):
action = [
action_decision,
action_permission,
source_ip,
dest_ip,
protocol,
port,
]
# Check to see if its an action we want to include as possible i.e. not a nothing action
if is_valid_acl_action_extra(action):
actions[action_key] = action
action_key += 1
for position in range(self.max_number_acl_rules - 1):
action = [
action_decision,
action_permission,
source_ip,
dest_ip,
protocol,
port,
position,
]
# Check to see if it is an action we want to include as possible
# i.e. not a nothing action
if is_valid_acl_action_extra(action):
actions[action_key] = action
action_key += 1
return actions
def create_node_and_acl_action_dict(self):
def create_node_and_acl_action_dict(self) -> Dict[int, List[int]]:
"""
Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action.
@@ -1233,7 +1272,7 @@ class Primaite(Env):
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
return combined_action_dict
def _create_random_red_agent(self):
def _create_random_red_agent(self) -> None:
"""Decide on random red agent for the episode to be called in env.reset()."""
# Reset the current red iers and red node pol
self.red_iers = {}

View File

@@ -1,25 +1,31 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Implements reward function."""
from typing import Dict
from logging import Logger
from typing import Dict, TYPE_CHECKING, Union
from primaite import getLogger
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
from primaite.common.service import Service
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.config.training_config import TrainingConfig
from primaite.pol.ier import IER
_LOGGER: Logger = getLogger(__name__)
def calculate_reward_function(
initial_nodes,
final_nodes,
reference_nodes,
green_iers,
green_iers_reference,
red_iers,
step_count,
config_values,
initial_nodes: Dict[str, NodeUnion],
final_nodes: Dict[str, NodeUnion],
reference_nodes: Dict[str, NodeUnion],
green_iers: Dict[str, "IER"],
green_iers_reference: Dict[str, "IER"],
red_iers: Dict[str, "IER"],
step_count: int,
config_values: "TrainingConfig",
) -> float:
"""
Compares the states of the initial and final nodes/links to get a reward.
@@ -93,7 +99,9 @@ def calculate_reward_function(
return reward_value
def score_node_operating_state(final_node, initial_node, reference_node, config_values) -> float:
def score_node_operating_state(
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig"
) -> float:
"""
Calculates score relating to the hardware state of a node.
@@ -142,7 +150,12 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
return score
def score_node_os_state(final_node, initial_node, reference_node, config_values) -> float:
def score_node_os_state(
final_node: Union[ActiveNode, ServiceNode],
initial_node: Union[ActiveNode, ServiceNode],
reference_node: Union[ActiveNode, ServiceNode],
config_values: "TrainingConfig",
) -> float:
"""
Calculates score relating to the Software State of a node.
@@ -193,7 +206,9 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
return score
def score_node_service_state(final_node, initial_node, reference_node, config_values) -> float:
def score_node_service_state(
final_node: ServiceNode, initial_node: ServiceNode, reference_node: ServiceNode, config_values: "TrainingConfig"
) -> float:
"""
Calculates score relating to the service state(s) of a node.
@@ -265,7 +280,12 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
return score
def score_node_file_system(final_node, initial_node, reference_node, config_values) -> float:
def score_node_file_system(
final_node: Union[ActiveNode, ServiceNode],
initial_node: Union[ActiveNode, ServiceNode],
reference_node: Union[ActiveNode, ServiceNode],
config_values: "TrainingConfig",
) -> float:
"""
Calculates score relating to the file system state of a node.

View File

@@ -8,7 +8,7 @@ from primaite.common.protocol import Protocol
class Link(object):
"""Link class."""
def __init__(self, _id, _bandwidth, _source_node_name, _dest_node_name, _services):
def __init__(self, _id: str, _bandwidth: int, _source_node_name: str, _dest_node_name: str, _services: str) -> None:
"""
Initialise a Link within the simulated network.
@@ -18,17 +18,17 @@ class Link(object):
:param _dest_node_name: The name of the destination node
:param _protocols: The protocols to add to the link
"""
self.id = _id
self.bandwidth = _bandwidth
self.source_node_name = _source_node_name
self.dest_node_name = _dest_node_name
self.id: str = _id
self.bandwidth: int = _bandwidth
self.source_node_name: str = _source_node_name
self.dest_node_name: str = _dest_node_name
self.protocol_list: List[Protocol] = []
# Add the default protocols
for protocol_name in _services:
self.add_protocol(protocol_name)
def add_protocol(self, _protocol):
def add_protocol(self, _protocol: str) -> None:
"""
Adds a new protocol to the list of protocols on this link.
@@ -37,7 +37,7 @@ class Link(object):
"""
self.protocol_list.append(Protocol(_protocol))
def get_id(self):
def get_id(self) -> str:
"""
Gets link ID.
@@ -46,7 +46,7 @@ class Link(object):
"""
return self.id
def get_source_node_name(self):
def get_source_node_name(self) -> str:
"""
Gets source node name.
@@ -55,7 +55,7 @@ class Link(object):
"""
return self.source_node_name
def get_dest_node_name(self):
def get_dest_node_name(self) -> str:
"""
Gets destination node name.
@@ -64,7 +64,7 @@ class Link(object):
"""
return self.dest_node_name
def get_bandwidth(self):
def get_bandwidth(self) -> int:
"""
Gets bandwidth of link.
@@ -73,7 +73,7 @@ class Link(object):
"""
return self.bandwidth
def get_protocol_list(self):
def get_protocol_list(self) -> List[Protocol]:
"""
Gets list of protocols on this link.
@@ -82,7 +82,7 @@ class Link(object):
"""
return self.protocol_list
def get_current_load(self):
def get_current_load(self) -> int:
"""
Gets current total load on this link.
@@ -94,7 +94,7 @@ class Link(object):
total_load += protocol.get_load()
return total_load
def add_protocol_load(self, _protocol, _load):
def add_protocol_load(self, _protocol: str, _load: int) -> None:
"""
Adds a loading to a protocol on this link.
@@ -108,7 +108,7 @@ class Link(object):
else:
pass
def clear_traffic(self):
def clear_traffic(self) -> None:
"""Clears all traffic on this link."""
for protocol in self.protocol_list:
protocol.clear_load()

View File

@@ -14,7 +14,7 @@ def run(
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
):
) -> None:
"""
Run the PrimAITE Session.

View File

@@ -24,7 +24,7 @@ class ActiveNode(Node):
software_state: SoftwareState,
file_system_state: FileSystemState,
config_values: TrainingConfig,
):
) -> None:
"""
Initialise an active node.
@@ -60,7 +60,7 @@ class ActiveNode(Node):
return self._software_state
@software_state.setter
def software_state(self, software_state: SoftwareState):
def software_state(self, software_state: SoftwareState) -> None:
"""
Get the software_state.
@@ -79,7 +79,7 @@ class ActiveNode(Node):
f"Node.software_state:{self._software_state}"
)
def set_software_state_if_not_compromised(self, software_state: SoftwareState):
def set_software_state_if_not_compromised(self, software_state: SoftwareState) -> None:
"""
Sets Software State if the node is not compromised.
@@ -99,14 +99,14 @@ class ActiveNode(Node):
f"Node.software_state:{self._software_state}"
)
def update_os_patching_status(self):
def update_os_patching_status(self) -> None:
"""Updates operating system status based on patching cycle."""
self.patching_count -= 1
if self.patching_count <= 0:
self.patching_count = 0
self._software_state = SoftwareState.GOOD
def set_file_system_state(self, file_system_state: FileSystemState):
def set_file_system_state(self, file_system_state: FileSystemState) -> None:
"""
Sets the file system state (actual and observed).
@@ -133,7 +133,7 @@ class ActiveNode(Node):
f"Node.file_system_state.actual:{self.file_system_state_actual}"
)
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState):
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState) -> None:
"""
Sets the file system state (actual and observed) if not in a compromised state.
@@ -166,12 +166,12 @@ class ActiveNode(Node):
f"Node.file_system_state.actual:{self.file_system_state_actual}"
)
def start_file_system_scan(self):
def start_file_system_scan(self) -> None:
"""Starts a file system scan."""
self.file_system_scanning = True
self.file_system_scanning_count = self.config_values.file_system_scanning_limit
def update_file_system_state(self):
def update_file_system_state(self) -> None:
"""Updates file system status based on scanning/restore/repair cycle."""
# Deprecate both the action count (for restoring or reparing) and the scanning count
self.file_system_action_count -= 1
@@ -193,14 +193,14 @@ class ActiveNode(Node):
self.file_system_scanning = False
self.file_system_scanning_count = 0
def update_resetting_status(self):
def update_resetting_status(self) -> None:
"""Updates the reset count & makes software and file state to GOOD."""
super().update_resetting_status()
if self.resetting_count <= 0:
self.file_system_state_actual = FileSystemState.GOOD
self.software_state = SoftwareState.GOOD
def update_booting_status(self):
def update_booting_status(self) -> None:
"""Updates the booting software and file state to GOOD."""
super().update_booting_status()
if self.booting_count <= 0:

View File

@@ -17,7 +17,7 @@ class Node:
priority: Priority,
hardware_state: HardwareState,
config_values: TrainingConfig,
):
) -> None:
"""
Initialise a node.
@@ -38,40 +38,40 @@ class Node:
self.booting_count: int = 0
self.shutting_down_count: int = 0
def __repr__(self):
def __repr__(self) -> str:
"""Returns the name of the node."""
return self.name
def turn_on(self):
def turn_on(self) -> None:
"""Sets the node state to ON."""
self.hardware_state = HardwareState.BOOTING
self.booting_count = self.config_values.node_booting_duration
def turn_off(self):
def turn_off(self) -> None:
"""Sets the node state to OFF."""
self.hardware_state = HardwareState.OFF
self.shutting_down_count = self.config_values.node_shutdown_duration
def reset(self):
def reset(self) -> None:
"""Sets the node state to Resetting and starts the reset count."""
self.hardware_state = HardwareState.RESETTING
self.resetting_count = self.config_values.node_reset_duration
def update_resetting_status(self):
def update_resetting_status(self) -> None:
"""Updates the resetting count."""
self.resetting_count -= 1
if self.resetting_count <= 0:
self.resetting_count = 0
self.hardware_state = HardwareState.ON
def update_booting_status(self):
def update_booting_status(self) -> None:
"""Updates the booting count."""
self.booting_count -= 1
if self.booting_count <= 0:
self.booting_count = 0
self.hardware_state = HardwareState.ON
def update_shutdown_status(self):
def update_shutdown_status(self) -> None:
"""Updates the shutdown count."""
self.shutting_down_count -= 1
if self.shutting_down_count <= 0:

View File

@@ -1,5 +1,9 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Defines node behaviour for Green PoL."""
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from primaite.common.enums import FileSystemState, HardwareState, NodePOLType, SoftwareState
class NodeStateInstructionGreen(object):
@@ -7,14 +11,14 @@ class NodeStateInstructionGreen(object):
def __init__(
self,
_id,
_start_step,
_end_step,
_node_id,
_node_pol_type,
_service_name,
_state,
):
_id: str,
_start_step: int,
_end_step: int,
_node_id: str,
_node_pol_type: "NodePOLType",
_service_name: str,
_state: Union["HardwareState", "SoftwareState", "FileSystemState"],
) -> None:
"""
Initialise the Node State Instruction.
@@ -30,11 +34,12 @@ class NodeStateInstructionGreen(object):
self.start_step = _start_step
self.end_step = _end_step
self.node_id = _node_id
self.node_pol_type = _node_pol_type
self.service_name = _service_name # Not used when not a service instruction
self.state = _state
self.node_pol_type: "NodePOLType" = _node_pol_type
self.service_name: str = _service_name # Not used when not a service instruction
# TODO: confirm type of state
self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _state
def get_start_step(self):
def get_start_step(self) -> int:
"""
Gets the start step.
@@ -43,7 +48,7 @@ class NodeStateInstructionGreen(object):
"""
return self.start_step
def get_end_step(self):
def get_end_step(self) -> int:
"""
Gets the end step.
@@ -52,7 +57,7 @@ class NodeStateInstructionGreen(object):
"""
return self.end_step
def get_node_id(self):
def get_node_id(self) -> str:
"""
Gets the node ID.
@@ -61,7 +66,7 @@ class NodeStateInstructionGreen(object):
"""
return self.node_id
def get_node_pol_type(self):
def get_node_pol_type(self) -> "NodePOLType":
"""
Gets the node pattern of life type (enum).
@@ -70,7 +75,7 @@ class NodeStateInstructionGreen(object):
"""
return self.node_pol_type
def get_service_name(self):
def get_service_name(self) -> str:
"""
Gets the service name.
@@ -79,7 +84,7 @@ class NodeStateInstructionGreen(object):
"""
return self.service_name
def get_state(self):
def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]:
"""
Gets the state (node or service).

View File

@@ -1,9 +1,13 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Defines node behaviour for Green PoL."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Union
from primaite.common.enums import NodePOLType
if TYPE_CHECKING:
from primaite.common.enums import FileSystemState, HardwareState, NodePOLInitiator, SoftwareState
@dataclass()
class NodeStateInstructionRed(object):
@@ -11,18 +15,18 @@ class NodeStateInstructionRed(object):
def __init__(
self,
_id,
_start_step,
_end_step,
_target_node_id,
_pol_initiator,
_id: str,
_start_step: int,
_end_step: int,
_target_node_id: str,
_pol_initiator: "NodePOLInitiator",
_pol_type: NodePOLType,
pol_protocol,
_pol_state,
_pol_source_node_id,
_pol_source_node_service,
_pol_source_node_service_state,
):
pol_protocol: str,
_pol_state: Union["HardwareState", "SoftwareState", "FileSystemState"],
_pol_source_node_id: str,
_pol_source_node_service: str,
_pol_source_node_service_state: str,
) -> None:
"""
Initialise the Node State Instruction for the red agent.
@@ -38,19 +42,19 @@ class NodeStateInstructionRed(object):
:param _pol_source_node_service: The source node service (used for initiator type SERVICE)
:param _pol_source_node_service_state: The source node service state (used for initiator type SERVICE)
"""
self.id = _id
self.start_step = _start_step
self.end_step = _end_step
self.target_node_id = _target_node_id
self.initiator = _pol_initiator
self.id: str = _id
self.start_step: int = _start_step
self.end_step: int = _end_step
self.target_node_id: str = _target_node_id
self.initiator: "NodePOLInitiator" = _pol_initiator
self.pol_type: NodePOLType = _pol_type
self.service_name = pol_protocol # Not used when not a service instruction
self.state = _pol_state
self.source_node_id = _pol_source_node_id
self.source_node_service = _pol_source_node_service
self.service_name: str = pol_protocol # Not used when not a service instruction
self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _pol_state
self.source_node_id: str = _pol_source_node_id
self.source_node_service: str = _pol_source_node_service
self.source_node_service_state = _pol_source_node_service_state
def get_start_step(self):
def get_start_step(self) -> int:
"""
Gets the start step.
@@ -59,7 +63,7 @@ class NodeStateInstructionRed(object):
"""
return self.start_step
def get_end_step(self):
def get_end_step(self) -> int:
"""
Gets the end step.
@@ -68,7 +72,7 @@ class NodeStateInstructionRed(object):
"""
return self.end_step
def get_target_node_id(self):
def get_target_node_id(self) -> str:
"""
Gets the node ID.
@@ -77,7 +81,7 @@ class NodeStateInstructionRed(object):
"""
return self.target_node_id
def get_initiator(self):
def get_initiator(self) -> "NodePOLInitiator":
"""
Gets the initiator.
@@ -95,7 +99,7 @@ class NodeStateInstructionRed(object):
"""
return self.pol_type
def get_service_name(self):
def get_service_name(self) -> str:
"""
Gets the service name.
@@ -104,7 +108,7 @@ class NodeStateInstructionRed(object):
"""
return self.service_name
def get_state(self):
def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]:
"""
Gets the state (node or service).
@@ -113,7 +117,7 @@ class NodeStateInstructionRed(object):
"""
return self.state
def get_source_node_id(self):
def get_source_node_id(self) -> str:
"""
Gets the source node id (used for initiator type SERVICE).
@@ -122,7 +126,7 @@ class NodeStateInstructionRed(object):
"""
return self.source_node_id
def get_source_node_service(self):
def get_source_node_service(self) -> str:
"""
Gets the source node service (used for initiator type SERVICE).
@@ -131,7 +135,7 @@ class NodeStateInstructionRed(object):
"""
return self.source_node_service
def get_source_node_service_state(self):
def get_source_node_service_state(self) -> str:
"""
Gets the source node service state (used for initiator type SERVICE).

View File

@@ -16,7 +16,7 @@ class PassiveNode(Node):
priority: Priority,
hardware_state: HardwareState,
config_values: TrainingConfig,
):
) -> None:
"""
Initialise a passive node.

View File

@@ -25,7 +25,7 @@ class ServiceNode(ActiveNode):
software_state: SoftwareState,
file_system_state: FileSystemState,
config_values: TrainingConfig,
):
) -> None:
"""
Initialise a Service Node.
@@ -52,7 +52,7 @@ class ServiceNode(ActiveNode):
)
self.services: Dict[str, Service] = {}
def add_service(self, service: Service):
def add_service(self, service: Service) -> None:
"""
Adds a service to the node.
@@ -102,7 +102,7 @@ class ServiceNode(ActiveNode):
return False
return False
def set_service_state(self, protocol_name: str, software_state: SoftwareState):
def set_service_state(self, protocol_name: str, software_state: SoftwareState) -> None:
"""
Sets the software_state of a service (protocol) on the node.
@@ -131,7 +131,7 @@ class ServiceNode(ActiveNode):
f"Node.services[<key>].software_state:{software_state}"
)
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState):
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState) -> None:
"""
Sets the software_state of a service (protocol) on the node.
@@ -158,7 +158,7 @@ class ServiceNode(ActiveNode):
f"Node.services[<key>].software_state:{software_state}"
)
def get_service_state(self, protocol_name):
def get_service_state(self, protocol_name: str) -> SoftwareState:
"""
Gets the state of a service.
@@ -169,20 +169,20 @@ class ServiceNode(ActiveNode):
if service_value:
return service_value.software_state
def update_services_patching_status(self):
def update_services_patching_status(self) -> None:
"""Updates the patching counter for any service that are patching."""
for service_key, service_value in self.services.items():
if service_value.software_state == SoftwareState.PATCHING:
service_value.reduce_patching_count()
def update_resetting_status(self):
def update_resetting_status(self) -> None:
"""Update resetting counter and set software state if it reached 0."""
super().update_resetting_status()
if self.resetting_count <= 0:
for service in self.services.values():
service.software_state = SoftwareState.GOOD
def update_booting_status(self):
def update_booting_status(self) -> None:
"""Update booting counter and set software to good if it reached 0."""
super().update_booting_status()
if self.booting_count <= 0:

View File

@@ -5,13 +5,14 @@ import importlib.util
import os
import subprocess
import sys
from logging import Logger
from primaite import getLogger, NOTEBOOKS_DIR
_LOGGER = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
def start_jupyter_session():
def start_jupyter_session() -> None:
"""
Starts a new Jupyter notebook session in the app notebooks directory.

View File

@@ -1,6 +1,6 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Implements Pattern of Life on the network (nodes and links)."""
from typing import Dict, Union
from typing import Dict
from networkx import MultiGraph, shortest_path
@@ -10,11 +10,10 @@ from primaite.common.enums import HardwareState, NodePOLType, NodeType, Software
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from primaite.nodes.service_node import ServiceNode
from primaite.pol.ier import IER
_VERBOSE = False
_VERBOSE: bool = False
def apply_iers(
@@ -24,7 +23,7 @@ def apply_iers(
iers: Dict[str, IER],
acl: AccessControlList,
step: int,
):
) -> None:
"""
Applies IERs to the links (link pattern of life).
@@ -65,6 +64,8 @@ def apply_iers(
dest_node = nodes[dest_node_id]
# 1. Check the source node situation
# TODO: should be using isinstance rather than checking node type attribute. IE. just because it's a switch
# doesn't mean it has a software state? It could be a PassiveNode or ActiveNode
if source_node.node_type == NodeType.SWITCH:
# It's a switch
if (
@@ -215,9 +216,9 @@ def apply_iers(
def apply_node_pol(
nodes: Dict[str, NodeUnion],
node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]],
node_pol: Dict[str, NodeStateInstructionGreen],
step: int,
):
) -> None:
"""
Applies node pattern of life.

View File

@@ -11,17 +11,17 @@ class IER(object):
def __init__(
self,
_id,
_start_step,
_end_step,
_load,
_protocol,
_port,
_source_node_id,
_dest_node_id,
_mission_criticality,
_running=False,
):
_id: str,
_start_step: int,
_end_step: int,
_load: int,
_protocol: str,
_port: str,
_source_node_id: str,
_dest_node_id: str,
_mission_criticality: int,
_running: bool = False,
) -> None:
"""
Initialise an Information Exchange Request.
@@ -36,18 +36,18 @@ class IER(object):
:param _mission_criticality: Criticality of this IER to the mission (0 none, 5 mission critical)
:param _running: Indicates whether the IER is currently running
"""
self.id = _id
self.start_step = _start_step
self.end_step = _end_step
self.source_node_id = _source_node_id
self.dest_node_id = _dest_node_id
self.load = _load
self.protocol = _protocol
self.port = _port
self.mission_criticality = _mission_criticality
self.running = _running
self.id: str = _id
self.start_step: int = _start_step
self.end_step: int = _end_step
self.source_node_id: str = _source_node_id
self.dest_node_id: str = _dest_node_id
self.load: int = _load
self.protocol: str = _protocol
self.port: str = _port
self.mission_criticality: int = _mission_criticality
self.running: bool = _running
def get_id(self):
def get_id(self) -> str:
"""
Gets IER ID.
@@ -56,7 +56,7 @@ class IER(object):
"""
return self.id
def get_start_step(self):
def get_start_step(self) -> int:
"""
Gets IER start step.
@@ -65,7 +65,7 @@ class IER(object):
"""
return self.start_step
def get_end_step(self):
def get_end_step(self) -> int:
"""
Gets IER end step.
@@ -74,7 +74,7 @@ class IER(object):
"""
return self.end_step
def get_load(self):
def get_load(self) -> int:
"""
Gets IER load.
@@ -83,7 +83,7 @@ class IER(object):
"""
return self.load
def get_protocol(self):
def get_protocol(self) -> str:
"""
Gets IER protocol.
@@ -92,7 +92,7 @@ class IER(object):
"""
return self.protocol
def get_port(self):
def get_port(self) -> str:
"""
Gets IER port.
@@ -101,7 +101,7 @@ class IER(object):
"""
return self.port
def get_source_node_id(self):
def get_source_node_id(self) -> str:
"""
Gets IER source node ID.
@@ -110,7 +110,7 @@ class IER(object):
"""
return self.source_node_id
def get_dest_node_id(self):
def get_dest_node_id(self) -> str:
"""
Gets IER destination node ID.
@@ -119,7 +119,7 @@ class IER(object):
"""
return self.dest_node_id
def get_is_running(self):
def get_is_running(self) -> bool:
"""
Informs whether the IER is currently running.
@@ -128,7 +128,7 @@ class IER(object):
"""
return self.running
def set_is_running(self, _value):
def set_is_running(self, _value: bool) -> None:
"""
Sets the running state of the IER.
@@ -137,7 +137,7 @@ class IER(object):
"""
self.running = _value
def get_mission_criticality(self):
def get_mission_criticality(self) -> int:
"""
Gets the IER mission criticality (used in the reward function).

View File

@@ -4,6 +4,7 @@ from typing import Dict
from networkx import MultiGraph, shortest_path
from primaite import getLogger
from primaite.acl.access_control_list import AccessControlList
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import HardwareState, NodePOLInitiator, NodePOLType, NodeType, SoftwareState
@@ -13,7 +14,9 @@ from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from primaite.nodes.service_node import ServiceNode
from primaite.pol.ier import IER
_VERBOSE = False
_LOGGER = getLogger(__name__)
_VERBOSE: bool = False
def apply_red_agent_iers(
@@ -23,7 +26,7 @@ def apply_red_agent_iers(
iers: Dict[str, IER],
acl: AccessControlList,
step: int,
):
) -> None:
"""
Applies IERs to the links (link POL) resulting from red agent attack.
@@ -74,6 +77,9 @@ def apply_red_agent_iers(
pass
else:
# It's not a switch or an actuator (so active node)
# TODO: this occurs after ruling out the possibility that the node is a switch or an actuator, but it
# could still be a passive/active node, therefore it won't have a hardware_state. The logic here needs
# to change according to duck typing.
if source_node.hardware_state == HardwareState.ON:
if source_node.has_service(protocol):
# Red agents IERs can only be valid if the source service is in a compromised state
@@ -213,7 +219,7 @@ def apply_red_agent_node_pol(
iers: Dict[str, IER],
node_pol: Dict[str, NodeStateInstructionRed],
step: int,
):
) -> None:
"""
Applies node pattern of life.
@@ -267,8 +273,7 @@ def apply_red_agent_node_pol(
# Do nothing, service not on this node
pass
else:
if _VERBOSE:
print("Node Red Agent PoL not allowed - misconfiguration")
_LOGGER.warning("Node Red Agent PoL not allowed - misconfiguration")
# Only apply the PoL if the checks have passed (based on the initiator type)
if passed_checks:
@@ -289,8 +294,7 @@ def apply_red_agent_node_pol(
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
target_node.set_file_system_state(state)
else:
if _VERBOSE:
print("Node Red Agent PoL not allowed - did not pass checks")
_LOGGER.debug("Node Red Agent PoL not allowed - did not pass checks")
else:
# PoL is not valid in this time step
pass

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
from pathlib import Path
from typing import Dict, Final, Optional, Union
from typing import Any, Dict, Final, Optional, Union
from primaite import getLogger
from primaite.agents.agent_abc import AgentSessionABC
@@ -32,7 +32,7 @@ class PrimaiteSession:
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
):
) -> None:
"""
The PrimaiteSession constructor.
@@ -72,7 +72,13 @@ class PrimaiteSession:
self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
def setup(self):
self._agent_session: AgentSessionABC = None # noqa
self.session_path: Path = None # noqa
self.timestamp_str: str = None # noqa
self.learning_path: Path = None # noqa
self.evaluation_path: Path = None # noqa
def setup(self) -> None:
"""Performs the session setup."""
if self._training_config.agent_framework == AgentFramework.CUSTOM:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}")
@@ -155,8 +161,8 @@ class PrimaiteSession:
def learn(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Train the agent.
@@ -167,8 +173,8 @@ class PrimaiteSession:
def evaluate(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
@@ -177,6 +183,6 @@ class PrimaiteSession:
if not self._training_config.session_type == SessionType.TRAIN:
self._agent_session.evaluate(**kwargs)
def close(self):
def close(self) -> None:
"""Closes the agent."""
self._agent_session.close()

View File

@@ -1,10 +1,11 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from primaite import getLogger
_LOGGER = getLogger(__name__)
def run():
def run() -> None:
"""Perform the full clean-up."""
pass

View File

@@ -2,16 +2,17 @@
import filecmp
import os
import shutil
from logging import Logger
from pathlib import Path
import pkg_resources
from primaite import getLogger, NOTEBOOKS_DIR
_LOGGER = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
def run(overwrite_existing: bool = True):
def run(overwrite_existing: bool = True) -> None:
"""
Resets the demo jupyter notebooks in the users app notebooks directory.

View File

@@ -11,7 +11,7 @@ from primaite import getLogger, USERS_CONFIG_DIR
_LOGGER = getLogger(__name__)
def run(overwrite_existing=True):
def run(overwrite_existing: bool = True) -> None:
"""
Resets the example config files in the users app config directory.

View File

@@ -1,10 +1,12 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from logging import Logger
from primaite import _USER_DIRS, getLogger, LOG_DIR, NOTEBOOKS_DIR
_LOGGER = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
def run():
def run() -> None:
"""
Handles creation of application directories and user directories.

View File

@@ -1,15 +1,19 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The Transaction class."""
from datetime import datetime
from typing import List, Tuple
from typing import List, Optional, Tuple, TYPE_CHECKING, Union
from primaite.common.enums import AgentIdentifier
if TYPE_CHECKING:
import numpy as np
from gym import spaces
class Transaction(object):
"""Transaction class."""
def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int):
def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int) -> None:
"""
Transaction constructor.
@@ -17,7 +21,7 @@ class Transaction(object):
:param episode_number: The episode number
:param step_number: The step number
"""
self.timestamp = datetime.now()
self.timestamp: datetime = datetime.now()
"The datetime of the transaction"
self.agent_identifier: AgentIdentifier = agent_identifier
"The agent identifier"
@@ -25,17 +29,17 @@ class Transaction(object):
"The episode number"
self.step_number: int = step_number
"The step number"
self.obs_space = None
self.obs_space: "spaces.Space" = None
"The observation space (pre)"
self.obs_space_pre = None
self.obs_space_pre: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None
"The observation space before any actions are taken"
self.obs_space_post = None
self.obs_space_post: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None
"The observation space after any actions are taken"
self.reward: float = None
self.reward: Optional[float] = None
"The reward value"
self.action_space = None
self.action_space: Optional[int] = None
"The action space invoked by the agent"
self.obs_space_description = None
self.obs_space_description: Optional[List[str]] = None
"The env observation space description"
def as_csv_data(self) -> Tuple[List, List]:
@@ -68,7 +72,7 @@ class Transaction(object):
return header, row
def _turn_action_space_to_array(action_space) -> List[str]:
def _turn_action_space_to_array(action_space: Union[int, List[int]]) -> List[str]:
"""
Turns action space into a string array so it can be saved to csv.
@@ -81,7 +85,7 @@ def _turn_action_space_to_array(action_space) -> List[str]:
return [str(action_space)]
def _turn_obs_space_to_array(obs_space, obs_assets, obs_features) -> List[str]:
def _turn_obs_space_to_array(obs_space: "np.ndarray", obs_assets: int, obs_features: int) -> List[str]:
"""
Turns observation space into a string array so it can be saved to csv.

View File

@@ -1,12 +1,13 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import os
from logging import Logger
from pathlib import Path
import pkg_resources
from primaite import getLogger
_LOGGER = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
def get_file_path(path: str) -> Path:

View File

@@ -1,7 +1,7 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import json
from pathlib import Path
from typing import Union
from typing import Any, Dict, Union
import yaml
@@ -10,7 +10,7 @@ from primaite import getLogger
_LOGGER = getLogger(__name__)
def parse_session_metadata(session_path: Union[Path, str], dict_only=False):
def parse_session_metadata(session_path: Union[Path, str], dict_only: bool = False) -> Dict[str, Any]:
"""
Loads a session metadata from the given directory path.

View File

@@ -7,6 +7,9 @@ from primaite import getLogger
from primaite.transactions.transaction import Transaction
if TYPE_CHECKING:
from io import TextIOWrapper
from pathlib import Path
from primaite.environment.primaite_env import Primaite
_LOGGER: Logger = getLogger(__name__)
@@ -29,7 +32,7 @@ class SessionOutputWriter:
env: "Primaite",
transaction_writer: bool = False,
learning_session: bool = True,
):
) -> None:
"""
Initialise the Session Output Writer.
@@ -42,15 +45,16 @@ class SessionOutputWriter:
determines the name of the folder which contains the final output csv. Defaults to True
:type learning_session: bool, optional
"""
self._env = env
self.transaction_writer = transaction_writer
self.learning_session = learning_session
self._env: "Primaite" = env
self.transaction_writer: bool = transaction_writer
self.learning_session: bool = learning_session
if self.transaction_writer:
fn = f"all_transactions_{self._env.timestamp_str}.csv"
else:
fn = f"average_reward_per_episode_{self._env.timestamp_str}.csv"
self._csv_file_path: "Path"
if self.learning_session:
self._csv_file_path = self._env.session_path / "learning" / fn
else:
@@ -58,26 +62,26 @@ class SessionOutputWriter:
self._csv_file_path.parent.mkdir(exist_ok=True, parents=True)
self._csv_file = None
self._csv_writer = None
self._csv_file: "TextIOWrapper" = None
self._csv_writer: "csv._writer" = None
self._first_write: bool = True
def _init_csv_writer(self):
def _init_csv_writer(self) -> None:
self._csv_file = open(self._csv_file_path, "w", encoding="UTF8", newline="")
self._csv_writer = csv.writer(self._csv_file)
def __del__(self):
def __del__(self) -> None:
self.close()
def close(self):
def close(self) -> None:
"""Close the cvs file."""
if self._csv_file:
self._csv_file.close()
_LOGGER.debug(f"Finished writing file: {self._csv_file_path}")
def write(self, data: Union[Tuple, Transaction]):
def write(self, data: Union[Tuple, Transaction]) -> None:
"""
Write a row of session data.

File diff suppressed because one or more lines are too long

View File

@@ -92,6 +92,7 @@
destination: 192.168.1.2
protocol: TCP
port: 80
position: 0
- item_type: ACL_RULE
id: '7'
permission: ALLOW
@@ -99,3 +100,4 @@
destination: 192.168.1.1
protocol: TCP
port: 80
position: 0

View File

@@ -0,0 +1,86 @@
- item_type: PORTS
ports_list:
- port: '80'
- port: '21'
- item_type: SERVICES
service_list:
- name: TCP
- name: FTP
########################################
# Nodes
- item_type: NODE
node_id: '1'
name: PC1
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.1
software_state: COMPROMISED
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: FTP
port: '21'
state: GOOD
- item_type: NODE
node_id: '2'
name: SERVER
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.2
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: FTP
port: '21'
state: OVERWHELMED
- item_type: NODE
node_id: '3'
name: SWITCH1
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.1.3
software_state: GOOD
file_system_state: GOOD
########################################
# Links
- item_type: LINK
id: '4'
name: link1
bandwidth: 1000
source: '1'
destination: '3'
- item_type: LINK
id: '5'
name: link2
bandwidth: 1000
source: '3'
destination: '2'
#########################################
# IERS
- item_type: GREEN_IER
id: '5'
start_step: 0
end_step: 5
load: 999
protocol: TCP
port: '80'
source: '1'
destination: '2'
mission_criticality: 5
#########################################
# ACL Rules

View File

@@ -0,0 +1,106 @@
# Main Config File
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_framework: SB3
agent_identifier: PPO
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: ANY
# Number of episodes for training to run per session
num_train_episodes: 1
# Number of time_steps for training per episode
num_train_steps: 5
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
implicit_acl_rule: DENY
# Total number of ACL rules allowed in the environment
max_number_acl_rules: 3
observation_space:
components:
- name: ACCESS_CONTROL_LIST
# Time delay between steps (for generic agents)
time_delay: 1
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAIN
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Environment config values
# The high value for the observation space
observation_space_high_value: 1_000_000_000
# Reward values
# Generic
all_ok: 0
# Node Hardware State
off_should_be_on: -10
off_should_be_resetting: -5
on_should_be_off: -2
on_should_be_resetting: -5
resetting_should_be_on: -5
resetting_should_be_off: -2
resetting: -3
# Node Software or Service State
good_should_be_patching: 2
good_should_be_compromised: 5
good_should_be_overwhelmed: 5
patching_should_be_good: -5
patching_should_be_compromised: 2
patching_should_be_overwhelmed: 2
patching: -3
compromised_should_be_good: -20
compromised_should_be_patching: -20
compromised_should_be_overwhelmed: -20
compromised: -20
overwhelmed_should_be_good: -20
overwhelmed_should_be_patching: -20
overwhelmed_should_be_compromised: -20
overwhelmed: -20
# Node File System State
good_should_be_repairing: 2
good_should_be_restoring: 2
good_should_be_corrupt: 5
good_should_be_destroyed: 10
repairing_should_be_good: -5
repairing_should_be_restoring: 2
repairing_should_be_corrupt: 2
repairing_should_be_destroyed: 0
repairing: -3
restoring_should_be_good: -10
restoring_should_be_repairing: -2
restoring_should_be_corrupt: 1
restoring_should_be_destroyed: 2
restoring: -6
corrupt_should_be_good: -10
corrupt_should_be_repairing: -10
corrupt_should_be_restoring: -10
corrupt_should_be_destroyed: 2
corrupt: -10
destroyed_should_be_good: -20
destroyed_should_be_repairing: -20
destroyed_should_be_restoring: -20
destroyed_should_be_corrupt: -20
destroyed: -20
scanning: -2
# IER status
red_ier_running: -5
green_ier_blocked: -10
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS
node_reset_duration: 5 # The time taken to reset a node (hardware)
service_patching_duration: 5 # The time taken to patch a service
file_system_repairing_limit: 5 # The time take to repair the file system
file_system_restoring_limit: 5 # The time take to restore the file system
file_system_scanning_limit: 5 # The time taken to scan the file system

View File

@@ -39,6 +39,11 @@ observation_space:
# Time delay between steps (for generic agents)
time_delay: 1
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
implicit_acl_rule: ALLOW
# Total number of ACL rules allowed in the environment
max_number_acl_rules: 4
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAIN
# Determine whether to load an agent from file

View File

@@ -37,6 +37,11 @@ observation_space:
time_delay: 1
# Filename of the scenario / laydown
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
implicit_acl_rule: ALLOW
# Total number of ACL rules allowed in the environment
max_number_acl_rules: 4
session_type: TRAIN
# Determine whether to load an agent from file
load_agent: False

View File

@@ -40,7 +40,8 @@ agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Environment config values
# The high value for the observation space
observation_space_high_value: 1_000_000_000
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
implicit_acl_rule: DENY
# Reward values
# Generic
all_ok: 0

View File

@@ -91,6 +91,8 @@ session_type: EVAL
# The high value for the observation space
observation_space_high_value: 1000000000
implicit_acl_rule: DENY
max_number_acl_rules: 10
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)

View File

@@ -36,7 +36,7 @@ random_red_agent: False
# Default is None (null)
seed: None
# Set whether the agent will be deterministic instead of stochastic
# Set whether the agent evaluation will be deterministic instead of stochastic
# Options are:
# True
# False
@@ -55,11 +55,11 @@ hard_coded_agent_view: FULL
action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# - name: ACCESS_CONTROL_LIST
# Number of episodes to run per session
num_train_episodes: 10

View File

@@ -36,7 +36,7 @@ random_red_agent: False
# Default is None (null)
seed: 67890
# Set whether the agent will be deterministic instead of stochastic
# Set whether the agent evaluation will be deterministic instead of stochastic
# Options are:
# True
# False
@@ -55,7 +55,6 @@ hard_coded_agent_view: FULL
action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES

View File

@@ -38,6 +38,15 @@ load_agent: False
# File path and file name of agent if you're loading one in
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
implicit_acl_rule: DENY
# Total number of ACL rules allowed in the environment
max_number_acl_rules: 10
observation_space:
components:
- name: ACCESS_CONTROL_LIST
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000

View File

@@ -1,10 +1,10 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
- item_type: PORTS
ports_list:
- port: '21'
- port: '80'
- item_type: SERVICES
service_list:
- name: ftp
- name: TCP
- item_type: NODE
node_id: '1'
name: node
@@ -16,8 +16,8 @@
software_state: GOOD
file_system_state: GOOD
services:
- name: ftp
port: '21'
- name: TCP
port: '80'
state: COMPROMISED
- item_type: NODE
node_id: '2'
@@ -30,15 +30,15 @@
software_state: GOOD
file_system_state: GOOD
services:
- name: ftp
port: '21'
- name: TCP
port: '80'
state: COMPROMISED
- item_type: RED_IER
id: '3'
start_step: 2
end_step: 15
load: 1000
protocol: ftp
protocol: TCP
port: CORRUPT
source: '1'
destination: '2'

View File

@@ -47,6 +47,9 @@ agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# The high value for the observation space
observation_space_high_value: 1000000000
# Choice whether to have an ALLOW or DENY implicit rule or not (TRUE or FALSE)
implicit_acl_rule: DENY
max_number_acl_rules: 10
# Reward values
# Generic
all_ok: 0

View File

@@ -3,40 +3,41 @@
from primaite.acl.access_control_list import AccessControlList
from primaite.acl.acl_rule import ACLRule
from primaite.common.enums import RulePermissionType
def test_acl_address_match_1():
"""Test that matching IP addresses produce True."""
acl = AccessControlList()
acl = AccessControlList(RulePermissionType.DENY, 10)
rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80")
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "192.168.1.2", "TCP", "80")
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
def test_acl_address_match_2():
"""Test that mismatching IP addresses produce False."""
acl = AccessControlList()
acl = AccessControlList(RulePermissionType.DENY, 10)
rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80")
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "192.168.1.2", "TCP", "80")
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.3") == False
def test_acl_address_match_3():
"""Test the ANY condition for source IP addresses produce True."""
acl = AccessControlList()
acl = AccessControlList(RulePermissionType.DENY, 10)
rule = ACLRule("ALLOW", "ANY", "192.168.1.2", "TCP", "80")
rule = ACLRule(RulePermissionType.ALLOW, "ANY", "192.168.1.2", "TCP", "80")
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
def test_acl_address_match_4():
"""Test the ANY condition for dest IP addresses produce True."""
acl = AccessControlList()
acl = AccessControlList(RulePermissionType.DENY, 10)
rule = ACLRule("ALLOW", "192.168.1.1", "ANY", "TCP", "80")
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "ANY", "TCP", "80")
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
@@ -44,14 +45,15 @@ def test_acl_address_match_4():
def test_check_acl_block_affirmative():
"""Test the block function (affirmative)."""
# Create the Access Control List
acl = AccessControlList()
acl = AccessControlList(RulePermissionType.DENY, 10)
# Create a rule
acl_rule_permission = "ALLOW"
acl_rule_permission = RulePermissionType.ALLOW
acl_rule_source = "192.168.1.1"
acl_rule_destination = "192.168.1.2"
acl_rule_protocol = "TCP"
acl_rule_port = "80"
acl_position_in_list = "0"
acl.add_rule(
acl_rule_permission,
@@ -59,22 +61,23 @@ def test_check_acl_block_affirmative():
acl_rule_destination,
acl_rule_protocol,
acl_rule_port,
acl_position_in_list,
)
assert acl.is_blocked("192.168.1.1", "192.168.1.2", "TCP", "80") == False
def test_check_acl_block_negative():
"""Test the block function (negative)."""
# Create the Access Control List
acl = AccessControlList()
acl = AccessControlList(RulePermissionType.DENY, 10)
# Create a rule
acl_rule_permission = "DENY"
acl_rule_permission = RulePermissionType.DENY
acl_rule_source = "192.168.1.1"
acl_rule_destination = "192.168.1.2"
acl_rule_protocol = "TCP"
acl_rule_port = "80"
acl_position_in_list = "0"
acl.add_rule(
acl_rule_permission,
@@ -82,6 +85,7 @@ def test_check_acl_block_negative():
acl_rule_destination,
acl_rule_protocol,
acl_rule_port,
acl_position_in_list,
)
assert acl.is_blocked("192.168.1.1", "192.168.1.2", "TCP", "80") == True
@@ -90,11 +94,73 @@ def test_check_acl_block_negative():
def test_rule_hash():
"""Test the rule hash."""
# Create the Access Control List
acl = AccessControlList()
acl = AccessControlList(RulePermissionType.DENY, 10)
rule = ACLRule("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
rule = ACLRule(RulePermissionType.DENY, "192.168.1.1", "192.168.1.2", "TCP", "80")
hash_value_local = hash(rule)
hash_value_remote = acl.get_dictionary_hash("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
hash_value_remote = acl.get_dictionary_hash(RulePermissionType.DENY, "192.168.1.1", "192.168.1.2", "TCP", "80")
assert hash_value_local == hash_value_remote
def test_delete_rule():
"""Adds 3 rules and deletes 1 rule and checks its deletion."""
# Create the Access Control List
acl = AccessControlList(RulePermissionType.ALLOW, 10)
# Create a first rule
acl_rule_permission = RulePermissionType.DENY
acl_rule_source = "192.168.1.1"
acl_rule_destination = "192.168.1.2"
acl_rule_protocol = "TCP"
acl_rule_port = "80"
acl_position_in_list = "0"
acl.add_rule(
acl_rule_permission,
acl_rule_source,
acl_rule_destination,
acl_rule_protocol,
acl_rule_port,
acl_position_in_list,
)
# Create a second rule
acl_rule_permission = RulePermissionType.DENY
acl_rule_source = "20"
acl_rule_destination = "30"
acl_rule_protocol = "FTP"
acl_rule_port = "21"
acl_position_in_list = "2"
acl.add_rule(
acl_rule_permission,
acl_rule_source,
acl_rule_destination,
acl_rule_protocol,
acl_rule_port,
acl_position_in_list,
)
# Create a third rule
acl_rule_permission = RulePermissionType.ALLOW
acl_rule_source = "192.168.1.3"
acl_rule_destination = "192.168.1.1"
acl_rule_protocol = "UDP"
acl_rule_port = "60"
acl_position_in_list = "4"
acl.add_rule(
acl_rule_permission,
acl_rule_source,
acl_rule_destination,
acl_rule_protocol,
acl_rule_port,
acl_position_in_list,
)
# Remove the second ACL rule added from the list
acl.remove_rule(RulePermissionType.DENY, "20", "30", "FTP", "21")
assert len(acl.acl) == 10
assert acl.acl[2] is None

View File

@@ -1,5 +1,6 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Test env creation and behaviour with different observation spaces."""
import numpy as np
import pytest
@@ -237,3 +238,140 @@ class TestLinkTrafficLevels:
# therefore the first and third elements should be 6 and all others 0
# (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%)
assert np.array_equal(obs, [6, 0, 6, 0])
@pytest.mark.parametrize(
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT / "obs_tests/main_config_ACCESS_CONTROL_LIST.yaml",
TEST_CONFIG_ROOT / "obs_tests/laydown_ACL.yaml",
]
],
indirect=True,
)
class TestAccessControlList:
"""Test the AccessControlList observation component (in isolation)."""
def test_obs_shape(self, temp_primaite_session):
"""Try creating env with MultiDiscrete observation space.
The laydown has 3 ACL Rules - that is the maximum_acl_rules it can have.
Each ACL Rule in the observation space has 6 different elements:
6 * 3 = 18
"""
with temp_primaite_session as session:
env = session.env
env.update_environent_obs()
assert env.env_obs.shape == (18,)
def test_values(self, temp_primaite_session):
"""Test that traffic values are encoded correctly.
The laydown has:
* one ACL IMPLICIT DENY rule
Therefore, the ACL is full of NAs aka zeros and just 6 non-zero elements representing DENY ANY ANY ANY at
Position 2.
"""
with temp_primaite_session as session:
env = session.env
obs, reward, done, info = env.step(0)
obs, reward, done, info = env.step(0)
assert np.array_equal(obs, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2])
def test_observation_space_with_implicit_rule(self, temp_primaite_session):
"""
Test observation space is what is expected when an agent adds ACLs during an episode.
At the start of the episode, there is a single implicit DENY rule
In the observation space IMPLICIT DENY: 1,1,1,1,1,0
0 shows the rule is the start (when episode began no other rules were created) so this is correct.
On Step 2, there is an ACL rule added at Position 0: 2,2,3,2,3,0
On Step 4, there is a second ACL rule added at POSITION 1: 2,4,2,3,3,1
The final observation space should be this:
[2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2]
The ACL Rule from Step 2 is added first and has a HIGHER position than the ACL rule from Step 4
but both come before the IMPLICIT DENY which will ALWAYS be at the end of the ACL List.
"""
# TODO: Refactor this at some point to build a custom ACL Hardcoded
# Agent and then patch the AgentIdentifier Enum class so that it
# has ACL_AGENT. This then allows us to set the agent identified in
# the main config and is a bit cleaner.
with temp_primaite_session as session:
env = session.env
training_config = env.training_config
for episode in range(0, training_config.num_train_episodes):
for step in range(0, training_config.num_train_steps):
# Do nothing action
action = 0
if step == 2:
# Action to add the first ACL rule
action = 43
elif step == 4:
# Action to add the second ACL rule
action = 96
# Run the simulation step on the live environment
obs, reward, done, info = env.step(action)
# Break if done is True
if done:
break
obs = env.env_obs
assert np.array_equal(obs, [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2])
def test_observation_space_with_different_positions(self, temp_primaite_session):
"""
Test observation space is what is expected when an agent adds ACLs during an episode.
At the start of the episode, there is a single implicit DENY rule
In the observation space IMPLICIT DENY: 1,1,1,1,1,0
0 shows the rule is the start (when episode began no other rules were created) so this is correct.
On Step 2, there is an ACL rule added at Position 1: 2,2,3,2,3,1
On Step 4 there is a second ACL rule added at Position 0: 2,4,2,3,3,0
The final observation space should be this:
[2 , 4, 2, 3, 3, 0, 2, 2, 3, 2, 3, 1, 1, 1, 1, 1, 1, 2]
The ACL Rule from Step 2 is added before and has a LOWER position than the ACL rule from Step 4
but both come before the IMPLICIT DENY which will ALWAYS be at the end of the ACL List.
"""
# TODO: Refactor this at some point to build a custom ACL Hardcoded
# Agent and then patch the AgentIdentifier Enum class so that it
# has ACL_AGENT. This then allows us to set the agent identified in
# the main config and is a bit cleaner.
with temp_primaite_session as session:
env = session.env
training_config = env.training_config
for episode in range(0, training_config.num_train_episodes):
for step in range(0, training_config.num_train_steps):
# Do nothing action
action = 0
if step == 2:
# Action to add the first ACL rule
action = 44
elif step == 4:
# Action to add the second ACL rule
action = 95
# Run the simulation step on the live environment
obs, reward, done, info = env.step(action)
# Break if done is True
if done:
break
obs = env.env_obs
assert np.array_equal(obs, [2, 4, 2, 3, 3, 0, 2, 2, 3, 2, 3, 1, 1, 1, 1, 1, 1, 2])

View File

@@ -11,30 +11,46 @@ from tests import TEST_CONFIG_ROOT
indirect=True,
)
def test_seeded_learning(temp_primaite_session):
"""Test running seeded learning produces the same output when ran twice."""
"""
Test running seeded learning produces the same output when ran twice.
.. note::
If this is failing, the hard-coded expected_mean_reward_per_episode
from a pre-trained agent will probably need to be updated. If the
env changes and those changed how this agent is trained, chances are
the mean rewards are going to be different.
Run the test, but print out the session.learn_av_reward_per_episode()
before comparing it. Then copy the printed dict and replace the
expected_mean_reward_per_episode with those values. The test should
now work. If not, then you've got a bug :).
"""
expected_mean_reward_per_episode = {
1: -90.703125,
2: -91.15234375,
3: -87.5,
4: -92.2265625,
5: -94.6875,
6: -91.19140625,
7: -88.984375,
8: -88.3203125,
9: -112.79296875,
10: -100.01953125,
1: -20.7421875,
2: -19.82421875,
3: -17.01171875,
4: -19.08203125,
5: -21.93359375,
6: -20.21484375,
7: -15.546875,
8: -12.08984375,
9: -17.59765625,
10: -14.6875,
}
with temp_primaite_session as session:
assert session._training_config.seed == 67890, (
"Expected output is based upon a agent that was trained with " "seed 67890"
)
assert (
session._training_config.seed == 67890
), "Expected output is based upon a agent that was trained with seed 67890"
session.learn()
actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict()
print(actual_mean_reward_per_episode, "THISt")
assert actual_mean_reward_per_episode == expected_mean_reward_per_episode
@pytest.mark.skip(reason="Inconsistent results. Needs someone with RL " "knowledge to investigate further.")
@pytest.mark.skip(reason="Inconsistent results. Needs someone with RL knowledge to investigate further.")
@pytest.mark.parametrize(
"temp_primaite_session",
[[TEST_CONFIG_ROOT / "ppo_seeded_training_config.yaml", dos_very_basic_config_path()]],

View File

@@ -6,6 +6,8 @@ from pathlib import Path
from typing import Union
from uuid import uuid4
import pytest
from primaite import getLogger
from primaite.agents.sb3 import SB3Agent
from primaite.common.enums import AgentFramework, AgentIdentifier
@@ -41,6 +43,9 @@ def copy_session_asset(asset_path: Union[str, Path]) -> str:
return copy_path
@pytest.mark.xfail(
reason="Loading works fine but the exact values change with code changes, a bug report has been created."
)
def test_load_sb3_session():
"""Test that loading an SB3 agent works."""
expected_learn_mean_reward_per_episode = {
@@ -97,6 +102,7 @@ def test_load_sb3_session():
shutil.rmtree(test_path)
@pytest.mark.xfail(reason="Temporarily don't worry about this not working")
def test_load_primaite_session():
"""Test that loading a Primaite session works."""
expected_learn_mean_reward_per_episode = {
@@ -157,6 +163,7 @@ def test_load_primaite_session():
shutil.rmtree(test_path)
@pytest.mark.xfail(reason="Temporarily don't worry about this not working")
def test_run_loading():
"""Test loading session via main.run."""
expected_learn_mean_reward_per_episode = {

View File

@@ -3,6 +3,7 @@ import time
import pytest
from primaite.acl.acl_rule import ACLRule
from primaite.common.enums import HardwareState
from primaite.environment.primaite_env import Primaite
from tests import TEST_CONFIG_ROOT
@@ -19,16 +20,17 @@ def run_generic_set_actions(env: Primaite):
# TEMP - random action for now
# action = env.blue_agent_action(obs)
action = 0
# print("Episode:", episode, "\nStep:", step)
if step == 5:
# [1, 1, 2, 1, 1, 1]
# [1, 1, 2, 1, 1, 1, 1(position)]
# Creates an ACL rule
# Allows traffic from server_1 to node_1 on port FTP
action = 7
action = 56
elif step == 7:
# [1, 1, 2, 0] Node Action
# Sets Node 1 Hardware State to OFF
# Does not resolve any service
action = 16
action = 128
# Run the simulation step on the live environment
obs, reward, done, info = env.step(action)
@@ -57,6 +59,10 @@ def run_generic_set_actions(env: Primaite):
)
def test_single_action_space_is_valid(temp_primaite_session):
"""Test single action space is valid."""
# TODO: Refactor this at some point to build a custom ACL Hardcoded
# Agent and then patch the AgentIdentifier Enum class so that it
# has ACL_AGENT. This then allows us to set the agent identified in
# the main config and is a bit cleaner.
with temp_primaite_session as session:
env = session.env
@@ -73,7 +79,7 @@ def test_single_action_space_is_valid(temp_primaite_session):
if len(dict_item) == 4:
contains_node_actions = True
# Link action detected
elif len(dict_item) == 6:
elif len(dict_item) == 7:
contains_acl_actions = True
# If both are there then the ANY action type is working
if contains_node_actions and contains_acl_actions:
@@ -94,6 +100,10 @@ def test_single_action_space_is_valid(temp_primaite_session):
)
def test_agent_is_executing_actions_from_both_spaces(temp_primaite_session):
"""Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL)."""
# TODO: Refactor this at some point to build a custom ACL Hardcoded
# Agent and then patch the AgentIdentifier Enum class so that it
# has ACL_AGENT. This then allows us to set the agent identified in
# the main config and is a bit cleaner.
with temp_primaite_session as session:
env = session.env
# Run environment with specified fixed blue agent actions only
@@ -105,11 +115,15 @@ def test_agent_is_executing_actions_from_both_spaces(temp_primaite_session):
access_control_list = env.acl
# Use the Access Control List object acl object attribute to get dictionary
# Use dictionary.values() to get total list of all items in the dictionary
acl_rules_list = access_control_list.acl.values()
acl_rules_list = access_control_list.acl
# Length of this list tells you how many items are in the dictionary
# This number is the frequency of Access Control Rules in the environment
# In the scenario, we specified that the agent should create only 1 acl rule
num_of_rules = len(acl_rules_list)
# This 1 rule added to the implicit deny means there should be 2 rules in total.
rules_count = 0
for rule in acl_rules_list:
if isinstance(rule, ACLRule):
rules_count += 1
# Therefore these statements below MUST be true
assert computer_node_hardware_state == HardwareState.OFF
assert num_of_rules == 1
assert rules_count == 2