Merge branch 'dev' into feature/1632_Add_benchmarking_scripts
This commit is contained in:
@@ -66,11 +66,11 @@ The environment config file consists of the following attributes:
|
||||
.. code-block:: yaml
|
||||
|
||||
observation_space:
|
||||
flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
- name: NODE_STATUSES
|
||||
- name: LINK_TRAFFIC_LEVELS
|
||||
- name: ACCESS_CONTROL_LIST
|
||||
options:
|
||||
combine_service_traffic : False
|
||||
quantisation_levels: 99
|
||||
@@ -80,6 +80,7 @@ The environment config file consists of the following attributes:
|
||||
|
||||
* :py:mod:`NODE_LINK_TABLE<primaite.environment.observations.NodeLinkTable>` this does not accept any additional options
|
||||
* :py:mod:`NODE_STATUSES<primaite.environment.observations.NodeStatuses>`, this does not accept any additional options
|
||||
* :py:mod:`ACCESS_CONTROL_LIST<primaite.environment.observations.AccessControlList>`, this does not accept additional options
|
||||
* :py:mod:`LINK_TRAFFIC_LEVELS<primaite.environment.observations.LinkTrafficLevels>`, this accepts the following options:
|
||||
|
||||
* ``combine_service_traffic`` - whether to consider bandwidth use separately for each network protocol or combine them into a single bandwidth reading (boolean)
|
||||
@@ -128,6 +129,14 @@ The environment config file consists of the following attributes:
|
||||
|
||||
The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases
|
||||
|
||||
* **implicit_acl_rule** [str]
|
||||
|
||||
Determines which Explicit rule the ACL list has - two options are: DENY or ALLOW.
|
||||
|
||||
* **max_number_acl_rules** [int]
|
||||
|
||||
Sets a limit on how many ACL rules there can be in the ACL list throughout the training session.
|
||||
|
||||
**Reward-Based Config Values**
|
||||
|
||||
Rewards are calculated based on the difference between the current state and reference state (the 'should be' state) of the environment.
|
||||
@@ -477,3 +486,4 @@ The lay down config file consists of the following attributes:
|
||||
* **destination** [IP address]: Defines the destination IP address for the rule in xxx.xxx.xxx.xxx format
|
||||
* **protocol** [freetext]: Defines the protocol for the rule. Must match a value in the services list
|
||||
* **port** [int]: Defines the port for the rule. Must match a value in the ports list
|
||||
* **position** [int]: Defines where to place the ACL rule in the list. Lower index or (higher up in the list) means they are checked first. Index starts at 0 (Python indexes).
|
||||
|
||||
@@ -53,3 +53,5 @@ v1.2 to v2.0 Migration guide
|
||||
* hard coded agent view
|
||||
|
||||
Each of these items have default values which are designed so that PrimAITE has the same behaviour as it did in 1.2.0, so you do not have to specify them.
|
||||
|
||||
ACL Rules in laydown configs have a new required parameter: ``position``. The lower the position, the higher up in the ACL table the rule will placed. If you have custom laydowns, you will need to go through them and add a position to each ACL_RULE.
|
||||
|
||||
@@ -47,6 +47,105 @@ The sub-directory is formatted as such: ``~/primaite/sessions/<yyyy-mm-dd>/<yyyy
|
||||
For example, when running a session at 17:30:00 on 31st January 2023, the session will output to:
|
||||
``~/primaite/sessions/2023-01-31/2023-01-31_17-30-00/``.
|
||||
|
||||
|
||||
Outputs
|
||||
-------
|
||||
|
||||
PrimAITE produces four types of outputs:
|
||||
|
||||
* Session Metadata
|
||||
* Results
|
||||
* Diagrams
|
||||
* Saved agents (training checkpoints and a final trained agent)
|
||||
|
||||
|
||||
**Session Metadata**
|
||||
|
||||
PrimAITE creates a ``session_metadata.json`` file that contains the following metadata:
|
||||
|
||||
* **uuid** - The UUID assigned to the session upon instantiation.
|
||||
* **start_datetime** - The date & time the session started in iso format.
|
||||
* **end_datetime** - The date & time the session ended in iso format.
|
||||
* **learning**
|
||||
* **total_episodes** - The total number of training episodes completed.
|
||||
* **total_time_steps** - The total number of training time steps completed.
|
||||
* **evaluation**
|
||||
* **total_episodes** - The total number of evaluation episodes completed.
|
||||
* **total_time_steps** - The total number of evaluation time steps completed.
|
||||
* **env**
|
||||
* **training_config**
|
||||
* **All training config items**
|
||||
* **lay_down_config**
|
||||
* **All lay down config items**
|
||||
|
||||
|
||||
**Results**
|
||||
|
||||
PrimAITE automatically creates two sets of results from each learning and evaluation session:
|
||||
|
||||
* Average reward per episode - a csv file listing the average reward for each episode of the session. This provides, for example, an indication of the change over a training session of the reward value
|
||||
* All transactions - a csv file listing the following values for every step of every episode:
|
||||
|
||||
* Timestamp
|
||||
* Episode number
|
||||
* Step number
|
||||
* Reward value
|
||||
* Action taken (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X
|
||||
* Initial observation space (what the blue agent observed when it decided its action)
|
||||
|
||||
**Diagrams**
|
||||
|
||||
* For each session, PrimAITE automatically creates a visualisation of the system / network lay down configuration.
|
||||
* For each learning and evaluation task within the session, PrimAITE automatically plots the average reward per episode using PlotLY and saves it to the learning or evaluation subdirectory in the session directory.
|
||||
|
||||
**Saved agents**
|
||||
|
||||
For each training session, assuming the agent being trained implements the *save()* function and this function is called by the code, PrimAITE automatically saves the agent state.
|
||||
|
||||
**Example Session Directory Structure**
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
~/
|
||||
└── primaite/
|
||||
└── sessions/
|
||||
└── 2023-07-18/
|
||||
└── 2023-07-18_11-06-04/
|
||||
├── evaluation/
|
||||
│ ├── all_transactions_2023-07-18_11-06-04.csv
|
||||
│ ├── average_reward_per_episode_2023-07-18_11-06-04.csv
|
||||
│ └── average_reward_per_episode_2023-07-18_11-06-04.png
|
||||
├── learning/
|
||||
│ ├── all_transactions_2023-07-18_11-06-04.csv
|
||||
│ ├── average_reward_per_episode_2023-07-18_11-06-04.csv
|
||||
│ ├── average_reward_per_episode_2023-07-18_11-06-04.png
|
||||
│ ├── checkpoints/
|
||||
│ │ └── sb3ppo_10.zip
|
||||
│ ├── SB3_PPO.zip
|
||||
│ └── tensorboard_logs/
|
||||
│ ├── PPO_1/
|
||||
│ │ └── events.out.tfevents.1689674765.METD-9PMRFB3.42960.0
|
||||
│ ├── PPO_2/
|
||||
│ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.1
|
||||
│ ├── PPO_3/
|
||||
│ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.2
|
||||
│ ├── PPO_4/
|
||||
│ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.3
|
||||
│ ├── PPO_5/
|
||||
│ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.4
|
||||
│ ├── PPO_6/
|
||||
│ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.5
|
||||
│ ├── PPO_7/
|
||||
│ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.6
|
||||
│ ├── PPO_8/
|
||||
│ │ └── events.out.tfevents.1689674769.METD-9PMRFB3.42960.7
|
||||
│ ├── PPO_9/
|
||||
│ │ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.8
|
||||
│ └── PPO_10/
|
||||
│ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.9
|
||||
├── network_2023-07-18_11-06-04.png
|
||||
└── session_metadata.json
|
||||
|
||||
Loading a session
|
||||
-----------------
|
||||
|
||||
@@ -78,52 +177,3 @@ A previous session can be loaded by providing the **directory** of the previous
|
||||
run(session_path=<previous session directory>)
|
||||
|
||||
When PrimAITE runs a loaded session, PrimAITE will output in the provided session directory
|
||||
|
||||
Outputs
|
||||
-------
|
||||
|
||||
PrimAITE produces four types of outputs:
|
||||
|
||||
* Session Metadata
|
||||
* Results
|
||||
* Diagrams
|
||||
* Saved agents
|
||||
|
||||
|
||||
**Session Metadata**
|
||||
|
||||
PrimAITE creates a ``session_metadata.json`` file that contains the following metadata:
|
||||
|
||||
* **uuid** - The UUID assigned to the session upon instantiation.
|
||||
* **start_datetime** - The date & time the session started in iso format.
|
||||
* **end_datetime** - The date & time the session ended in iso format.
|
||||
* **total_episodes** - The total number of training episodes completed.
|
||||
* **total_time_steps** - The total number of training time steps completed.
|
||||
* **env**
|
||||
* **training_config**
|
||||
* **All training config items**
|
||||
* **lay_down_config**
|
||||
* **All lay down config items**
|
||||
|
||||
|
||||
**Results**
|
||||
|
||||
PrimAITE automatically creates two sets of results from each session:
|
||||
|
||||
* Average reward per episode - a csv file listing the average reward for each episode of the session. This provides, for example, an indication of the change over a training session of the reward value
|
||||
* All transactions - a csv file listing the following values for every step of every episode:
|
||||
|
||||
* Timestamp
|
||||
* Episode number
|
||||
* Step number
|
||||
* Reward value
|
||||
* Action taken (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X
|
||||
* Initial observation space (what the blue agent observed when it decided its action)
|
||||
|
||||
**Diagrams**
|
||||
|
||||
For each session, PrimAITE automatically creates a visualisation of the system / network lay down configuration.
|
||||
|
||||
**Saved agents**
|
||||
|
||||
For each training session, assuming the agent being trained implements the *save()* function and this function is called by the code, PrimAITE automatically saves the agent state.
|
||||
|
||||
@@ -6,7 +6,7 @@ from bisect import bisect
|
||||
from logging import Formatter, Logger, LogRecord, StreamHandler
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from typing import Dict, Final
|
||||
from typing import Any, Dict, Final
|
||||
|
||||
import pkg_resources
|
||||
import yaml
|
||||
@@ -16,7 +16,7 @@ _PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite")
|
||||
"""An instance of `PlatformDirs` set with appname='primaite'."""
|
||||
|
||||
|
||||
def _get_primaite_config():
|
||||
def _get_primaite_config() -> Dict:
|
||||
config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml"
|
||||
if not config_path.exists():
|
||||
config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
|
||||
@@ -72,7 +72,7 @@ class _LevelFormatter(Formatter):
|
||||
Credit to: https://stackoverflow.com/a/68154386
|
||||
"""
|
||||
|
||||
def __init__(self, formats: Dict[int, str], **kwargs):
|
||||
def __init__(self, formats: Dict[int, str], **kwargs: Any) -> None:
|
||||
super().__init__()
|
||||
|
||||
if "fmt" in kwargs:
|
||||
|
||||
@@ -1,16 +1,38 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""A class that implements the access control list implementation for the network."""
|
||||
from typing import Dict
|
||||
import logging
|
||||
from typing import Dict, Final, List, Union
|
||||
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.common.enums import RulePermissionType
|
||||
|
||||
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AccessControlList:
|
||||
"""Access Control List class."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialise an empty AccessControlList."""
|
||||
self.acl: Dict[str, ACLRule] = {} # A dictionary of ACL Rules
|
||||
def __init__(self, implicit_permission: RulePermissionType, max_acl_rules: int) -> None:
|
||||
"""Init."""
|
||||
# Implicit ALLOW or DENY firewall spec
|
||||
self.acl_implicit_permission = implicit_permission
|
||||
# Implicit rule in ACL list
|
||||
if self.acl_implicit_permission == RulePermissionType.DENY:
|
||||
self.acl_implicit_rule = ACLRule(RulePermissionType.DENY, "ANY", "ANY", "ANY", "ANY")
|
||||
elif self.acl_implicit_permission == RulePermissionType.ALLOW:
|
||||
self.acl_implicit_rule = ACLRule(RulePermissionType.ALLOW, "ANY", "ANY", "ANY", "ANY")
|
||||
else:
|
||||
raise ValueError(f"implicit permission must be ALLOW or DENY, got {self.acl_implicit_permission}")
|
||||
|
||||
# Maximum number of ACL Rules in ACL
|
||||
self.max_acl_rules: int = max_acl_rules
|
||||
# A list of ACL Rules
|
||||
self._acl: List[Union[ACLRule, None]] = [None] * (self.max_acl_rules - 1)
|
||||
|
||||
@property
|
||||
def acl(self) -> List[Union[ACLRule, None]]:
|
||||
"""Public access method for private _acl."""
|
||||
return self._acl + [self.acl_implicit_rule]
|
||||
|
||||
def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool:
|
||||
"""Checks for IP address matches.
|
||||
@@ -47,21 +69,30 @@ class AccessControlList:
|
||||
Returns:
|
||||
Indicates block if all conditions are satisfied.
|
||||
"""
|
||||
for rule_key, rule_value in self.acl.items():
|
||||
if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address):
|
||||
if (rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY") and (
|
||||
str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY"
|
||||
):
|
||||
# There's a matching rule. Get the permission
|
||||
if rule_value.get_permission() == "DENY":
|
||||
return True
|
||||
elif rule_value.get_permission() == "ALLOW":
|
||||
return False
|
||||
for rule in self.acl:
|
||||
if isinstance(rule, ACLRule):
|
||||
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
|
||||
if (rule.get_protocol() == _protocol or rule.get_protocol() == "ANY") and (
|
||||
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY"
|
||||
):
|
||||
# There's a matching rule. Get the permission
|
||||
if rule.get_permission() == RulePermissionType.DENY:
|
||||
return True
|
||||
elif rule.get_permission() == RulePermissionType.ALLOW:
|
||||
return False
|
||||
|
||||
# If there has been no rule to allow the IER through, it will return a blocked signal by default
|
||||
return True
|
||||
|
||||
def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
def add_rule(
|
||||
self,
|
||||
_permission: RulePermissionType,
|
||||
_source_ip: str,
|
||||
_dest_ip: str,
|
||||
_protocol: str,
|
||||
_port: str,
|
||||
_position: str,
|
||||
) -> None:
|
||||
"""
|
||||
Adds a new rule.
|
||||
|
||||
@@ -71,12 +102,36 @@ class AccessControlList:
|
||||
_dest_ip: the destination IP address
|
||||
_protocol: the protocol
|
||||
_port: the port
|
||||
_position: position to insert ACL rule into ACL list (starting from index 1 and NOT 0)
|
||||
"""
|
||||
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
hash_value = hash(new_rule)
|
||||
self.acl[hash_value] = new_rule
|
||||
try:
|
||||
position_index = int(_position)
|
||||
except TypeError:
|
||||
_LOGGER.info(f"Position {_position} could not be converted to integer.")
|
||||
return
|
||||
|
||||
def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
# Checks position is in correct range
|
||||
if self.max_acl_rules - 1 > position_index > -1:
|
||||
try:
|
||||
_LOGGER.info(f"Position {position_index} is valid.")
|
||||
# Check to see Agent will not overwrite current ACL in ACL list
|
||||
if self._acl[position_index] is None:
|
||||
_LOGGER.info(f"Inserting rule {new_rule} at position {position_index}")
|
||||
# Adds rule
|
||||
self._acl[position_index] = new_rule
|
||||
else:
|
||||
# Cannot overwrite it
|
||||
_LOGGER.info(f"Error: inserting rule at non-empty position {position_index}")
|
||||
return
|
||||
except Exception:
|
||||
_LOGGER.info(f"New Rule could NOT be added to list at position {position_index}.")
|
||||
else:
|
||||
_LOGGER.info(f"Position {position_index} is an invalid/overwrites implicit firewall rule")
|
||||
|
||||
def remove_rule(
|
||||
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
|
||||
) -> None:
|
||||
"""
|
||||
Removes a rule.
|
||||
|
||||
@@ -87,19 +142,21 @@ class AccessControlList:
|
||||
_protocol: the protocol
|
||||
_port: the port
|
||||
"""
|
||||
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
hash_value = hash(rule)
|
||||
# There will not always be something 'popable' since the agent will be trying random things
|
||||
try:
|
||||
self.acl.pop(hash_value)
|
||||
except Exception:
|
||||
return
|
||||
rule_to_delete = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
delete_rule_hash = hash(rule_to_delete)
|
||||
|
||||
def remove_all_rules(self):
|
||||
for index in range(0, len(self._acl)):
|
||||
if isinstance(self._acl[index], ACLRule) and hash(self._acl[index]) == delete_rule_hash:
|
||||
self._acl[index] = None
|
||||
|
||||
def remove_all_rules(self) -> None:
|
||||
"""Removes all rules."""
|
||||
self.acl.clear()
|
||||
for i in range(len(self._acl)):
|
||||
self._acl[i] = None
|
||||
|
||||
def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
def get_dictionary_hash(
|
||||
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
|
||||
) -> int:
|
||||
"""
|
||||
Produces a hash value for a rule.
|
||||
|
||||
@@ -117,7 +174,9 @@ class AccessControlList:
|
||||
hash_value = hash(rule)
|
||||
return hash_value
|
||||
|
||||
def get_relevant_rules(self, _source_ip_address, _dest_ip_address, _protocol, _port):
|
||||
def get_relevant_rules(
|
||||
self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Get all ACL rules that relate to the given arguments.
|
||||
|
||||
:param _source_ip_address: the source IP address to check
|
||||
@@ -125,18 +184,15 @@ class AccessControlList:
|
||||
:param _protocol: the protocol to check
|
||||
:param _port: the port to check
|
||||
:return: Dictionary of all ACL rules that relate to the given arguments
|
||||
:rtype: Dict[str, ACLRule]
|
||||
:rtype: Dict[int, ACLRule]
|
||||
"""
|
||||
relevant_rules = {}
|
||||
|
||||
for rule_key, rule_value in self.acl.items():
|
||||
if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address):
|
||||
if (
|
||||
rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY" or _protocol == "ANY"
|
||||
) and (
|
||||
str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY" or str(_port) == "ANY"
|
||||
for rule in self.acl:
|
||||
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
|
||||
if (rule.get_protocol() == _protocol or rule.get_protocol() == "ANY" or _protocol == "ANY") and (
|
||||
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY" or str(_port) == "ANY"
|
||||
):
|
||||
# There's a matching rule.
|
||||
relevant_rules[rule_key] = rule_value
|
||||
relevant_rules[self._acl.index(rule)] = rule
|
||||
|
||||
return relevant_rules
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""A class that implements an access control list rule."""
|
||||
from primaite.common.enums import RulePermissionType
|
||||
|
||||
|
||||
class ACLRule:
|
||||
"""Access Control List Rule class."""
|
||||
|
||||
def __init__(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
def __init__(
|
||||
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an ACL Rule.
|
||||
|
||||
@@ -15,13 +18,13 @@ class ACLRule:
|
||||
:param _protocol: The rule protocol
|
||||
:param _port: The rule port
|
||||
"""
|
||||
self.permission = _permission
|
||||
self.source_ip = _source_ip
|
||||
self.dest_ip = _dest_ip
|
||||
self.protocol = _protocol
|
||||
self.port = _port
|
||||
self.permission: RulePermissionType = _permission
|
||||
self.source_ip: str = _source_ip
|
||||
self.dest_ip: str = _dest_ip
|
||||
self.protocol: str = _protocol
|
||||
self.port: str = _port
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
"""
|
||||
Override the hash function.
|
||||
|
||||
@@ -38,7 +41,7 @@ class ACLRule:
|
||||
)
|
||||
)
|
||||
|
||||
def get_permission(self):
|
||||
def get_permission(self) -> str:
|
||||
"""
|
||||
Gets the permission attribute.
|
||||
|
||||
@@ -47,7 +50,7 @@ class ACLRule:
|
||||
"""
|
||||
return self.permission
|
||||
|
||||
def get_source_ip(self):
|
||||
def get_source_ip(self) -> str:
|
||||
"""
|
||||
Gets the source IP address attribute.
|
||||
|
||||
@@ -56,7 +59,7 @@ class ACLRule:
|
||||
"""
|
||||
return self.source_ip
|
||||
|
||||
def get_dest_ip(self):
|
||||
def get_dest_ip(self) -> str:
|
||||
"""
|
||||
Gets the desintation IP address attribute.
|
||||
|
||||
@@ -65,7 +68,7 @@ class ACLRule:
|
||||
"""
|
||||
return self.dest_ip
|
||||
|
||||
def get_protocol(self):
|
||||
def get_protocol(self) -> str:
|
||||
"""
|
||||
Gets the protocol attribute.
|
||||
|
||||
@@ -74,7 +77,7 @@ class ACLRule:
|
||||
"""
|
||||
return self.protocol
|
||||
|
||||
def get_port(self):
|
||||
def get_port(self) -> str:
|
||||
"""
|
||||
Gets the port attribute.
|
||||
|
||||
|
||||
@@ -4,8 +4,9 @@ from __future__ import annotations
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import primaite
|
||||
@@ -16,7 +17,7 @@ from primaite.data_viz.session_plots import plot_av_reward_per_episode
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.utils.session_metadata_parser import parse_session_metadata
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def get_session_path(session_timestamp: datetime) -> Path:
|
||||
@@ -51,7 +52,7 @@ class AgentSessionABC(ABC):
|
||||
training_config_path: Optional[Union[str, Path]] = None,
|
||||
lay_down_config_path: Optional[Union[str, Path]] = None,
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an agent session from config files, or load a previous session.
|
||||
|
||||
@@ -131,11 +132,11 @@ class AgentSessionABC(ABC):
|
||||
return path
|
||||
|
||||
@property
|
||||
def uuid(self):
|
||||
def uuid(self) -> str:
|
||||
"""The Agent Session UUID."""
|
||||
return self._uuid
|
||||
|
||||
def _write_session_metadata_file(self):
|
||||
def _write_session_metadata_file(self) -> None:
|
||||
"""
|
||||
Write the ``session_metadata.json`` file.
|
||||
|
||||
@@ -171,7 +172,7 @@ class AgentSessionABC(ABC):
|
||||
json.dump(metadata_dict, file)
|
||||
_LOGGER.debug("Finished writing session metadata file")
|
||||
|
||||
def _update_session_metadata_file(self):
|
||||
def _update_session_metadata_file(self) -> None:
|
||||
"""
|
||||
Update the ``session_metadata.json`` file.
|
||||
|
||||
@@ -200,7 +201,7 @@ class AgentSessionABC(ABC):
|
||||
_LOGGER.debug("Finished updating session metadata file")
|
||||
|
||||
@abstractmethod
|
||||
def _setup(self):
|
||||
def _setup(self) -> None:
|
||||
_LOGGER.info(
|
||||
"Welcome to the Primary-level AI Training Environment " f"(PrimAITE) (version: {primaite.__version__})"
|
||||
)
|
||||
@@ -210,14 +211,14 @@ class AgentSessionABC(ABC):
|
||||
self._can_evaluate = False
|
||||
|
||||
@abstractmethod
|
||||
def _save_checkpoint(self):
|
||||
def _save_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
@@ -234,8 +235,8 @@ class AgentSessionABC(ABC):
|
||||
@abstractmethod
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
@@ -248,10 +249,10 @@ class AgentSessionABC(ABC):
|
||||
_LOGGER.info("Finished evaluation")
|
||||
|
||||
@abstractmethod
|
||||
def _get_latest_checkpoint(self):
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def load(self, path: Union[str, Path]):
|
||||
def load(self, path: Union[str, Path]) -> None:
|
||||
"""Load an agent from file."""
|
||||
md_dict, training_config_path, laydown_config_path = parse_session_metadata(path)
|
||||
|
||||
@@ -275,21 +276,21 @@ class AgentSessionABC(ABC):
|
||||
return self.learning_path / file_name
|
||||
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def export(self):
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Closes the agent."""
|
||||
self._env.episode_av_reward_writer.close() # noqa
|
||||
self._env.transaction_writer.close() # noqa
|
||||
|
||||
def _plot_av_reward_per_episode(self, learning_session: bool = True):
|
||||
def _plot_av_reward_per_episode(self, learning_session: bool = True) -> None:
|
||||
# self.close()
|
||||
title = f"PrimAITE Session {self.timestamp_str} "
|
||||
subtitle = str(self._training_config)
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
@@ -24,7 +26,7 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a hardcoded agent session.
|
||||
|
||||
@@ -37,7 +39,7 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
super().__init__(training_config_path, lay_down_config_path, session_path)
|
||||
self._setup()
|
||||
|
||||
def _setup(self):
|
||||
def _setup(self) -> None:
|
||||
self._env: Primaite = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
@@ -48,16 +50,16 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
self._can_learn = False
|
||||
self._can_evaluate = True
|
||||
|
||||
def _save_checkpoint(self):
|
||||
def _save_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
@@ -66,13 +68,13 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
_LOGGER.warning("Deterministic agents cannot learn")
|
||||
|
||||
@abstractmethod
|
||||
def _calculate_action(self, obs):
|
||||
def _calculate_action(self, obs: np.ndarray) -> None:
|
||||
pass
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
@@ -103,14 +105,14 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
self._env.close()
|
||||
|
||||
@classmethod
|
||||
def load(cls, path=None):
|
||||
def load(cls, path: Union[str, Path] = None) -> None:
|
||||
"""Load an agent from file."""
|
||||
_LOGGER.warning("Deterministic agents cannot be loaded")
|
||||
|
||||
def save(self):
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
_LOGGER.warning("Deterministic agents cannot be saved")
|
||||
|
||||
def export(self):
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
_LOGGER.warning("Deterministic agents cannot be exported")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -33,7 +33,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
def get_blocked_green_iers(
|
||||
self, green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[Any, Any]:
|
||||
) -> Dict[str, IER]:
|
||||
"""Get blocked green IERs.
|
||||
|
||||
:param green_iers: Green IERs to check for being
|
||||
@@ -61,7 +61,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
return blocked_green_iers
|
||||
|
||||
def get_matching_acl_rules_for_ier(self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]):
|
||||
def get_matching_acl_rules_for_ier(
|
||||
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Get list of ACL rules which are relevant to an IER.
|
||||
|
||||
:param ier: Information Exchange Request to query against the ACL list
|
||||
@@ -84,7 +86,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
def get_blocking_acl_rules_for_ier(
|
||||
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[str, Any]:
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""
|
||||
Get blocking ACL rules for an IER.
|
||||
|
||||
@@ -112,7 +114,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
def get_allow_acl_rules_for_ier(
|
||||
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[str, Any]:
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Get all allowing ACL rules for an IER.
|
||||
|
||||
:param ier: Information Exchange Request to query against the ACL list
|
||||
@@ -142,7 +144,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
acl: AccessControlList,
|
||||
nodes: Dict[str, Union[ServiceNode, ActiveNode]],
|
||||
services_list: List[str],
|
||||
) -> Dict[str, ACLRule]:
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Filter ACL rules to only those which are relevant to the specified nodes.
|
||||
|
||||
:param source_node_id: Source node
|
||||
@@ -174,6 +176,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
if protocol != "ANY":
|
||||
protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services
|
||||
# TODO: This should throw an error because protocol is a string
|
||||
|
||||
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
|
||||
return matching_rules
|
||||
@@ -187,7 +190,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
acl: AccessControlList,
|
||||
nodes: Dict[str, NodeUnion],
|
||||
services_list: List[str],
|
||||
) -> Dict[str, ACLRule]:
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""List ALLOW rules relating to specified nodes.
|
||||
|
||||
:param source_node_id: Source node id
|
||||
@@ -234,7 +237,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
acl: AccessControlList,
|
||||
nodes: Dict[str, NodeUnion],
|
||||
services_list: List[str],
|
||||
) -> Dict[str, ACLRule]:
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""List DENY rules relating to specified nodes.
|
||||
|
||||
:param source_node_id: Source node id
|
||||
|
||||
@@ -102,6 +102,7 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
# TODO: transform_action_node_enum takes only one argument, not sure why two are given here.
|
||||
action = transform_action_node_enum(action, action_dict)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
|
||||
@@ -4,8 +4,9 @@ from __future__ import annotations
|
||||
import json
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from ray.rllib.algorithms import Algorithm
|
||||
@@ -19,10 +20,11 @@ from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def _env_creator(env_config):
|
||||
# TODO: verify type of env_config
|
||||
def _env_creator(env_config: Dict[str, Any]) -> Primaite:
|
||||
return Primaite(
|
||||
training_config_path=env_config["training_config_path"],
|
||||
lay_down_config_path=env_config["lay_down_config_path"],
|
||||
@@ -31,11 +33,12 @@ def _env_creator(env_config):
|
||||
)
|
||||
|
||||
|
||||
def _custom_log_creator(session_path: Path):
|
||||
# TODO: verify type hint return type
|
||||
def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]:
|
||||
logdir = session_path / "ray_results"
|
||||
logdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def logger_creator(config):
|
||||
def logger_creator(config: Dict) -> UnifiedLogger:
|
||||
return UnifiedLogger(config, logdir, loggers=None)
|
||||
|
||||
return logger_creator
|
||||
@@ -49,7 +52,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the RLLib Agent training session.
|
||||
|
||||
@@ -74,6 +77,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
self._agent_config_class: Union[PPOConfig, A2CConfig]
|
||||
if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
||||
self._agent_config_class = PPOConfig
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
||||
@@ -95,7 +99,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
f"{self._training_config.deep_learning_framework}"
|
||||
)
|
||||
|
||||
def _update_session_metadata_file(self):
|
||||
def _update_session_metadata_file(self) -> None:
|
||||
"""
|
||||
Update the ``session_metadata.json`` file.
|
||||
|
||||
@@ -123,7 +127,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
json.dump(metadata_dict, file)
|
||||
_LOGGER.debug("Finished updating session metadata file")
|
||||
|
||||
def _setup(self):
|
||||
def _setup(self) -> None:
|
||||
super()._setup()
|
||||
register_env("primaite", _env_creator)
|
||||
self._agent_config = self._agent_config_class()
|
||||
@@ -149,7 +153,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
)
|
||||
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
|
||||
|
||||
def _save_checkpoint(self):
|
||||
def _save_checkpoint(self) -> None:
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._current_result["episodes_total"]
|
||||
save_checkpoint = False
|
||||
@@ -160,8 +164,8 @@ class RLlibAgent(AgentSessionABC):
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
@@ -181,8 +185,8 @@ class RLlibAgent(AgentSessionABC):
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: None,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
@@ -190,7 +194,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@@ -198,7 +202,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
"""Load an agent from file."""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self, overwrite_existing: bool = True):
|
||||
def save(self, overwrite_existing: bool = True) -> None:
|
||||
"""Save the agent."""
|
||||
# Make temp dir to save in isolation
|
||||
temp_dir = self.learning_path / str(uuid4())
|
||||
@@ -218,6 +222,6 @@ class RLlibAgent(AgentSessionABC):
|
||||
# Drop the temp directory
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
def export(self):
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import A2C, PPO
|
||||
@@ -14,7 +15,7 @@ from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
class SB3Agent(AgentSessionABC):
|
||||
@@ -25,7 +26,7 @@ class SB3Agent(AgentSessionABC):
|
||||
training_config_path: Optional[Union[str, Path]] = None,
|
||||
lay_down_config_path: Optional[Union[str, Path]] = None,
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the SB3 Agent training session.
|
||||
|
||||
@@ -43,6 +44,7 @@ class SB3Agent(AgentSessionABC):
|
||||
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
self._agent_class: Union[PPO, A2C]
|
||||
if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
||||
self._agent_class = PPO
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
||||
@@ -66,7 +68,7 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
self._setup()
|
||||
|
||||
def _setup(self):
|
||||
def _setup(self) -> None:
|
||||
"""Set up the SB3 Agent."""
|
||||
self._env = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
@@ -113,7 +115,7 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
super()._setup()
|
||||
|
||||
def _save_checkpoint(self):
|
||||
def _save_checkpoint(self) -> None:
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._env.episode_count
|
||||
save_checkpoint = False
|
||||
@@ -124,13 +126,13 @@ class SB3Agent(AgentSessionABC):
|
||||
self._agent.save(checkpoint_path)
|
||||
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
@@ -153,8 +155,8 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
@@ -183,10 +185,10 @@ class SB3Agent(AgentSessionABC):
|
||||
self._env.close()
|
||||
super().evaluate()
|
||||
|
||||
def save(self):
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
self._agent.save(self._saved_agent_path)
|
||||
|
||||
def export(self):
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
|
||||
|
||||
@@ -10,7 +13,7 @@ class RandomAgent(HardCodedAgentSessionABC):
|
||||
Get a completely random action from the action space.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
return self._env.action_space.sample()
|
||||
|
||||
|
||||
@@ -21,7 +24,7 @@ class DummyAgent(HardCodedAgentSessionABC):
|
||||
All action spaces setup so dummy action is always 0 regardless of action type used.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
@@ -32,7 +35,7 @@ class DoNothingACLAgent(HardCodedAgentSessionABC):
|
||||
A valid ACL action that has no effect; does nothing.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
|
||||
nothing_action = transform_action_acl_enum(nothing_action)
|
||||
nothing_action = get_new_action(nothing_action, self._env.action_dict)
|
||||
@@ -47,7 +50,7 @@ class DoNothingNodeAgent(HardCodedAgentSessionABC):
|
||||
A valid Node action that has no effect; does nothing.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
nothing_action = [1, "NONE", "ON", 0]
|
||||
nothing_action = transform_action_node_enum(nothing_action)
|
||||
nothing_action = get_new_action(nothing_action, self._env.action_dict)
|
||||
|
||||
@@ -35,11 +35,11 @@ def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]:
|
||||
else:
|
||||
property_action = "NONE"
|
||||
|
||||
new_action = [action[0], action_node_property, property_action, action[3]]
|
||||
new_action: list[Union[int, str]] = [action[0], action_node_property, property_action, action[3]]
|
||||
return new_action
|
||||
|
||||
|
||||
def transform_action_acl_readable(action: List[str]) -> List[Union[str, int]]:
|
||||
def transform_action_acl_readable(action: List[int]) -> List[Union[str, int]]:
|
||||
"""
|
||||
Transform an ACL action to a more readable format.
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def build_dirs():
|
||||
def build_dirs() -> None:
|
||||
"""Build the PrimAITE app directories."""
|
||||
from primaite.setup import setup_app_dirs
|
||||
|
||||
@@ -27,7 +27,7 @@ def build_dirs():
|
||||
|
||||
|
||||
@app.command()
|
||||
def reset_notebooks(overwrite: bool = True):
|
||||
def reset_notebooks(overwrite: bool = True) -> None:
|
||||
"""
|
||||
Force a reset of the demo notebooks in the users notebooks directory.
|
||||
|
||||
@@ -39,7 +39,7 @@ def reset_notebooks(overwrite: bool = True):
|
||||
|
||||
|
||||
@app.command()
|
||||
def logs(last_n: Annotated[int, typer.Option("-n")]):
|
||||
def logs(last_n: Annotated[int, typer.Option("-n")]) -> None:
|
||||
"""
|
||||
Print the PrimAITE log file.
|
||||
|
||||
@@ -60,7 +60,7 @@ _LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # n
|
||||
|
||||
|
||||
@app.command()
|
||||
def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None):
|
||||
def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None) -> None:
|
||||
"""
|
||||
View or set the PrimAITE Log Level.
|
||||
|
||||
@@ -88,7 +88,7 @@ def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None):
|
||||
|
||||
|
||||
@app.command()
|
||||
def notebooks():
|
||||
def notebooks() -> None:
|
||||
"""Start Jupyter Lab in the users PrimAITE notebooks directory."""
|
||||
from primaite.notebooks import start_jupyter_session
|
||||
|
||||
@@ -96,7 +96,7 @@ def notebooks():
|
||||
|
||||
|
||||
@app.command()
|
||||
def version():
|
||||
def version() -> None:
|
||||
"""Get the installed PrimAITE version number."""
|
||||
import primaite
|
||||
|
||||
@@ -104,7 +104,7 @@ def version():
|
||||
|
||||
|
||||
@app.command()
|
||||
def clean_up():
|
||||
def clean_up() -> None:
|
||||
"""Cleans up left over files from previous version installations."""
|
||||
from primaite.setup import old_installation_clean_up
|
||||
|
||||
@@ -112,7 +112,7 @@ def clean_up():
|
||||
|
||||
|
||||
@app.command()
|
||||
def setup(overwrite_existing: bool = True):
|
||||
def setup(overwrite_existing: bool = True) -> None:
|
||||
"""
|
||||
Perform the PrimAITE first-time setup.
|
||||
|
||||
@@ -151,7 +151,7 @@ def setup(overwrite_existing: bool = True):
|
||||
|
||||
|
||||
@app.command()
|
||||
def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[str] = None):
|
||||
def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[str] = None) -> None:
|
||||
"""
|
||||
Run a PrimAITE session.
|
||||
|
||||
@@ -185,7 +185,7 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[
|
||||
|
||||
|
||||
@app.command()
|
||||
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None):
|
||||
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None) -> None:
|
||||
"""
|
||||
View or set the plotly template for Session plots.
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from typing import Type, Union
|
||||
from typing import Union
|
||||
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.passive_node import PassiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
NodeUnion: Type = Union[ActiveNode, PassiveNode, ServiceNode]
|
||||
NodeUnion = Union[ActiveNode, PassiveNode, ServiceNode]
|
||||
"""A Union of ActiveNode, PassiveNode, and ServiceNode."""
|
||||
|
||||
@@ -148,6 +148,7 @@ class ActionType(Enum):
|
||||
ANY = 2
|
||||
|
||||
|
||||
# TODO: this is not used anymore, write a ticket to delete it.
|
||||
class ObservationType(Enum):
|
||||
"""Observation type enumeration."""
|
||||
|
||||
@@ -197,3 +198,11 @@ class SB3OutputVerboseLevel(IntEnum):
|
||||
NONE = 0
|
||||
INFO = 1
|
||||
DEBUG = 2
|
||||
|
||||
|
||||
class RulePermissionType(Enum):
|
||||
"""Any firewall rule type."""
|
||||
|
||||
NONE = 0
|
||||
DENY = 1
|
||||
ALLOW = 2
|
||||
|
||||
@@ -5,17 +5,17 @@
|
||||
class Protocol(object):
|
||||
"""Protocol class."""
|
||||
|
||||
def __init__(self, _name):
|
||||
def __init__(self, _name: str) -> None:
|
||||
"""
|
||||
Initialise a protocol.
|
||||
|
||||
:param _name: The name of the protocol
|
||||
:type _name: str
|
||||
"""
|
||||
self.name = _name
|
||||
self.load = 0 # bps
|
||||
self.name: str = _name
|
||||
self.load: int = 0 # bps
|
||||
|
||||
def get_name(self):
|
||||
def get_name(self) -> str:
|
||||
"""
|
||||
Gets the protocol name.
|
||||
|
||||
@@ -24,7 +24,7 @@ class Protocol(object):
|
||||
"""
|
||||
return self.name
|
||||
|
||||
def get_load(self):
|
||||
def get_load(self) -> int:
|
||||
"""
|
||||
Gets the protocol load.
|
||||
|
||||
@@ -33,7 +33,7 @@ class Protocol(object):
|
||||
"""
|
||||
return self.load
|
||||
|
||||
def add_load(self, _load):
|
||||
def add_load(self, _load: int) -> None:
|
||||
"""
|
||||
Adds load to the protocol.
|
||||
|
||||
@@ -42,6 +42,6 @@ class Protocol(object):
|
||||
"""
|
||||
self.load += _load
|
||||
|
||||
def clear_load(self):
|
||||
def clear_load(self) -> None:
|
||||
"""Clears the load on this protocol."""
|
||||
self.load = 0
|
||||
|
||||
@@ -7,7 +7,7 @@ from primaite.common.enums import SoftwareState
|
||||
class Service(object):
|
||||
"""Service class."""
|
||||
|
||||
def __init__(self, name: str, port: str, software_state: SoftwareState):
|
||||
def __init__(self, name: str, port: str, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Initialise a service.
|
||||
|
||||
@@ -15,12 +15,12 @@ class Service(object):
|
||||
:param port: The service port.
|
||||
:param software_state: The service SoftwareState.
|
||||
"""
|
||||
self.name = name
|
||||
self.port = port
|
||||
self.software_state = software_state
|
||||
self.patching_count = 0
|
||||
self.name: str = name
|
||||
self.port: str = port
|
||||
self.software_state: SoftwareState = software_state
|
||||
self.patching_count: int = 0
|
||||
|
||||
def reduce_patching_count(self):
|
||||
def reduce_patching_count(self) -> None:
|
||||
"""Reduces the patching count for the service."""
|
||||
self.patching_count -= 1
|
||||
if self.patching_count <= 0:
|
||||
|
||||
@@ -163,3 +163,4 @@
|
||||
destination: ANY
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 0
|
||||
|
||||
@@ -243,6 +243,7 @@
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 0
|
||||
- item_type: ACL_RULE
|
||||
id: '26'
|
||||
permission: ALLOW
|
||||
@@ -250,6 +251,7 @@
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 1
|
||||
- item_type: ACL_RULE
|
||||
id: '27'
|
||||
permission: ALLOW
|
||||
@@ -257,6 +259,7 @@
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 2
|
||||
- item_type: ACL_RULE
|
||||
id: '28'
|
||||
permission: ALLOW
|
||||
@@ -264,6 +267,7 @@
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 3
|
||||
- item_type: ACL_RULE
|
||||
id: '29'
|
||||
permission: ALLOW
|
||||
@@ -271,6 +275,7 @@
|
||||
destination: 192.168.10.13
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 4
|
||||
- item_type: ACL_RULE
|
||||
id: '30'
|
||||
permission: DENY
|
||||
@@ -278,6 +283,7 @@
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 5
|
||||
- item_type: ACL_RULE
|
||||
id: '31'
|
||||
permission: DENY
|
||||
@@ -285,6 +291,7 @@
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 6
|
||||
- item_type: ACL_RULE
|
||||
id: '32'
|
||||
permission: DENY
|
||||
@@ -292,6 +299,7 @@
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 7
|
||||
- item_type: ACL_RULE
|
||||
id: '33'
|
||||
permission: DENY
|
||||
@@ -299,6 +307,7 @@
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 8
|
||||
- item_type: RED_POL
|
||||
id: '34'
|
||||
start_step: 20
|
||||
|
||||
@@ -111,6 +111,7 @@
|
||||
destination: 192.168.1.4
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 0
|
||||
- item_type: ACL_RULE
|
||||
id: '12'
|
||||
permission: ALLOW
|
||||
@@ -118,6 +119,7 @@
|
||||
destination: 192.168.1.4
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 1
|
||||
- item_type: ACL_RULE
|
||||
id: '13'
|
||||
permission: ALLOW
|
||||
@@ -125,6 +127,7 @@
|
||||
destination: 192.168.1.3
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 2
|
||||
- item_type: RED_POL
|
||||
id: '14'
|
||||
start_step: 20
|
||||
|
||||
@@ -345,6 +345,7 @@
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 0
|
||||
- item_type: ACL_RULE
|
||||
id: '34'
|
||||
permission: ALLOW
|
||||
@@ -352,6 +353,7 @@
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 1
|
||||
- item_type: ACL_RULE
|
||||
id: '35'
|
||||
permission: ALLOW
|
||||
@@ -359,6 +361,7 @@
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 2
|
||||
- item_type: ACL_RULE
|
||||
id: '36'
|
||||
permission: ALLOW
|
||||
@@ -366,6 +369,7 @@
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 3
|
||||
- item_type: ACL_RULE
|
||||
id: '37'
|
||||
permission: ALLOW
|
||||
@@ -373,6 +377,7 @@
|
||||
destination: 192.168.10.11
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 4
|
||||
- item_type: ACL_RULE
|
||||
id: '38'
|
||||
permission: ALLOW
|
||||
@@ -380,6 +385,7 @@
|
||||
destination: 192.168.10.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 5
|
||||
- item_type: ACL_RULE
|
||||
id: '39'
|
||||
permission: ALLOW
|
||||
@@ -387,6 +393,7 @@
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 6
|
||||
- item_type: ACL_RULE
|
||||
id: '40'
|
||||
permission: ALLOW
|
||||
@@ -394,6 +401,7 @@
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 7
|
||||
- item_type: ACL_RULE
|
||||
id: '41'
|
||||
permission: ALLOW
|
||||
@@ -401,6 +409,7 @@
|
||||
destination: 192.168.2.16
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 8
|
||||
- item_type: ACL_RULE
|
||||
id: '42'
|
||||
permission: ALLOW
|
||||
@@ -408,6 +417,7 @@
|
||||
destination: 192.168.2.16
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 9
|
||||
- item_type: ACL_RULE
|
||||
id: '43'
|
||||
permission: ALLOW
|
||||
@@ -415,6 +425,7 @@
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 10
|
||||
- item_type: ACL_RULE
|
||||
id: '44'
|
||||
permission: ALLOW
|
||||
@@ -422,6 +433,7 @@
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 11
|
||||
- item_type: ACL_RULE
|
||||
id: '45'
|
||||
permission: ALLOW
|
||||
@@ -429,6 +441,7 @@
|
||||
destination: 192.168.2.16
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 12
|
||||
- item_type: ACL_RULE
|
||||
id: '46'
|
||||
permission: ALLOW
|
||||
@@ -436,6 +449,7 @@
|
||||
destination: 192.168.1.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 13
|
||||
- item_type: ACL_RULE
|
||||
id: '47'
|
||||
permission: ALLOW
|
||||
@@ -443,6 +457,7 @@
|
||||
destination: 192.168.1.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 14
|
||||
- item_type: ACL_RULE
|
||||
id: '48'
|
||||
permission: ALLOW
|
||||
@@ -450,6 +465,7 @@
|
||||
destination: 192.168.1.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 15
|
||||
- item_type: ACL_RULE
|
||||
id: '49'
|
||||
permission: DENY
|
||||
@@ -457,6 +473,7 @@
|
||||
destination: ANY
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 16
|
||||
- item_type: RED_POL
|
||||
id: '50'
|
||||
start_step: 50
|
||||
|
||||
@@ -35,7 +35,7 @@ random_red_agent: False
|
||||
# Default is None (null)
|
||||
seed: null
|
||||
|
||||
# Set whether the agent will be deterministic instead of stochastic
|
||||
# Set whether the agent evaluation will be deterministic instead of stochastic
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
@@ -51,15 +51,15 @@ hard_coded_agent_view: FULL
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: NODE
|
||||
action_type: ANY
|
||||
# observation space
|
||||
observation_space:
|
||||
# flatten: true
|
||||
flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
# - name: NODE_STATUSES
|
||||
# - name: LINK_TRAFFIC_LEVELS
|
||||
|
||||
- name: NODE_STATUSES
|
||||
- name: LINK_TRAFFIC_LEVELS
|
||||
- name: ACCESS_CONTROL_LIST
|
||||
|
||||
# Number of episodes for training to run per session
|
||||
num_train_episodes: 10
|
||||
@@ -90,6 +90,11 @@ session_type: TRAIN_EVAL
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
# Implicit ACL firewall rule at end of ACL list to be the default action (ALLOW or DENY)
|
||||
implicit_acl_rule: DENY
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 30
|
||||
|
||||
# The Stable Baselines3 learn/eval output verbosity level:
|
||||
# Options are:
|
||||
# "NONE" (No Output)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Union
|
||||
|
||||
@@ -6,7 +7,7 @@ import yaml
|
||||
|
||||
from primaite import getLogger, USERS_CONFIG_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Optional, Union
|
||||
|
||||
@@ -14,11 +15,12 @@ from primaite.common.enums import (
|
||||
AgentIdentifier,
|
||||
DeepLearningFramework,
|
||||
HardCodedAgentView,
|
||||
RulePermissionType,
|
||||
SB3OutputVerboseLevel,
|
||||
SessionType,
|
||||
)
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training"
|
||||
|
||||
@@ -85,7 +87,7 @@ class TrainingConfig:
|
||||
session_type: SessionType = SessionType.TRAIN
|
||||
"The type of PrimAITE session to run"
|
||||
|
||||
load_agent: str = False
|
||||
load_agent: bool = False
|
||||
"Determine whether to load an agent from file"
|
||||
|
||||
agent_load_file: Optional[str] = None
|
||||
@@ -98,6 +100,12 @@ class TrainingConfig:
|
||||
sb3_output_verbose_level: SB3OutputVerboseLevel = SB3OutputVerboseLevel.NONE
|
||||
"Stable Baselines3 learn/eval output verbosity level"
|
||||
|
||||
implicit_acl_rule: RulePermissionType = RulePermissionType.DENY
|
||||
"ALLOW or DENY implicit firewall rule to go at the end of list of ACL list."
|
||||
|
||||
max_number_acl_rules: int = 30
|
||||
"Sets a limit for number of acl rules allowed in the list and environment."
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: float = 0
|
||||
@@ -191,7 +199,7 @@ class TrainingConfig:
|
||||
"The random number generator seed to be used while training the agent"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig:
|
||||
def from_dict(cls, config_dict: Dict[str, Any]) -> TrainingConfig:
|
||||
"""
|
||||
Create an instance of TrainingConfig from a dict.
|
||||
|
||||
@@ -206,14 +214,17 @@ class TrainingConfig:
|
||||
"session_type": SessionType,
|
||||
"sb3_output_verbose_level": SB3OutputVerboseLevel,
|
||||
"hard_coded_agent_view": HardCodedAgentView,
|
||||
"implicit_acl_rule": RulePermissionType,
|
||||
}
|
||||
|
||||
# convert the string representation of enums into the actual enum values themselves?
|
||||
for key, value in field_enum_map.items():
|
||||
if key in config_dict:
|
||||
config_dict[key] = value[config_dict[key]]
|
||||
|
||||
return TrainingConfig(**config_dict)
|
||||
|
||||
def to_dict(self, json_serializable: bool = True):
|
||||
def to_dict(self, json_serializable: bool = True) -> Dict:
|
||||
"""
|
||||
Serialise the ``TrainingConfig`` as dict.
|
||||
|
||||
@@ -230,6 +241,7 @@ class TrainingConfig:
|
||||
data["sb3_output_verbose_level"] = self.sb3_output_verbose_level.name
|
||||
data["session_type"] = self.session_type.name
|
||||
data["hard_coded_agent_view"] = self.hard_coded_agent_view.name
|
||||
data["implicit_acl_rule"] = self.implicit_acl_rule.name
|
||||
|
||||
return data
|
||||
|
||||
@@ -332,7 +344,7 @@ def convert_legacy_training_config_dict(
|
||||
return config_dict
|
||||
|
||||
|
||||
def _get_new_key_from_legacy(legacy_key: str) -> str:
|
||||
def _get_new_key_from_legacy(legacy_key: str) -> Optional[str]:
|
||||
"""
|
||||
Maps legacy training config keys to the new format keys.
|
||||
|
||||
|
||||
@@ -2,12 +2,14 @@
|
||||
"""Module for handling configurable observation spaces in PrimAITE."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.common.enums import FileSystemState, HardwareState, RulePermissionType, SoftwareState
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
@@ -18,14 +20,14 @@ if TYPE_CHECKING:
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_LOGGER: Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbstractObservationComponent(ABC):
|
||||
"""Represents a part of the PrimAITE observation space."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, env: "Primaite"):
|
||||
def __init__(self, env: "Primaite") -> None:
|
||||
"""
|
||||
Initialise observation component.
|
||||
|
||||
@@ -40,7 +42,7 @@ class AbstractObservationComponent(ABC):
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def update(self):
|
||||
def update(self) -> None:
|
||||
"""Update the observation based on the current state of the environment."""
|
||||
self.current_observation = NotImplemented
|
||||
|
||||
@@ -75,7 +77,7 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
_MAX_VAL: int = 1_000_000_000
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite"):
|
||||
def __init__(self, env: "Primaite") -> None:
|
||||
"""
|
||||
Initialise a NodeLinkTable observation space component.
|
||||
|
||||
@@ -102,7 +104,7 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self):
|
||||
def update(self) -> None:
|
||||
"""
|
||||
Update the observation based on current environment state.
|
||||
|
||||
@@ -149,7 +151,7 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
protocol_index += 1
|
||||
item_index += 1
|
||||
|
||||
def generate_structure(self):
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
nodes = self.env.nodes.values()
|
||||
links = self.env.links.values()
|
||||
@@ -212,7 +214,7 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite"):
|
||||
def __init__(self, env: "Primaite") -> None:
|
||||
"""
|
||||
Initialise a NodeStatuses observation component.
|
||||
|
||||
@@ -238,7 +240,7 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self):
|
||||
def update(self) -> None:
|
||||
"""
|
||||
Update the observation based on current environment state.
|
||||
|
||||
@@ -269,11 +271,12 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
)
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self):
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
services = self.env.services_list
|
||||
|
||||
structure = []
|
||||
|
||||
for _, node in self.env.nodes.items():
|
||||
node_id = node.node_id
|
||||
structure.append(f"node_{node_id}_hardware_state_NONE")
|
||||
@@ -318,7 +321,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
env: "Primaite",
|
||||
combine_service_traffic: bool = False,
|
||||
quantisation_levels: int = 5,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a LinkTrafficLevels observation component.
|
||||
|
||||
@@ -360,7 +363,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self):
|
||||
def update(self) -> None:
|
||||
"""
|
||||
Update the observation based on current environment state.
|
||||
|
||||
@@ -386,7 +389,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self):
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
structure = []
|
||||
for _, link in self.env.links.items():
|
||||
@@ -402,6 +405,182 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
return structure
|
||||
|
||||
|
||||
class AccessControlList(AbstractObservationComponent):
|
||||
"""Flat list of all the Access Control Rules in the Access Control List.
|
||||
|
||||
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
|
||||
integers.
|
||||
|
||||
Each ACL Rule has 6 elements. It will have the following structure:
|
||||
.. code-block::
|
||||
[
|
||||
acl_rule1 permission,
|
||||
acl_rule1 source_ip,
|
||||
acl_rule1 dest_ip,
|
||||
acl_rule1 protocol,
|
||||
acl_rule1 port,
|
||||
acl_rule1 position,
|
||||
acl_rule2 permission,
|
||||
acl_rule2 source_ip,
|
||||
acl_rule2 dest_ip,
|
||||
acl_rule2 protocol,
|
||||
acl_rule2 port,
|
||||
acl_rule2 position,
|
||||
...
|
||||
]
|
||||
|
||||
|
||||
Terms (for ACL Observation Space):
|
||||
[0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW)
|
||||
[0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
|
||||
[0, num nodes] - Dest IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
|
||||
[0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol)
|
||||
[0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port)
|
||||
[0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list)
|
||||
|
||||
NOTE: NA is Non-Applicable - this means the ACL Rule in the list is a NoneType and NOT an ACLRule object.
|
||||
"""
|
||||
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite"):
|
||||
"""
|
||||
Initialise an AccessControlList observation component.
|
||||
|
||||
:param env: The environment that forms the basis of the observations
|
||||
:type env: Primaite
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
# The NA and ANY types means that there are 2 extra items for Nodes, Services and Ports.
|
||||
# Number of ACL rules incremented by 1 for positions starting at index 0.
|
||||
acl_shape = [
|
||||
len(RulePermissionType),
|
||||
len(env.nodes) + 2,
|
||||
len(env.nodes) + 2,
|
||||
len(env.services_list) + 2,
|
||||
len(env.ports_list) + 2,
|
||||
env.max_number_acl_rules,
|
||||
]
|
||||
shape = acl_shape * self.env.max_number_acl_rules
|
||||
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.MultiDiscrete(shape)
|
||||
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self) -> None:
|
||||
"""Update the observation based on current environment state.
|
||||
|
||||
The structure of the observation space is described in :class:`.AccessControlList`
|
||||
"""
|
||||
obs = []
|
||||
|
||||
for index in range(0, len(self.env.acl.acl)):
|
||||
acl_rule = self.env.acl.acl[index]
|
||||
if isinstance(acl_rule, ACLRule):
|
||||
permission = acl_rule.permission
|
||||
source_ip = acl_rule.source_ip
|
||||
dest_ip = acl_rule.dest_ip
|
||||
protocol = acl_rule.protocol
|
||||
port = acl_rule.port
|
||||
position = index
|
||||
# Map each ACL attribute from what it was to an integer to fit the observation space
|
||||
source_ip_int = None
|
||||
dest_ip_int = None
|
||||
if permission == RulePermissionType.DENY:
|
||||
permission_int = 1
|
||||
else:
|
||||
permission_int = 2
|
||||
if source_ip == "ANY":
|
||||
source_ip_int = 1
|
||||
else:
|
||||
# Map Node ID (+ 1) to source IP address
|
||||
nodes = list(self.env.nodes.values())
|
||||
for node in nodes:
|
||||
if (
|
||||
isinstance(node, ServiceNode) or isinstance(node, ActiveNode)
|
||||
) and node.ip_address == source_ip:
|
||||
source_ip_int = int(node.node_id) + 1
|
||||
break
|
||||
if dest_ip == "ANY":
|
||||
dest_ip_int = 1
|
||||
else:
|
||||
# Map Node ID (+ 1) to dest IP address
|
||||
# Index of Nodes start at 1 so + 1 is needed so NA can be added.
|
||||
nodes = list(self.env.nodes.values())
|
||||
for node in nodes:
|
||||
if (
|
||||
isinstance(node, ServiceNode) or isinstance(node, ActiveNode)
|
||||
) and node.ip_address == dest_ip:
|
||||
dest_ip_int = int(node.node_id) + 1
|
||||
if protocol == "ANY":
|
||||
protocol_int = 1
|
||||
else:
|
||||
# Index of protocols and ports start from 0 so + 2 is needed to add NA and ANY
|
||||
try:
|
||||
protocol_int = self.env.services_list.index(protocol) + 2
|
||||
except AttributeError:
|
||||
_LOGGER.info(f"Service {protocol} could not be found")
|
||||
protocol_int = None
|
||||
if port == "ANY":
|
||||
port_int = 1
|
||||
else:
|
||||
if port in self.env.ports_list:
|
||||
port_int = self.env.ports_list.index(port) + 2
|
||||
else:
|
||||
_LOGGER.info(f"Port {port} could not be found.")
|
||||
port_int = None
|
||||
# Add to current obs
|
||||
obs.extend(
|
||||
[
|
||||
permission_int,
|
||||
source_ip_int,
|
||||
dest_ip_int,
|
||||
protocol_int,
|
||||
port_int,
|
||||
position,
|
||||
]
|
||||
)
|
||||
|
||||
else:
|
||||
# The Nothing or NA representation of 'NONE' ACL rules
|
||||
obs.extend([0, 0, 0, 0, 0, 0])
|
||||
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
structure = []
|
||||
for acl_rule in self.env.acl.acl:
|
||||
acl_rule_id = self.env.acl.acl.index(acl_rule)
|
||||
|
||||
for permission in RulePermissionType:
|
||||
structure.append(f"acl_rule_{acl_rule_id}_permission_{permission.name}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_source_ip_ANY")
|
||||
for node in self.env.nodes.keys():
|
||||
structure.append(f"acl_rule_{acl_rule_id}_source_ip_{node}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_dest_ip_ANY")
|
||||
for node in self.env.nodes.keys():
|
||||
structure.append(f"acl_rule_{acl_rule_id}_dest_ip_{node}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_service_ANY")
|
||||
for service in self.env.services_list:
|
||||
structure.append(f"acl_rule_{acl_rule_id}_service_{service}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_port_ANY")
|
||||
for port in self.env.ports_list:
|
||||
structure.append(f"acl_rule_{acl_rule_id}_port_{port}")
|
||||
|
||||
return structure
|
||||
|
||||
|
||||
class ObservationsHandler:
|
||||
"""
|
||||
Component-based observation space handler.
|
||||
@@ -414,9 +593,10 @@ class ObservationsHandler:
|
||||
"NODE_LINK_TABLE": NodeLinkTable,
|
||||
"NODE_STATUSES": NodeStatuses,
|
||||
"LINK_TRAFFIC_LEVELS": LinkTrafficLevels,
|
||||
"ACCESS_CONTROL_LIST": AccessControlList,
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""Initialise the observation handler."""
|
||||
self.registered_obs_components: List[AbstractObservationComponent] = []
|
||||
|
||||
@@ -429,9 +609,7 @@ class ObservationsHandler:
|
||||
# used for transactions and when flatten=true
|
||||
self._flat_observation: np.ndarray
|
||||
|
||||
self.flatten: bool = False
|
||||
|
||||
def update_obs(self):
|
||||
def update_obs(self) -> None:
|
||||
"""Fetch fresh information about the environment."""
|
||||
current_obs = []
|
||||
for obs in self.registered_obs_components:
|
||||
@@ -444,7 +622,7 @@ class ObservationsHandler:
|
||||
self._observation = tuple(current_obs)
|
||||
self._flat_observation = spaces.flatten(self._space, self._observation)
|
||||
|
||||
def register(self, obs_component: AbstractObservationComponent):
|
||||
def register(self, obs_component: AbstractObservationComponent) -> None:
|
||||
"""
|
||||
Add a component for this handler to track.
|
||||
|
||||
@@ -454,7 +632,7 @@ class ObservationsHandler:
|
||||
self.registered_obs_components.append(obs_component)
|
||||
self.update_space()
|
||||
|
||||
def deregister(self, obs_component: AbstractObservationComponent):
|
||||
def deregister(self, obs_component: AbstractObservationComponent) -> None:
|
||||
"""
|
||||
Remove a component from this handler.
|
||||
|
||||
@@ -465,7 +643,7 @@ class ObservationsHandler:
|
||||
self.registered_obs_components.remove(obs_component)
|
||||
self.update_space()
|
||||
|
||||
def update_space(self):
|
||||
def update_space(self) -> None:
|
||||
"""Rebuild the handler's composite observation space from its components."""
|
||||
component_spaces = []
|
||||
for obs_comp in self.registered_obs_components:
|
||||
@@ -482,23 +660,23 @@ class ObservationsHandler:
|
||||
self._flat_space = spaces.Box(0, 1, (0,))
|
||||
|
||||
@property
|
||||
def space(self):
|
||||
def space(self) -> spaces.Space:
|
||||
"""Observation space, return the flattened version if flatten is True."""
|
||||
if self.flatten:
|
||||
if len(self.registered_obs_components) > 1:
|
||||
return self._flat_space
|
||||
else:
|
||||
return self._space
|
||||
|
||||
@property
|
||||
def current_observation(self):
|
||||
def current_observation(self) -> Union[np.ndarray, Tuple[np.ndarray]]:
|
||||
"""Current observation, return the flattened version if flatten is True."""
|
||||
if self.flatten:
|
||||
if len(self.registered_obs_components) > 1:
|
||||
return self._flat_observation
|
||||
else:
|
||||
return self._observation
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, env: "Primaite", obs_space_config: dict):
|
||||
def from_config(cls, env: "Primaite", obs_space_config: dict) -> "ObservationsHandler":
|
||||
"""
|
||||
Parse a config dictinary, return a new observation handler populated with new observation component objects.
|
||||
|
||||
@@ -527,9 +705,6 @@ class ObservationsHandler:
|
||||
# Instantiate the handler
|
||||
handler = cls()
|
||||
|
||||
if obs_space_config.get("flatten"):
|
||||
handler.flatten = True
|
||||
|
||||
for component_cfg in obs_space_config["components"]:
|
||||
# Figure out which class can instantiate the desired component
|
||||
comp_type = component_cfg["name"]
|
||||
@@ -544,7 +719,7 @@ class ObservationsHandler:
|
||||
handler.update_obs()
|
||||
return handler
|
||||
|
||||
def describe_structure(self):
|
||||
def describe_structure(self) -> List[str]:
|
||||
"""
|
||||
Create a list of names for the features of the obs space.
|
||||
|
||||
|
||||
@@ -3,9 +3,10 @@
|
||||
import copy
|
||||
import logging
|
||||
import uuid as uuid
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from random import choice, randint, sample, uniform
|
||||
from typing import Dict, Final, Tuple, Union
|
||||
from typing import Any, Dict, Final, List, Tuple, Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
@@ -20,6 +21,7 @@ from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import (
|
||||
ActionType,
|
||||
AgentFramework,
|
||||
AgentIdentifier,
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
NodePOLInitiator,
|
||||
@@ -48,7 +50,7 @@ from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_nod
|
||||
from primaite.transactions.transaction import Transaction
|
||||
from primaite.utils.session_output_writer import SessionOutputWriter
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
class Primaite(Env):
|
||||
@@ -66,7 +68,7 @@ class Primaite(Env):
|
||||
lay_down_config_path: Union[str, Path],
|
||||
session_path: Path,
|
||||
timestamp_str: str,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
The Primaite constructor.
|
||||
|
||||
@@ -77,13 +79,14 @@ class Primaite(Env):
|
||||
"""
|
||||
self.session_path: Final[Path] = session_path
|
||||
self.timestamp_str: Final[str] = timestamp_str
|
||||
self._training_config_path = training_config_path
|
||||
self._lay_down_config_path = lay_down_config_path
|
||||
self._training_config_path: Union[str, Path] = training_config_path
|
||||
self._lay_down_config_path: Union[str, Path] = lay_down_config_path
|
||||
|
||||
self.training_config: TrainingConfig = training_config.load(training_config_path)
|
||||
_LOGGER.info(f"Using: {str(self.training_config)}")
|
||||
|
||||
# Number of steps in an episode
|
||||
self.episode_steps: int
|
||||
if self.training_config.session_type == SessionType.TRAIN:
|
||||
self.episode_steps = self.training_config.num_train_steps
|
||||
elif self.training_config.session_type == SessionType.EVAL:
|
||||
@@ -94,7 +97,7 @@ class Primaite(Env):
|
||||
super(Primaite, self).__init__()
|
||||
|
||||
# The agent in use
|
||||
self.agent_identifier = self.training_config.agent_identifier
|
||||
self.agent_identifier: AgentIdentifier = self.training_config.agent_identifier
|
||||
|
||||
# Create a dictionary to hold all the nodes
|
||||
self.nodes: Dict[str, NodeUnion] = {}
|
||||
@@ -113,37 +116,42 @@ class Primaite(Env):
|
||||
self.green_iers_reference: Dict[str, IER] = {}
|
||||
|
||||
# Create a dictionary to hold all the node PoLs (this will come from an external source)
|
||||
self.node_pol = {}
|
||||
self.node_pol: Dict[str, NodeStateInstructionGreen] = {}
|
||||
|
||||
# Create a dictionary to hold all the red agent IERs (this will come from an external source)
|
||||
self.red_iers = {}
|
||||
self.red_iers: Dict[str, IER] = {}
|
||||
|
||||
# Create a dictionary to hold all the red agent node PoLs (this will come from an external source)
|
||||
self.red_node_pol = {}
|
||||
self.red_node_pol: Dict[str, NodeStateInstructionRed] = {}
|
||||
|
||||
# Create the Access Control List
|
||||
self.acl = AccessControlList()
|
||||
self.acl: AccessControlList = AccessControlList(
|
||||
self.training_config.implicit_acl_rule,
|
||||
self.training_config.max_number_acl_rules,
|
||||
)
|
||||
# Sets limit for number of ACL rules in environment
|
||||
self.max_number_acl_rules: int = self.training_config.max_number_acl_rules
|
||||
|
||||
# Create a list of services (enums)
|
||||
self.services_list = []
|
||||
self.services_list: List[str] = []
|
||||
|
||||
# Create a list of ports
|
||||
self.ports_list = []
|
||||
self.ports_list: List[str] = []
|
||||
|
||||
# Create graph (network)
|
||||
self.network = nx.MultiGraph()
|
||||
self.network: nx.Graph = nx.MultiGraph()
|
||||
|
||||
# Create a graph (network) reference
|
||||
self.network_reference = nx.MultiGraph()
|
||||
self.network_reference: nx.Graph = nx.MultiGraph()
|
||||
|
||||
# Create step count
|
||||
self.step_count = 0
|
||||
self.step_count: int = 0
|
||||
|
||||
self.total_step_count: int = 0
|
||||
"""The total number of time steps completed."""
|
||||
|
||||
# Create step info dictionary
|
||||
self.step_info = {}
|
||||
self.step_info: Dict[Any] = {}
|
||||
|
||||
# Total reward
|
||||
self.total_reward: float = 0
|
||||
@@ -152,22 +160,23 @@ class Primaite(Env):
|
||||
self.average_reward: float = 0
|
||||
|
||||
# Episode count
|
||||
self.episode_count = 0
|
||||
self.episode_count: int = 0
|
||||
|
||||
# Number of nodes - gets a value by examining the nodes dictionary after it's been populated
|
||||
self.num_nodes = 0
|
||||
self.num_nodes: int = 0
|
||||
|
||||
# Number of links - gets a value by examining the links dictionary after it's been populated
|
||||
self.num_links = 0
|
||||
self.num_links: int = 0
|
||||
|
||||
# Number of services - gets a value when config is loaded
|
||||
self.num_services = 0
|
||||
self.num_services: int = 0
|
||||
|
||||
# Number of ports - gets a value when config is loaded
|
||||
self.num_ports = 0
|
||||
self.num_ports: int = 0
|
||||
|
||||
# The action type
|
||||
self.action_type = 0
|
||||
# TODO: confirm type
|
||||
self.action_type: int = 0
|
||||
|
||||
# TODO fix up with TrainingConfig
|
||||
# stores the observation config from the yaml, default is NODE_LINK_TABLE
|
||||
@@ -179,7 +188,7 @@ class Primaite(Env):
|
||||
# It will be initialised later.
|
||||
self.obs_handler: ObservationsHandler
|
||||
|
||||
self._obs_space_description = None
|
||||
self._obs_space_description: List[str] = None
|
||||
"The env observation space description for transactions writing"
|
||||
|
||||
# Open the config file and build the environment laydown
|
||||
@@ -211,9 +220,13 @@ class Primaite(Env):
|
||||
_LOGGER.error("Could not save network diagram", exc_info=True)
|
||||
|
||||
# Initiate observation space
|
||||
self.observation_space: spaces.Space
|
||||
self.env_obs: np.ndarray
|
||||
self.observation_space, self.env_obs = self.init_observations()
|
||||
|
||||
# Define Action Space - depends on action space type (Node or ACL)
|
||||
self.action_dict: Dict[int, List[int]]
|
||||
self.action_space: spaces.Space
|
||||
if self.training_config.action_type == ActionType.NODE:
|
||||
_LOGGER.debug("Action space type NODE selected")
|
||||
# Terms (for node action space):
|
||||
@@ -241,8 +254,12 @@ class Primaite(Env):
|
||||
else:
|
||||
_LOGGER.error(f"Invalid action type selected: {self.training_config.action_type}")
|
||||
|
||||
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=True)
|
||||
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=True)
|
||||
self.episode_av_reward_writer: SessionOutputWriter = SessionOutputWriter(
|
||||
self, transaction_writer=False, learning_session=True
|
||||
)
|
||||
self.transaction_writer: SessionOutputWriter = SessionOutputWriter(
|
||||
self, transaction_writer=True, learning_session=True
|
||||
)
|
||||
|
||||
@property
|
||||
def actual_episode_count(self) -> int:
|
||||
@@ -251,7 +268,7 @@ class Primaite(Env):
|
||||
return self.episode_count - 1
|
||||
return self.episode_count
|
||||
|
||||
def set_as_eval(self):
|
||||
def set_as_eval(self) -> None:
|
||||
"""Set the writers to write to eval directories."""
|
||||
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=False)
|
||||
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=False)
|
||||
@@ -260,12 +277,12 @@ class Primaite(Env):
|
||||
self.total_step_count = 0
|
||||
self.episode_steps = self.training_config.num_eval_steps
|
||||
|
||||
def _write_av_reward_per_episode(self):
|
||||
def _write_av_reward_per_episode(self) -> None:
|
||||
if self.actual_episode_count > 0:
|
||||
csv_data = self.actual_episode_count, self.average_reward
|
||||
self.episode_av_reward_writer.write(csv_data)
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> np.ndarray:
|
||||
"""
|
||||
AI Gym Reset function.
|
||||
|
||||
@@ -299,7 +316,7 @@ class Primaite(Env):
|
||||
|
||||
return self.env_obs
|
||||
|
||||
def step(self, action):
|
||||
def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict]:
|
||||
"""
|
||||
AI Gym Step function.
|
||||
|
||||
@@ -418,7 +435,7 @@ class Primaite(Env):
|
||||
# Return
|
||||
return self.env_obs, reward, done, self.step_info
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Override parent close and close writers."""
|
||||
# Close files if last episode/step
|
||||
# if self.can_finish:
|
||||
@@ -427,18 +444,18 @@ class Primaite(Env):
|
||||
self.transaction_writer.close()
|
||||
self.episode_av_reward_writer.close()
|
||||
|
||||
def init_acl(self):
|
||||
def init_acl(self) -> None:
|
||||
"""Initialise the Access Control List."""
|
||||
self.acl.remove_all_rules()
|
||||
|
||||
def output_link_status(self):
|
||||
def output_link_status(self) -> None:
|
||||
"""Output the link status of all links to the console."""
|
||||
for link_key, link_value in self.links.items():
|
||||
_LOGGER.debug("Link ID: " + link_value.get_id())
|
||||
for protocol in link_value.protocol_list:
|
||||
print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load()))
|
||||
|
||||
def interpret_action_and_apply(self, _action):
|
||||
def interpret_action_and_apply(self, _action: int) -> None:
|
||||
"""
|
||||
Applies agent actions to the nodes and Access Control List.
|
||||
|
||||
@@ -446,19 +463,18 @@ class Primaite(Env):
|
||||
_action: The action space from the agent
|
||||
"""
|
||||
# At the moment, actions are only affecting nodes
|
||||
|
||||
if self.training_config.action_type == ActionType.NODE:
|
||||
self.apply_actions_to_nodes(_action)
|
||||
elif self.training_config.action_type == ActionType.ACL:
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif len(self.action_dict[_action]) == 6: # ACL actions in multidiscrete form have len 6
|
||||
elif len(self.action_dict[_action]) == 7: # ACL actions in multidiscrete form have len 7
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif len(self.action_dict[_action]) == 4: # Node actions in multdiscrete (array) from have len 4
|
||||
self.apply_actions_to_nodes(_action)
|
||||
else:
|
||||
logging.error("Invalid action type found")
|
||||
|
||||
def apply_actions_to_nodes(self, _action):
|
||||
def apply_actions_to_nodes(self, _action: int) -> None:
|
||||
"""
|
||||
Applies agent actions to the nodes.
|
||||
|
||||
@@ -546,7 +562,7 @@ class Primaite(Env):
|
||||
else:
|
||||
return
|
||||
|
||||
def apply_actions_to_acl(self, _action):
|
||||
def apply_actions_to_acl(self, _action: int) -> None:
|
||||
"""
|
||||
Applies agent actions to the Access Control List [TO DO].
|
||||
|
||||
@@ -562,6 +578,7 @@ class Primaite(Env):
|
||||
action_destination_ip = readable_action[3]
|
||||
action_protocol = readable_action[4]
|
||||
action_port = readable_action[5]
|
||||
acl_rule_position = readable_action[6]
|
||||
|
||||
if action_decision == 0:
|
||||
# It's decided to do nothing
|
||||
@@ -611,6 +628,7 @@ class Primaite(Env):
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_rule_position,
|
||||
)
|
||||
elif action_decision == 2:
|
||||
# Remove the rule
|
||||
@@ -624,7 +642,7 @@ class Primaite(Env):
|
||||
else:
|
||||
return
|
||||
|
||||
def apply_time_based_updates(self):
|
||||
def apply_time_based_updates(self) -> None:
|
||||
"""
|
||||
Updates anything that needs to count down and then change state.
|
||||
|
||||
@@ -680,12 +698,12 @@ class Primaite(Env):
|
||||
|
||||
return self.obs_handler.space, self.obs_handler.current_observation
|
||||
|
||||
def update_environent_obs(self):
|
||||
def update_environent_obs(self) -> None:
|
||||
"""Updates the observation space based on the node and link status."""
|
||||
self.obs_handler.update_obs()
|
||||
self.env_obs = self.obs_handler.current_observation
|
||||
|
||||
def load_lay_down_config(self):
|
||||
def load_lay_down_config(self) -> None:
|
||||
"""Loads config data in order to build the environment configuration."""
|
||||
for item in self.lay_down_config:
|
||||
if item["item_type"] == "NODE":
|
||||
@@ -723,7 +741,7 @@ class Primaite(Env):
|
||||
_LOGGER.info("Environment configuration loaded")
|
||||
print("Environment configuration loaded")
|
||||
|
||||
def create_node(self, item):
|
||||
def create_node(self, item: Dict) -> None:
|
||||
"""
|
||||
Creates a node from config data.
|
||||
|
||||
@@ -804,7 +822,7 @@ class Primaite(Env):
|
||||
# Add node to network (reference)
|
||||
self.network_reference.add_nodes_from([node_ref])
|
||||
|
||||
def create_link(self, item: Dict):
|
||||
def create_link(self, item: Dict) -> None:
|
||||
"""
|
||||
Creates a link from config data.
|
||||
|
||||
@@ -848,7 +866,7 @@ class Primaite(Env):
|
||||
self.services_list,
|
||||
)
|
||||
|
||||
def create_green_ier(self, item):
|
||||
def create_green_ier(self, item: Dict) -> None:
|
||||
"""
|
||||
Creates a green IER from config data.
|
||||
|
||||
@@ -889,7 +907,7 @@ class Primaite(Env):
|
||||
ier_mission_criticality,
|
||||
)
|
||||
|
||||
def create_red_ier(self, item):
|
||||
def create_red_ier(self, item: Dict) -> None:
|
||||
"""
|
||||
Creates a red IER from config data.
|
||||
|
||||
@@ -919,7 +937,7 @@ class Primaite(Env):
|
||||
ier_mission_criticality,
|
||||
)
|
||||
|
||||
def create_green_pol(self, item):
|
||||
def create_green_pol(self, item: Dict) -> None:
|
||||
"""
|
||||
Creates a green PoL object from config data.
|
||||
|
||||
@@ -953,7 +971,7 @@ class Primaite(Env):
|
||||
pol_state,
|
||||
)
|
||||
|
||||
def create_red_pol(self, item):
|
||||
def create_red_pol(self, item: Dict) -> None:
|
||||
"""
|
||||
Creates a red PoL object from config data.
|
||||
|
||||
@@ -994,7 +1012,7 @@ class Primaite(Env):
|
||||
pol_source_node_service_state,
|
||||
)
|
||||
|
||||
def create_acl_rule(self, item):
|
||||
def create_acl_rule(self, item: Dict) -> None:
|
||||
"""
|
||||
Creates an ACL rule from config data.
|
||||
|
||||
@@ -1006,6 +1024,7 @@ class Primaite(Env):
|
||||
acl_rule_destination = item["destination"]
|
||||
acl_rule_protocol = item["protocol"]
|
||||
acl_rule_port = item["port"]
|
||||
acl_rule_position = item["position"]
|
||||
|
||||
self.acl.add_rule(
|
||||
acl_rule_permission,
|
||||
@@ -1013,9 +1032,11 @@ class Primaite(Env):
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_rule_position,
|
||||
)
|
||||
|
||||
def create_services_list(self, services):
|
||||
# TODO: confirm typehint using runtime
|
||||
def create_services_list(self, services: Dict) -> None:
|
||||
"""
|
||||
Creates a list of services (enum) from config data.
|
||||
|
||||
@@ -1031,7 +1052,7 @@ class Primaite(Env):
|
||||
# Set the number of services
|
||||
self.num_services = len(self.services_list)
|
||||
|
||||
def create_ports_list(self, ports):
|
||||
def create_ports_list(self, ports: Dict) -> None:
|
||||
"""
|
||||
Creates a list of ports from config data.
|
||||
|
||||
@@ -1047,7 +1068,8 @@ class Primaite(Env):
|
||||
# Set the number of ports
|
||||
self.num_ports = len(self.ports_list)
|
||||
|
||||
def get_observation_info(self, observation_info):
|
||||
# TODO: this is not used anymore, write a ticket to delete it
|
||||
def get_observation_info(self, observation_info: Dict) -> None:
|
||||
"""
|
||||
Extracts observation_info.
|
||||
|
||||
@@ -1056,7 +1078,8 @@ class Primaite(Env):
|
||||
"""
|
||||
self.observation_type = ObservationType[observation_info["type"]]
|
||||
|
||||
def get_action_info(self, action_info):
|
||||
# TODO: this is not used anymore, write a ticket to delete it.
|
||||
def get_action_info(self, action_info: Dict) -> None:
|
||||
"""
|
||||
Extracts action_info.
|
||||
|
||||
@@ -1065,7 +1088,7 @@ class Primaite(Env):
|
||||
"""
|
||||
self.action_type = ActionType[action_info["type"]]
|
||||
|
||||
def save_obs_config(self, obs_config: dict):
|
||||
def save_obs_config(self, obs_config: dict) -> None:
|
||||
"""
|
||||
Cache the config for the observation space.
|
||||
|
||||
@@ -1078,7 +1101,7 @@ class Primaite(Env):
|
||||
"""
|
||||
self.obs_config = obs_config
|
||||
|
||||
def reset_environment(self):
|
||||
def reset_environment(self) -> None:
|
||||
"""
|
||||
Resets environment.
|
||||
|
||||
@@ -1103,7 +1126,7 @@ class Primaite(Env):
|
||||
for ier_key, ier_value in self.red_iers.items():
|
||||
ier_value.set_is_running(False)
|
||||
|
||||
def reset_node(self, item):
|
||||
def reset_node(self, item: Dict) -> None:
|
||||
"""
|
||||
Resets the statuses of a node.
|
||||
|
||||
@@ -1151,7 +1174,7 @@ class Primaite(Env):
|
||||
# Bad formatting
|
||||
pass
|
||||
|
||||
def create_node_action_dict(self):
|
||||
def create_node_action_dict(self) -> Dict[int, List[int]]:
|
||||
"""
|
||||
Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.
|
||||
|
||||
@@ -1167,6 +1190,11 @@ class Primaite(Env):
|
||||
...
|
||||
}
|
||||
"""
|
||||
# Terms (for node action space):
|
||||
# [0, num nodes] - node ID (0 = nothing, node ID)
|
||||
# [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa
|
||||
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa
|
||||
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
|
||||
# reserve 0 action to be a nothing action
|
||||
actions = {0: [1, 0, 0, 0]}
|
||||
action_key = 1
|
||||
@@ -1186,10 +1214,18 @@ class Primaite(Env):
|
||||
|
||||
return actions
|
||||
|
||||
def create_acl_action_dict(self):
|
||||
def create_acl_action_dict(self) -> Dict[int, List[int]]:
|
||||
"""Creates a dictionary mapping each possible discrete action to more readable multidiscrete action."""
|
||||
# Terms (for ACL action space):
|
||||
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
|
||||
# [0, 1] - Permission (0 = DENY, 1 = ALLOW)
|
||||
# [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
|
||||
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
|
||||
# [0, max acl rules - 1] - Position (0 = first index, then 1 -> x index resolving to acl rule in acl list)
|
||||
# reserve 0 action to be a nothing action
|
||||
actions = {0: [0, 0, 0, 0, 0, 0]}
|
||||
actions = {0: [0, 0, 0, 0, 0, 0, 0]}
|
||||
|
||||
action_key = 1
|
||||
# 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE
|
||||
@@ -1201,22 +1237,25 @@ class Primaite(Env):
|
||||
for dest_ip in range(self.num_nodes + 1):
|
||||
for protocol in range(self.num_services + 1):
|
||||
for port in range(self.num_ports + 1):
|
||||
action = [
|
||||
action_decision,
|
||||
action_permission,
|
||||
source_ip,
|
||||
dest_ip,
|
||||
protocol,
|
||||
port,
|
||||
]
|
||||
# Check to see if its an action we want to include as possible i.e. not a nothing action
|
||||
if is_valid_acl_action_extra(action):
|
||||
actions[action_key] = action
|
||||
action_key += 1
|
||||
for position in range(self.max_number_acl_rules - 1):
|
||||
action = [
|
||||
action_decision,
|
||||
action_permission,
|
||||
source_ip,
|
||||
dest_ip,
|
||||
protocol,
|
||||
port,
|
||||
position,
|
||||
]
|
||||
# Check to see if it is an action we want to include as possible
|
||||
# i.e. not a nothing action
|
||||
if is_valid_acl_action_extra(action):
|
||||
actions[action_key] = action
|
||||
action_key += 1
|
||||
|
||||
return actions
|
||||
|
||||
def create_node_and_acl_action_dict(self):
|
||||
def create_node_and_acl_action_dict(self) -> Dict[int, List[int]]:
|
||||
"""
|
||||
Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action.
|
||||
|
||||
@@ -1233,7 +1272,7 @@ class Primaite(Env):
|
||||
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
|
||||
return combined_action_dict
|
||||
|
||||
def _create_random_red_agent(self):
|
||||
def _create_random_red_agent(self) -> None:
|
||||
"""Decide on random red agent for the episode to be called in env.reset()."""
|
||||
# Reset the current red iers and red node pol
|
||||
self.red_iers = {}
|
||||
|
||||
@@ -1,25 +1,31 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Implements reward function."""
|
||||
from typing import Dict
|
||||
from logging import Logger
|
||||
from typing import Dict, TYPE_CHECKING, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
|
||||
from primaite.common.service import Service
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
if TYPE_CHECKING:
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_reward_function(
|
||||
initial_nodes,
|
||||
final_nodes,
|
||||
reference_nodes,
|
||||
green_iers,
|
||||
green_iers_reference,
|
||||
red_iers,
|
||||
step_count,
|
||||
config_values,
|
||||
initial_nodes: Dict[str, NodeUnion],
|
||||
final_nodes: Dict[str, NodeUnion],
|
||||
reference_nodes: Dict[str, NodeUnion],
|
||||
green_iers: Dict[str, "IER"],
|
||||
green_iers_reference: Dict[str, "IER"],
|
||||
red_iers: Dict[str, "IER"],
|
||||
step_count: int,
|
||||
config_values: "TrainingConfig",
|
||||
) -> float:
|
||||
"""
|
||||
Compares the states of the initial and final nodes/links to get a reward.
|
||||
@@ -93,7 +99,9 @@ def calculate_reward_function(
|
||||
return reward_value
|
||||
|
||||
|
||||
def score_node_operating_state(final_node, initial_node, reference_node, config_values) -> float:
|
||||
def score_node_operating_state(
|
||||
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig"
|
||||
) -> float:
|
||||
"""
|
||||
Calculates score relating to the hardware state of a node.
|
||||
|
||||
@@ -142,7 +150,12 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
|
||||
return score
|
||||
|
||||
|
||||
def score_node_os_state(final_node, initial_node, reference_node, config_values) -> float:
|
||||
def score_node_os_state(
|
||||
final_node: Union[ActiveNode, ServiceNode],
|
||||
initial_node: Union[ActiveNode, ServiceNode],
|
||||
reference_node: Union[ActiveNode, ServiceNode],
|
||||
config_values: "TrainingConfig",
|
||||
) -> float:
|
||||
"""
|
||||
Calculates score relating to the Software State of a node.
|
||||
|
||||
@@ -193,7 +206,9 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
|
||||
return score
|
||||
|
||||
|
||||
def score_node_service_state(final_node, initial_node, reference_node, config_values) -> float:
|
||||
def score_node_service_state(
|
||||
final_node: ServiceNode, initial_node: ServiceNode, reference_node: ServiceNode, config_values: "TrainingConfig"
|
||||
) -> float:
|
||||
"""
|
||||
Calculates score relating to the service state(s) of a node.
|
||||
|
||||
@@ -265,7 +280,12 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
|
||||
return score
|
||||
|
||||
|
||||
def score_node_file_system(final_node, initial_node, reference_node, config_values) -> float:
|
||||
def score_node_file_system(
|
||||
final_node: Union[ActiveNode, ServiceNode],
|
||||
initial_node: Union[ActiveNode, ServiceNode],
|
||||
reference_node: Union[ActiveNode, ServiceNode],
|
||||
config_values: "TrainingConfig",
|
||||
) -> float:
|
||||
"""
|
||||
Calculates score relating to the file system state of a node.
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from primaite.common.protocol import Protocol
|
||||
class Link(object):
|
||||
"""Link class."""
|
||||
|
||||
def __init__(self, _id, _bandwidth, _source_node_name, _dest_node_name, _services):
|
||||
def __init__(self, _id: str, _bandwidth: int, _source_node_name: str, _dest_node_name: str, _services: str) -> None:
|
||||
"""
|
||||
Initialise a Link within the simulated network.
|
||||
|
||||
@@ -18,17 +18,17 @@ class Link(object):
|
||||
:param _dest_node_name: The name of the destination node
|
||||
:param _protocols: The protocols to add to the link
|
||||
"""
|
||||
self.id = _id
|
||||
self.bandwidth = _bandwidth
|
||||
self.source_node_name = _source_node_name
|
||||
self.dest_node_name = _dest_node_name
|
||||
self.id: str = _id
|
||||
self.bandwidth: int = _bandwidth
|
||||
self.source_node_name: str = _source_node_name
|
||||
self.dest_node_name: str = _dest_node_name
|
||||
self.protocol_list: List[Protocol] = []
|
||||
|
||||
# Add the default protocols
|
||||
for protocol_name in _services:
|
||||
self.add_protocol(protocol_name)
|
||||
|
||||
def add_protocol(self, _protocol):
|
||||
def add_protocol(self, _protocol: str) -> None:
|
||||
"""
|
||||
Adds a new protocol to the list of protocols on this link.
|
||||
|
||||
@@ -37,7 +37,7 @@ class Link(object):
|
||||
"""
|
||||
self.protocol_list.append(Protocol(_protocol))
|
||||
|
||||
def get_id(self):
|
||||
def get_id(self) -> str:
|
||||
"""
|
||||
Gets link ID.
|
||||
|
||||
@@ -46,7 +46,7 @@ class Link(object):
|
||||
"""
|
||||
return self.id
|
||||
|
||||
def get_source_node_name(self):
|
||||
def get_source_node_name(self) -> str:
|
||||
"""
|
||||
Gets source node name.
|
||||
|
||||
@@ -55,7 +55,7 @@ class Link(object):
|
||||
"""
|
||||
return self.source_node_name
|
||||
|
||||
def get_dest_node_name(self):
|
||||
def get_dest_node_name(self) -> str:
|
||||
"""
|
||||
Gets destination node name.
|
||||
|
||||
@@ -64,7 +64,7 @@ class Link(object):
|
||||
"""
|
||||
return self.dest_node_name
|
||||
|
||||
def get_bandwidth(self):
|
||||
def get_bandwidth(self) -> int:
|
||||
"""
|
||||
Gets bandwidth of link.
|
||||
|
||||
@@ -73,7 +73,7 @@ class Link(object):
|
||||
"""
|
||||
return self.bandwidth
|
||||
|
||||
def get_protocol_list(self):
|
||||
def get_protocol_list(self) -> List[Protocol]:
|
||||
"""
|
||||
Gets list of protocols on this link.
|
||||
|
||||
@@ -82,7 +82,7 @@ class Link(object):
|
||||
"""
|
||||
return self.protocol_list
|
||||
|
||||
def get_current_load(self):
|
||||
def get_current_load(self) -> int:
|
||||
"""
|
||||
Gets current total load on this link.
|
||||
|
||||
@@ -94,7 +94,7 @@ class Link(object):
|
||||
total_load += protocol.get_load()
|
||||
return total_load
|
||||
|
||||
def add_protocol_load(self, _protocol, _load):
|
||||
def add_protocol_load(self, _protocol: str, _load: int) -> None:
|
||||
"""
|
||||
Adds a loading to a protocol on this link.
|
||||
|
||||
@@ -108,7 +108,7 @@ class Link(object):
|
||||
else:
|
||||
pass
|
||||
|
||||
def clear_traffic(self):
|
||||
def clear_traffic(self) -> None:
|
||||
"""Clears all traffic on this link."""
|
||||
for protocol in self.protocol_list:
|
||||
protocol.clear_load()
|
||||
|
||||
@@ -14,7 +14,7 @@ def run(
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Run the PrimAITE Session.
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ class ActiveNode(Node):
|
||||
software_state: SoftwareState,
|
||||
file_system_state: FileSystemState,
|
||||
config_values: TrainingConfig,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an active node.
|
||||
|
||||
@@ -60,7 +60,7 @@ class ActiveNode(Node):
|
||||
return self._software_state
|
||||
|
||||
@software_state.setter
|
||||
def software_state(self, software_state: SoftwareState):
|
||||
def software_state(self, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Get the software_state.
|
||||
|
||||
@@ -79,7 +79,7 @@ class ActiveNode(Node):
|
||||
f"Node.software_state:{self._software_state}"
|
||||
)
|
||||
|
||||
def set_software_state_if_not_compromised(self, software_state: SoftwareState):
|
||||
def set_software_state_if_not_compromised(self, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Sets Software State if the node is not compromised.
|
||||
|
||||
@@ -99,14 +99,14 @@ class ActiveNode(Node):
|
||||
f"Node.software_state:{self._software_state}"
|
||||
)
|
||||
|
||||
def update_os_patching_status(self):
|
||||
def update_os_patching_status(self) -> None:
|
||||
"""Updates operating system status based on patching cycle."""
|
||||
self.patching_count -= 1
|
||||
if self.patching_count <= 0:
|
||||
self.patching_count = 0
|
||||
self._software_state = SoftwareState.GOOD
|
||||
|
||||
def set_file_system_state(self, file_system_state: FileSystemState):
|
||||
def set_file_system_state(self, file_system_state: FileSystemState) -> None:
|
||||
"""
|
||||
Sets the file system state (actual and observed).
|
||||
|
||||
@@ -133,7 +133,7 @@ class ActiveNode(Node):
|
||||
f"Node.file_system_state.actual:{self.file_system_state_actual}"
|
||||
)
|
||||
|
||||
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState):
|
||||
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState) -> None:
|
||||
"""
|
||||
Sets the file system state (actual and observed) if not in a compromised state.
|
||||
|
||||
@@ -166,12 +166,12 @@ class ActiveNode(Node):
|
||||
f"Node.file_system_state.actual:{self.file_system_state_actual}"
|
||||
)
|
||||
|
||||
def start_file_system_scan(self):
|
||||
def start_file_system_scan(self) -> None:
|
||||
"""Starts a file system scan."""
|
||||
self.file_system_scanning = True
|
||||
self.file_system_scanning_count = self.config_values.file_system_scanning_limit
|
||||
|
||||
def update_file_system_state(self):
|
||||
def update_file_system_state(self) -> None:
|
||||
"""Updates file system status based on scanning/restore/repair cycle."""
|
||||
# Deprecate both the action count (for restoring or reparing) and the scanning count
|
||||
self.file_system_action_count -= 1
|
||||
@@ -193,14 +193,14 @@ class ActiveNode(Node):
|
||||
self.file_system_scanning = False
|
||||
self.file_system_scanning_count = 0
|
||||
|
||||
def update_resetting_status(self):
|
||||
def update_resetting_status(self) -> None:
|
||||
"""Updates the reset count & makes software and file state to GOOD."""
|
||||
super().update_resetting_status()
|
||||
if self.resetting_count <= 0:
|
||||
self.file_system_state_actual = FileSystemState.GOOD
|
||||
self.software_state = SoftwareState.GOOD
|
||||
|
||||
def update_booting_status(self):
|
||||
def update_booting_status(self) -> None:
|
||||
"""Updates the booting software and file state to GOOD."""
|
||||
super().update_booting_status()
|
||||
if self.booting_count <= 0:
|
||||
|
||||
@@ -17,7 +17,7 @@ class Node:
|
||||
priority: Priority,
|
||||
hardware_state: HardwareState,
|
||||
config_values: TrainingConfig,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a node.
|
||||
|
||||
@@ -38,40 +38,40 @@ class Node:
|
||||
self.booting_count: int = 0
|
||||
self.shutting_down_count: int = 0
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
"""Returns the name of the node."""
|
||||
return self.name
|
||||
|
||||
def turn_on(self):
|
||||
def turn_on(self) -> None:
|
||||
"""Sets the node state to ON."""
|
||||
self.hardware_state = HardwareState.BOOTING
|
||||
self.booting_count = self.config_values.node_booting_duration
|
||||
|
||||
def turn_off(self):
|
||||
def turn_off(self) -> None:
|
||||
"""Sets the node state to OFF."""
|
||||
self.hardware_state = HardwareState.OFF
|
||||
self.shutting_down_count = self.config_values.node_shutdown_duration
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
"""Sets the node state to Resetting and starts the reset count."""
|
||||
self.hardware_state = HardwareState.RESETTING
|
||||
self.resetting_count = self.config_values.node_reset_duration
|
||||
|
||||
def update_resetting_status(self):
|
||||
def update_resetting_status(self) -> None:
|
||||
"""Updates the resetting count."""
|
||||
self.resetting_count -= 1
|
||||
if self.resetting_count <= 0:
|
||||
self.resetting_count = 0
|
||||
self.hardware_state = HardwareState.ON
|
||||
|
||||
def update_booting_status(self):
|
||||
def update_booting_status(self) -> None:
|
||||
"""Updates the booting count."""
|
||||
self.booting_count -= 1
|
||||
if self.booting_count <= 0:
|
||||
self.booting_count = 0
|
||||
self.hardware_state = HardwareState.ON
|
||||
|
||||
def update_shutdown_status(self):
|
||||
def update_shutdown_status(self) -> None:
|
||||
"""Updates the shutdown count."""
|
||||
self.shutting_down_count -= 1
|
||||
if self.shutting_down_count <= 0:
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Defines node behaviour for Green PoL."""
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodePOLType, SoftwareState
|
||||
|
||||
|
||||
class NodeStateInstructionGreen(object):
|
||||
@@ -7,14 +11,14 @@ class NodeStateInstructionGreen(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_id,
|
||||
_start_step,
|
||||
_end_step,
|
||||
_node_id,
|
||||
_node_pol_type,
|
||||
_service_name,
|
||||
_state,
|
||||
):
|
||||
_id: str,
|
||||
_start_step: int,
|
||||
_end_step: int,
|
||||
_node_id: str,
|
||||
_node_pol_type: "NodePOLType",
|
||||
_service_name: str,
|
||||
_state: Union["HardwareState", "SoftwareState", "FileSystemState"],
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the Node State Instruction.
|
||||
|
||||
@@ -30,11 +34,12 @@ class NodeStateInstructionGreen(object):
|
||||
self.start_step = _start_step
|
||||
self.end_step = _end_step
|
||||
self.node_id = _node_id
|
||||
self.node_pol_type = _node_pol_type
|
||||
self.service_name = _service_name # Not used when not a service instruction
|
||||
self.state = _state
|
||||
self.node_pol_type: "NodePOLType" = _node_pol_type
|
||||
self.service_name: str = _service_name # Not used when not a service instruction
|
||||
# TODO: confirm type of state
|
||||
self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _state
|
||||
|
||||
def get_start_step(self):
|
||||
def get_start_step(self) -> int:
|
||||
"""
|
||||
Gets the start step.
|
||||
|
||||
@@ -43,7 +48,7 @@ class NodeStateInstructionGreen(object):
|
||||
"""
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self):
|
||||
def get_end_step(self) -> int:
|
||||
"""
|
||||
Gets the end step.
|
||||
|
||||
@@ -52,7 +57,7 @@ class NodeStateInstructionGreen(object):
|
||||
"""
|
||||
return self.end_step
|
||||
|
||||
def get_node_id(self):
|
||||
def get_node_id(self) -> str:
|
||||
"""
|
||||
Gets the node ID.
|
||||
|
||||
@@ -61,7 +66,7 @@ class NodeStateInstructionGreen(object):
|
||||
"""
|
||||
return self.node_id
|
||||
|
||||
def get_node_pol_type(self):
|
||||
def get_node_pol_type(self) -> "NodePOLType":
|
||||
"""
|
||||
Gets the node pattern of life type (enum).
|
||||
|
||||
@@ -70,7 +75,7 @@ class NodeStateInstructionGreen(object):
|
||||
"""
|
||||
return self.node_pol_type
|
||||
|
||||
def get_service_name(self):
|
||||
def get_service_name(self) -> str:
|
||||
"""
|
||||
Gets the service name.
|
||||
|
||||
@@ -79,7 +84,7 @@ class NodeStateInstructionGreen(object):
|
||||
"""
|
||||
return self.service_name
|
||||
|
||||
def get_state(self):
|
||||
def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]:
|
||||
"""
|
||||
Gets the state (node or service).
|
||||
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Defines node behaviour for Green PoL."""
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from primaite.common.enums import NodePOLType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodePOLInitiator, SoftwareState
|
||||
|
||||
|
||||
@dataclass()
|
||||
class NodeStateInstructionRed(object):
|
||||
@@ -11,18 +15,18 @@ class NodeStateInstructionRed(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_id,
|
||||
_start_step,
|
||||
_end_step,
|
||||
_target_node_id,
|
||||
_pol_initiator,
|
||||
_id: str,
|
||||
_start_step: int,
|
||||
_end_step: int,
|
||||
_target_node_id: str,
|
||||
_pol_initiator: "NodePOLInitiator",
|
||||
_pol_type: NodePOLType,
|
||||
pol_protocol,
|
||||
_pol_state,
|
||||
_pol_source_node_id,
|
||||
_pol_source_node_service,
|
||||
_pol_source_node_service_state,
|
||||
):
|
||||
pol_protocol: str,
|
||||
_pol_state: Union["HardwareState", "SoftwareState", "FileSystemState"],
|
||||
_pol_source_node_id: str,
|
||||
_pol_source_node_service: str,
|
||||
_pol_source_node_service_state: str,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the Node State Instruction for the red agent.
|
||||
|
||||
@@ -38,19 +42,19 @@ class NodeStateInstructionRed(object):
|
||||
:param _pol_source_node_service: The source node service (used for initiator type SERVICE)
|
||||
:param _pol_source_node_service_state: The source node service state (used for initiator type SERVICE)
|
||||
"""
|
||||
self.id = _id
|
||||
self.start_step = _start_step
|
||||
self.end_step = _end_step
|
||||
self.target_node_id = _target_node_id
|
||||
self.initiator = _pol_initiator
|
||||
self.id: str = _id
|
||||
self.start_step: int = _start_step
|
||||
self.end_step: int = _end_step
|
||||
self.target_node_id: str = _target_node_id
|
||||
self.initiator: "NodePOLInitiator" = _pol_initiator
|
||||
self.pol_type: NodePOLType = _pol_type
|
||||
self.service_name = pol_protocol # Not used when not a service instruction
|
||||
self.state = _pol_state
|
||||
self.source_node_id = _pol_source_node_id
|
||||
self.source_node_service = _pol_source_node_service
|
||||
self.service_name: str = pol_protocol # Not used when not a service instruction
|
||||
self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _pol_state
|
||||
self.source_node_id: str = _pol_source_node_id
|
||||
self.source_node_service: str = _pol_source_node_service
|
||||
self.source_node_service_state = _pol_source_node_service_state
|
||||
|
||||
def get_start_step(self):
|
||||
def get_start_step(self) -> int:
|
||||
"""
|
||||
Gets the start step.
|
||||
|
||||
@@ -59,7 +63,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self):
|
||||
def get_end_step(self) -> int:
|
||||
"""
|
||||
Gets the end step.
|
||||
|
||||
@@ -68,7 +72,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.end_step
|
||||
|
||||
def get_target_node_id(self):
|
||||
def get_target_node_id(self) -> str:
|
||||
"""
|
||||
Gets the node ID.
|
||||
|
||||
@@ -77,7 +81,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.target_node_id
|
||||
|
||||
def get_initiator(self):
|
||||
def get_initiator(self) -> "NodePOLInitiator":
|
||||
"""
|
||||
Gets the initiator.
|
||||
|
||||
@@ -95,7 +99,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.pol_type
|
||||
|
||||
def get_service_name(self):
|
||||
def get_service_name(self) -> str:
|
||||
"""
|
||||
Gets the service name.
|
||||
|
||||
@@ -104,7 +108,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.service_name
|
||||
|
||||
def get_state(self):
|
||||
def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]:
|
||||
"""
|
||||
Gets the state (node or service).
|
||||
|
||||
@@ -113,7 +117,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.state
|
||||
|
||||
def get_source_node_id(self):
|
||||
def get_source_node_id(self) -> str:
|
||||
"""
|
||||
Gets the source node id (used for initiator type SERVICE).
|
||||
|
||||
@@ -122,7 +126,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.source_node_id
|
||||
|
||||
def get_source_node_service(self):
|
||||
def get_source_node_service(self) -> str:
|
||||
"""
|
||||
Gets the source node service (used for initiator type SERVICE).
|
||||
|
||||
@@ -131,7 +135,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.source_node_service
|
||||
|
||||
def get_source_node_service_state(self):
|
||||
def get_source_node_service_state(self) -> str:
|
||||
"""
|
||||
Gets the source node service state (used for initiator type SERVICE).
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ class PassiveNode(Node):
|
||||
priority: Priority,
|
||||
hardware_state: HardwareState,
|
||||
config_values: TrainingConfig,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a passive node.
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ class ServiceNode(ActiveNode):
|
||||
software_state: SoftwareState,
|
||||
file_system_state: FileSystemState,
|
||||
config_values: TrainingConfig,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a Service Node.
|
||||
|
||||
@@ -52,7 +52,7 @@ class ServiceNode(ActiveNode):
|
||||
)
|
||||
self.services: Dict[str, Service] = {}
|
||||
|
||||
def add_service(self, service: Service):
|
||||
def add_service(self, service: Service) -> None:
|
||||
"""
|
||||
Adds a service to the node.
|
||||
|
||||
@@ -102,7 +102,7 @@ class ServiceNode(ActiveNode):
|
||||
return False
|
||||
return False
|
||||
|
||||
def set_service_state(self, protocol_name: str, software_state: SoftwareState):
|
||||
def set_service_state(self, protocol_name: str, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Sets the software_state of a service (protocol) on the node.
|
||||
|
||||
@@ -131,7 +131,7 @@ class ServiceNode(ActiveNode):
|
||||
f"Node.services[<key>].software_state:{software_state}"
|
||||
)
|
||||
|
||||
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState):
|
||||
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Sets the software_state of a service (protocol) on the node.
|
||||
|
||||
@@ -158,7 +158,7 @@ class ServiceNode(ActiveNode):
|
||||
f"Node.services[<key>].software_state:{software_state}"
|
||||
)
|
||||
|
||||
def get_service_state(self, protocol_name):
|
||||
def get_service_state(self, protocol_name: str) -> SoftwareState:
|
||||
"""
|
||||
Gets the state of a service.
|
||||
|
||||
@@ -169,20 +169,20 @@ class ServiceNode(ActiveNode):
|
||||
if service_value:
|
||||
return service_value.software_state
|
||||
|
||||
def update_services_patching_status(self):
|
||||
def update_services_patching_status(self) -> None:
|
||||
"""Updates the patching counter for any service that are patching."""
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_value.software_state == SoftwareState.PATCHING:
|
||||
service_value.reduce_patching_count()
|
||||
|
||||
def update_resetting_status(self):
|
||||
def update_resetting_status(self) -> None:
|
||||
"""Update resetting counter and set software state if it reached 0."""
|
||||
super().update_resetting_status()
|
||||
if self.resetting_count <= 0:
|
||||
for service in self.services.values():
|
||||
service.software_state = SoftwareState.GOOD
|
||||
|
||||
def update_booting_status(self):
|
||||
def update_booting_status(self) -> None:
|
||||
"""Update booting counter and set software to good if it reached 0."""
|
||||
super().update_booting_status()
|
||||
if self.booting_count <= 0:
|
||||
|
||||
@@ -5,13 +5,14 @@ import importlib.util
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from logging import Logger
|
||||
|
||||
from primaite import getLogger, NOTEBOOKS_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def start_jupyter_session():
|
||||
def start_jupyter_session() -> None:
|
||||
"""
|
||||
Starts a new Jupyter notebook session in the app notebooks directory.
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Implements Pattern of Life on the network (nodes and links)."""
|
||||
from typing import Dict, Union
|
||||
from typing import Dict
|
||||
|
||||
from networkx import MultiGraph, shortest_path
|
||||
|
||||
@@ -10,11 +10,10 @@ from primaite.common.enums import HardwareState, NodePOLType, NodeType, Software
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
_VERBOSE = False
|
||||
_VERBOSE: bool = False
|
||||
|
||||
|
||||
def apply_iers(
|
||||
@@ -24,7 +23,7 @@ def apply_iers(
|
||||
iers: Dict[str, IER],
|
||||
acl: AccessControlList,
|
||||
step: int,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Applies IERs to the links (link pattern of life).
|
||||
|
||||
@@ -65,6 +64,8 @@ def apply_iers(
|
||||
dest_node = nodes[dest_node_id]
|
||||
|
||||
# 1. Check the source node situation
|
||||
# TODO: should be using isinstance rather than checking node type attribute. IE. just because it's a switch
|
||||
# doesn't mean it has a software state? It could be a PassiveNode or ActiveNode
|
||||
if source_node.node_type == NodeType.SWITCH:
|
||||
# It's a switch
|
||||
if (
|
||||
@@ -215,9 +216,9 @@ def apply_iers(
|
||||
|
||||
def apply_node_pol(
|
||||
nodes: Dict[str, NodeUnion],
|
||||
node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]],
|
||||
node_pol: Dict[str, NodeStateInstructionGreen],
|
||||
step: int,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Applies node pattern of life.
|
||||
|
||||
|
||||
@@ -11,17 +11,17 @@ class IER(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_id,
|
||||
_start_step,
|
||||
_end_step,
|
||||
_load,
|
||||
_protocol,
|
||||
_port,
|
||||
_source_node_id,
|
||||
_dest_node_id,
|
||||
_mission_criticality,
|
||||
_running=False,
|
||||
):
|
||||
_id: str,
|
||||
_start_step: int,
|
||||
_end_step: int,
|
||||
_load: int,
|
||||
_protocol: str,
|
||||
_port: str,
|
||||
_source_node_id: str,
|
||||
_dest_node_id: str,
|
||||
_mission_criticality: int,
|
||||
_running: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an Information Exchange Request.
|
||||
|
||||
@@ -36,18 +36,18 @@ class IER(object):
|
||||
:param _mission_criticality: Criticality of this IER to the mission (0 none, 5 mission critical)
|
||||
:param _running: Indicates whether the IER is currently running
|
||||
"""
|
||||
self.id = _id
|
||||
self.start_step = _start_step
|
||||
self.end_step = _end_step
|
||||
self.source_node_id = _source_node_id
|
||||
self.dest_node_id = _dest_node_id
|
||||
self.load = _load
|
||||
self.protocol = _protocol
|
||||
self.port = _port
|
||||
self.mission_criticality = _mission_criticality
|
||||
self.running = _running
|
||||
self.id: str = _id
|
||||
self.start_step: int = _start_step
|
||||
self.end_step: int = _end_step
|
||||
self.source_node_id: str = _source_node_id
|
||||
self.dest_node_id: str = _dest_node_id
|
||||
self.load: int = _load
|
||||
self.protocol: str = _protocol
|
||||
self.port: str = _port
|
||||
self.mission_criticality: int = _mission_criticality
|
||||
self.running: bool = _running
|
||||
|
||||
def get_id(self):
|
||||
def get_id(self) -> str:
|
||||
"""
|
||||
Gets IER ID.
|
||||
|
||||
@@ -56,7 +56,7 @@ class IER(object):
|
||||
"""
|
||||
return self.id
|
||||
|
||||
def get_start_step(self):
|
||||
def get_start_step(self) -> int:
|
||||
"""
|
||||
Gets IER start step.
|
||||
|
||||
@@ -65,7 +65,7 @@ class IER(object):
|
||||
"""
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self):
|
||||
def get_end_step(self) -> int:
|
||||
"""
|
||||
Gets IER end step.
|
||||
|
||||
@@ -74,7 +74,7 @@ class IER(object):
|
||||
"""
|
||||
return self.end_step
|
||||
|
||||
def get_load(self):
|
||||
def get_load(self) -> int:
|
||||
"""
|
||||
Gets IER load.
|
||||
|
||||
@@ -83,7 +83,7 @@ class IER(object):
|
||||
"""
|
||||
return self.load
|
||||
|
||||
def get_protocol(self):
|
||||
def get_protocol(self) -> str:
|
||||
"""
|
||||
Gets IER protocol.
|
||||
|
||||
@@ -92,7 +92,7 @@ class IER(object):
|
||||
"""
|
||||
return self.protocol
|
||||
|
||||
def get_port(self):
|
||||
def get_port(self) -> str:
|
||||
"""
|
||||
Gets IER port.
|
||||
|
||||
@@ -101,7 +101,7 @@ class IER(object):
|
||||
"""
|
||||
return self.port
|
||||
|
||||
def get_source_node_id(self):
|
||||
def get_source_node_id(self) -> str:
|
||||
"""
|
||||
Gets IER source node ID.
|
||||
|
||||
@@ -110,7 +110,7 @@ class IER(object):
|
||||
"""
|
||||
return self.source_node_id
|
||||
|
||||
def get_dest_node_id(self):
|
||||
def get_dest_node_id(self) -> str:
|
||||
"""
|
||||
Gets IER destination node ID.
|
||||
|
||||
@@ -119,7 +119,7 @@ class IER(object):
|
||||
"""
|
||||
return self.dest_node_id
|
||||
|
||||
def get_is_running(self):
|
||||
def get_is_running(self) -> bool:
|
||||
"""
|
||||
Informs whether the IER is currently running.
|
||||
|
||||
@@ -128,7 +128,7 @@ class IER(object):
|
||||
"""
|
||||
return self.running
|
||||
|
||||
def set_is_running(self, _value):
|
||||
def set_is_running(self, _value: bool) -> None:
|
||||
"""
|
||||
Sets the running state of the IER.
|
||||
|
||||
@@ -137,7 +137,7 @@ class IER(object):
|
||||
"""
|
||||
self.running = _value
|
||||
|
||||
def get_mission_criticality(self):
|
||||
def get_mission_criticality(self) -> int:
|
||||
"""
|
||||
Gets the IER mission criticality (used in the reward function).
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Dict
|
||||
|
||||
from networkx import MultiGraph, shortest_path
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import HardwareState, NodePOLInitiator, NodePOLType, NodeType, SoftwareState
|
||||
@@ -13,7 +14,9 @@ from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
_VERBOSE = False
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
_VERBOSE: bool = False
|
||||
|
||||
|
||||
def apply_red_agent_iers(
|
||||
@@ -23,7 +26,7 @@ def apply_red_agent_iers(
|
||||
iers: Dict[str, IER],
|
||||
acl: AccessControlList,
|
||||
step: int,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Applies IERs to the links (link POL) resulting from red agent attack.
|
||||
|
||||
@@ -74,6 +77,9 @@ def apply_red_agent_iers(
|
||||
pass
|
||||
else:
|
||||
# It's not a switch or an actuator (so active node)
|
||||
# TODO: this occurs after ruling out the possibility that the node is a switch or an actuator, but it
|
||||
# could still be a passive/active node, therefore it won't have a hardware_state. The logic here needs
|
||||
# to change according to duck typing.
|
||||
if source_node.hardware_state == HardwareState.ON:
|
||||
if source_node.has_service(protocol):
|
||||
# Red agents IERs can only be valid if the source service is in a compromised state
|
||||
@@ -213,7 +219,7 @@ def apply_red_agent_node_pol(
|
||||
iers: Dict[str, IER],
|
||||
node_pol: Dict[str, NodeStateInstructionRed],
|
||||
step: int,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Applies node pattern of life.
|
||||
|
||||
@@ -267,8 +273,7 @@ def apply_red_agent_node_pol(
|
||||
# Do nothing, service not on this node
|
||||
pass
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("Node Red Agent PoL not allowed - misconfiguration")
|
||||
_LOGGER.warning("Node Red Agent PoL not allowed - misconfiguration")
|
||||
|
||||
# Only apply the PoL if the checks have passed (based on the initiator type)
|
||||
if passed_checks:
|
||||
@@ -289,8 +294,7 @@ def apply_red_agent_node_pol(
|
||||
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
|
||||
target_node.set_file_system_state(state)
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("Node Red Agent PoL not allowed - did not pass checks")
|
||||
_LOGGER.debug("Node Red Agent PoL not allowed - did not pass checks")
|
||||
else:
|
||||
# PoL is not valid in this time step
|
||||
pass
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, Final, Optional, Union
|
||||
from typing import Any, Dict, Final, Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
@@ -32,7 +32,7 @@ class PrimaiteSession:
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
The PrimaiteSession constructor.
|
||||
|
||||
@@ -72,7 +72,13 @@ class PrimaiteSession:
|
||||
self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
|
||||
def setup(self):
|
||||
self._agent_session: AgentSessionABC = None # noqa
|
||||
self.session_path: Path = None # noqa
|
||||
self.timestamp_str: str = None # noqa
|
||||
self.learning_path: Path = None # noqa
|
||||
self.evaluation_path: Path = None # noqa
|
||||
|
||||
def setup(self) -> None:
|
||||
"""Performs the session setup."""
|
||||
if self._training_config.agent_framework == AgentFramework.CUSTOM:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}")
|
||||
@@ -155,8 +161,8 @@ class PrimaiteSession:
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
@@ -167,8 +173,8 @@ class PrimaiteSession:
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
@@ -177,6 +183,6 @@ class PrimaiteSession:
|
||||
if not self._training_config.session_type == SessionType.TRAIN:
|
||||
self._agent_session.evaluate(**kwargs)
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Closes the agent."""
|
||||
self._agent_session.close()
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run():
|
||||
def run() -> None:
|
||||
"""Perform the full clean-up."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -2,16 +2,17 @@
|
||||
import filecmp
|
||||
import os
|
||||
import shutil
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from primaite import getLogger, NOTEBOOKS_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def run(overwrite_existing: bool = True):
|
||||
def run(overwrite_existing: bool = True) -> None:
|
||||
"""
|
||||
Resets the demo jupyter notebooks in the users app notebooks directory.
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from primaite import getLogger, USERS_CONFIG_DIR
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run(overwrite_existing=True):
|
||||
def run(overwrite_existing: bool = True) -> None:
|
||||
"""
|
||||
Resets the example config files in the users app config directory.
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from logging import Logger
|
||||
|
||||
from primaite import _USER_DIRS, getLogger, LOG_DIR, NOTEBOOKS_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def run():
|
||||
def run() -> None:
|
||||
"""
|
||||
Handles creation of application directories and user directories.
|
||||
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""The Transaction class."""
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from primaite.common.enums import AgentIdentifier
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
|
||||
class Transaction(object):
|
||||
"""Transaction class."""
|
||||
|
||||
def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int):
|
||||
def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int) -> None:
|
||||
"""
|
||||
Transaction constructor.
|
||||
|
||||
@@ -17,7 +21,7 @@ class Transaction(object):
|
||||
:param episode_number: The episode number
|
||||
:param step_number: The step number
|
||||
"""
|
||||
self.timestamp = datetime.now()
|
||||
self.timestamp: datetime = datetime.now()
|
||||
"The datetime of the transaction"
|
||||
self.agent_identifier: AgentIdentifier = agent_identifier
|
||||
"The agent identifier"
|
||||
@@ -25,17 +29,17 @@ class Transaction(object):
|
||||
"The episode number"
|
||||
self.step_number: int = step_number
|
||||
"The step number"
|
||||
self.obs_space = None
|
||||
self.obs_space: "spaces.Space" = None
|
||||
"The observation space (pre)"
|
||||
self.obs_space_pre = None
|
||||
self.obs_space_pre: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None
|
||||
"The observation space before any actions are taken"
|
||||
self.obs_space_post = None
|
||||
self.obs_space_post: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None
|
||||
"The observation space after any actions are taken"
|
||||
self.reward: float = None
|
||||
self.reward: Optional[float] = None
|
||||
"The reward value"
|
||||
self.action_space = None
|
||||
self.action_space: Optional[int] = None
|
||||
"The action space invoked by the agent"
|
||||
self.obs_space_description = None
|
||||
self.obs_space_description: Optional[List[str]] = None
|
||||
"The env observation space description"
|
||||
|
||||
def as_csv_data(self) -> Tuple[List, List]:
|
||||
@@ -68,7 +72,7 @@ class Transaction(object):
|
||||
return header, row
|
||||
|
||||
|
||||
def _turn_action_space_to_array(action_space) -> List[str]:
|
||||
def _turn_action_space_to_array(action_space: Union[int, List[int]]) -> List[str]:
|
||||
"""
|
||||
Turns action space into a string array so it can be saved to csv.
|
||||
|
||||
@@ -81,7 +85,7 @@ def _turn_action_space_to_array(action_space) -> List[str]:
|
||||
return [str(action_space)]
|
||||
|
||||
|
||||
def _turn_obs_space_to_array(obs_space, obs_assets, obs_features) -> List[str]:
|
||||
def _turn_obs_space_to_array(obs_space: "np.ndarray", obs_assets: int, obs_features: int) -> List[str]:
|
||||
"""
|
||||
Turns observation space into a string array so it can be saved to csv.
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import os
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def get_file_path(path: str) -> Path:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -10,7 +10,7 @@ from primaite import getLogger
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def parse_session_metadata(session_path: Union[Path, str], dict_only=False):
|
||||
def parse_session_metadata(session_path: Union[Path, str], dict_only: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Loads a session metadata from the given directory path.
|
||||
|
||||
|
||||
@@ -7,6 +7,9 @@ from primaite import getLogger
|
||||
from primaite.transactions.transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from io import TextIOWrapper
|
||||
from pathlib import Path
|
||||
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
@@ -29,7 +32,7 @@ class SessionOutputWriter:
|
||||
env: "Primaite",
|
||||
transaction_writer: bool = False,
|
||||
learning_session: bool = True,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the Session Output Writer.
|
||||
|
||||
@@ -42,15 +45,16 @@ class SessionOutputWriter:
|
||||
determines the name of the folder which contains the final output csv. Defaults to True
|
||||
:type learning_session: bool, optional
|
||||
"""
|
||||
self._env = env
|
||||
self.transaction_writer = transaction_writer
|
||||
self.learning_session = learning_session
|
||||
self._env: "Primaite" = env
|
||||
self.transaction_writer: bool = transaction_writer
|
||||
self.learning_session: bool = learning_session
|
||||
|
||||
if self.transaction_writer:
|
||||
fn = f"all_transactions_{self._env.timestamp_str}.csv"
|
||||
else:
|
||||
fn = f"average_reward_per_episode_{self._env.timestamp_str}.csv"
|
||||
|
||||
self._csv_file_path: "Path"
|
||||
if self.learning_session:
|
||||
self._csv_file_path = self._env.session_path / "learning" / fn
|
||||
else:
|
||||
@@ -58,26 +62,26 @@ class SessionOutputWriter:
|
||||
|
||||
self._csv_file_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
self._csv_file = None
|
||||
self._csv_writer = None
|
||||
self._csv_file: "TextIOWrapper" = None
|
||||
self._csv_writer: "csv._writer" = None
|
||||
|
||||
self._first_write: bool = True
|
||||
|
||||
def _init_csv_writer(self):
|
||||
def _init_csv_writer(self) -> None:
|
||||
self._csv_file = open(self._csv_file_path, "w", encoding="UTF8", newline="")
|
||||
|
||||
self._csv_writer = csv.writer(self._csv_file)
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Close the cvs file."""
|
||||
if self._csv_file:
|
||||
self._csv_file.close()
|
||||
_LOGGER.debug(f"Finished writing file: {self._csv_file_path}")
|
||||
|
||||
def write(self, data: Union[Tuple, Transaction]):
|
||||
def write(self, data: Union[Tuple, Transaction]) -> None:
|
||||
"""
|
||||
Write a row of session data.
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -92,6 +92,7 @@
|
||||
destination: 192.168.1.2
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 0
|
||||
- item_type: ACL_RULE
|
||||
id: '7'
|
||||
permission: ALLOW
|
||||
@@ -99,3 +100,4 @@
|
||||
destination: 192.168.1.1
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 0
|
||||
|
||||
86
tests/config/obs_tests/laydown_ACL.yaml
Normal file
86
tests/config/obs_tests/laydown_ACL.yaml
Normal file
@@ -0,0 +1,86 @@
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '80'
|
||||
- port: '21'
|
||||
- item_type: SERVICES
|
||||
service_list:
|
||||
- name: TCP
|
||||
- name: FTP
|
||||
|
||||
########################################
|
||||
# Nodes
|
||||
- item_type: NODE
|
||||
node_id: '1'
|
||||
name: PC1
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.1
|
||||
software_state: COMPROMISED
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: FTP
|
||||
port: '21'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '2'
|
||||
name: SERVER
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.2
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: FTP
|
||||
port: '21'
|
||||
state: OVERWHELMED
|
||||
- item_type: NODE
|
||||
node_id: '3'
|
||||
name: SWITCH1
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.3
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
|
||||
########################################
|
||||
# Links
|
||||
- item_type: LINK
|
||||
id: '4'
|
||||
name: link1
|
||||
bandwidth: 1000
|
||||
source: '1'
|
||||
destination: '3'
|
||||
- item_type: LINK
|
||||
id: '5'
|
||||
name: link2
|
||||
bandwidth: 1000
|
||||
source: '3'
|
||||
destination: '2'
|
||||
|
||||
#########################################
|
||||
# IERS
|
||||
- item_type: GREEN_IER
|
||||
id: '5'
|
||||
start_step: 0
|
||||
end_step: 5
|
||||
load: 999
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '2'
|
||||
mission_criticality: 5
|
||||
|
||||
#########################################
|
||||
# ACL Rules
|
||||
106
tests/config/obs_tests/main_config_ACCESS_CONTROL_LIST.yaml
Normal file
106
tests/config/obs_tests/main_config_ACCESS_CONTROL_LIST.yaml
Normal file
@@ -0,0 +1,106 @@
|
||||
# Main Config File
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agent_framework: SB3
|
||||
agent_identifier: PPO
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: ANY
|
||||
# Number of episodes for training to run per session
|
||||
num_train_episodes: 1
|
||||
# Number of time_steps for training per episode
|
||||
num_train_steps: 5
|
||||
|
||||
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
|
||||
implicit_acl_rule: DENY
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 3
|
||||
|
||||
observation_space:
|
||||
components:
|
||||
- name: ACCESS_CONTROL_LIST
|
||||
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 1
|
||||
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAIN
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1_000_000_000
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0
|
||||
# Node Hardware State
|
||||
off_should_be_on: -10
|
||||
off_should_be_resetting: -5
|
||||
on_should_be_off: -2
|
||||
on_should_be_resetting: -5
|
||||
resetting_should_be_on: -5
|
||||
resetting_should_be_off: -2
|
||||
resetting: -3
|
||||
# Node Software or Service State
|
||||
good_should_be_patching: 2
|
||||
good_should_be_compromised: 5
|
||||
good_should_be_overwhelmed: 5
|
||||
patching_should_be_good: -5
|
||||
patching_should_be_compromised: 2
|
||||
patching_should_be_overwhelmed: 2
|
||||
patching: -3
|
||||
compromised_should_be_good: -20
|
||||
compromised_should_be_patching: -20
|
||||
compromised_should_be_overwhelmed: -20
|
||||
compromised: -20
|
||||
overwhelmed_should_be_good: -20
|
||||
overwhelmed_should_be_patching: -20
|
||||
overwhelmed_should_be_compromised: -20
|
||||
overwhelmed: -20
|
||||
# Node File System State
|
||||
good_should_be_repairing: 2
|
||||
good_should_be_restoring: 2
|
||||
good_should_be_corrupt: 5
|
||||
good_should_be_destroyed: 10
|
||||
repairing_should_be_good: -5
|
||||
repairing_should_be_restoring: 2
|
||||
repairing_should_be_corrupt: 2
|
||||
repairing_should_be_destroyed: 0
|
||||
repairing: -3
|
||||
restoring_should_be_good: -10
|
||||
restoring_should_be_repairing: -2
|
||||
restoring_should_be_corrupt: 1
|
||||
restoring_should_be_destroyed: 2
|
||||
restoring: -6
|
||||
corrupt_should_be_good: -10
|
||||
corrupt_should_be_repairing: -10
|
||||
corrupt_should_be_restoring: -10
|
||||
corrupt_should_be_destroyed: 2
|
||||
corrupt: -10
|
||||
destroyed_should_be_good: -20
|
||||
destroyed_should_be_repairing: -20
|
||||
destroyed_should_be_restoring: -20
|
||||
destroyed_should_be_corrupt: -20
|
||||
destroyed: -20
|
||||
scanning: -2
|
||||
# IER status
|
||||
red_ier_running: -5
|
||||
green_ier_blocked: -10
|
||||
|
||||
# Patching / Reset durations
|
||||
os_patching_duration: 5 # The time taken to patch the OS
|
||||
node_reset_duration: 5 # The time taken to reset a node (hardware)
|
||||
service_patching_duration: 5 # The time taken to patch a service
|
||||
file_system_repairing_limit: 5 # The time take to repair the file system
|
||||
file_system_restoring_limit: 5 # The time take to restore the file system
|
||||
file_system_scanning_limit: 5 # The time taken to scan the file system
|
||||
@@ -39,6 +39,11 @@ observation_space:
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 1
|
||||
|
||||
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
|
||||
implicit_acl_rule: ALLOW
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 4
|
||||
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAIN
|
||||
# Determine whether to load an agent from file
|
||||
|
||||
@@ -37,6 +37,11 @@ observation_space:
|
||||
time_delay: 1
|
||||
# Filename of the scenario / laydown
|
||||
|
||||
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
|
||||
implicit_acl_rule: ALLOW
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 4
|
||||
|
||||
session_type: TRAIN
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
|
||||
@@ -40,7 +40,8 @@ agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1_000_000_000
|
||||
|
||||
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
|
||||
implicit_acl_rule: DENY
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0
|
||||
|
||||
@@ -91,6 +91,8 @@ session_type: EVAL
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
implicit_acl_rule: DENY
|
||||
max_number_acl_rules: 10
|
||||
# The Stable Baselines3 learn/eval output verbosity level:
|
||||
# Options are:
|
||||
# "NONE" (No Output)
|
||||
|
||||
@@ -36,7 +36,7 @@ random_red_agent: False
|
||||
# Default is None (null)
|
||||
seed: None
|
||||
|
||||
# Set whether the agent will be deterministic instead of stochastic
|
||||
# Set whether the agent evaluation will be deterministic instead of stochastic
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
@@ -55,11 +55,11 @@ hard_coded_agent_view: FULL
|
||||
action_type: NODE
|
||||
# observation space
|
||||
observation_space:
|
||||
# flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
# - name: NODE_STATUSES
|
||||
# - name: LINK_TRAFFIC_LEVELS
|
||||
# - name: ACCESS_CONTROL_LIST
|
||||
# Number of episodes to run per session
|
||||
num_train_episodes: 10
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ random_red_agent: False
|
||||
# Default is None (null)
|
||||
seed: 67890
|
||||
|
||||
# Set whether the agent will be deterministic instead of stochastic
|
||||
# Set whether the agent evaluation will be deterministic instead of stochastic
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
@@ -55,7 +55,6 @@ hard_coded_agent_view: FULL
|
||||
action_type: NODE
|
||||
# observation space
|
||||
observation_space:
|
||||
# flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
# - name: NODE_STATUSES
|
||||
|
||||
@@ -38,6 +38,15 @@ load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
|
||||
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
|
||||
implicit_acl_rule: DENY
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 10
|
||||
|
||||
observation_space:
|
||||
components:
|
||||
- name: ACCESS_CONTROL_LIST
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '21'
|
||||
- port: '80'
|
||||
- item_type: SERVICES
|
||||
service_list:
|
||||
- name: ftp
|
||||
- name: TCP
|
||||
- item_type: NODE
|
||||
node_id: '1'
|
||||
name: node
|
||||
@@ -16,8 +16,8 @@
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: ftp
|
||||
port: '21'
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: COMPROMISED
|
||||
- item_type: NODE
|
||||
node_id: '2'
|
||||
@@ -30,15 +30,15 @@
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: ftp
|
||||
port: '21'
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: COMPROMISED
|
||||
- item_type: RED_IER
|
||||
id: '3'
|
||||
start_step: 2
|
||||
end_step: 15
|
||||
load: 1000
|
||||
protocol: ftp
|
||||
protocol: TCP
|
||||
port: CORRUPT
|
||||
source: '1'
|
||||
destination: '2'
|
||||
|
||||
@@ -47,6 +47,9 @@ agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
# Choice whether to have an ALLOW or DENY implicit rule or not (TRUE or FALSE)
|
||||
implicit_acl_rule: DENY
|
||||
max_number_acl_rules: 10
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0
|
||||
|
||||
@@ -3,40 +3,41 @@
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.common.enums import RulePermissionType
|
||||
|
||||
|
||||
def test_acl_address_match_1():
|
||||
"""Test that matching IP addresses produce True."""
|
||||
acl = AccessControlList()
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
|
||||
|
||||
|
||||
def test_acl_address_match_2():
|
||||
"""Test that mismatching IP addresses produce False."""
|
||||
acl = AccessControlList()
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.3") == False
|
||||
|
||||
|
||||
def test_acl_address_match_3():
|
||||
"""Test the ANY condition for source IP addresses produce True."""
|
||||
acl = AccessControlList()
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule("ALLOW", "ANY", "192.168.1.2", "TCP", "80")
|
||||
rule = ACLRule(RulePermissionType.ALLOW, "ANY", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
|
||||
|
||||
|
||||
def test_acl_address_match_4():
|
||||
"""Test the ANY condition for dest IP addresses produce True."""
|
||||
acl = AccessControlList()
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule("ALLOW", "192.168.1.1", "ANY", "TCP", "80")
|
||||
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "ANY", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
|
||||
|
||||
@@ -44,14 +45,15 @@ def test_acl_address_match_4():
|
||||
def test_check_acl_block_affirmative():
|
||||
"""Test the block function (affirmative)."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList()
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
# Create a rule
|
||||
acl_rule_permission = "ALLOW"
|
||||
acl_rule_permission = RulePermissionType.ALLOW
|
||||
acl_rule_source = "192.168.1.1"
|
||||
acl_rule_destination = "192.168.1.2"
|
||||
acl_rule_protocol = "TCP"
|
||||
acl_rule_port = "80"
|
||||
acl_position_in_list = "0"
|
||||
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
@@ -59,22 +61,23 @@ def test_check_acl_block_affirmative():
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_position_in_list,
|
||||
)
|
||||
|
||||
assert acl.is_blocked("192.168.1.1", "192.168.1.2", "TCP", "80") == False
|
||||
|
||||
|
||||
def test_check_acl_block_negative():
|
||||
"""Test the block function (negative)."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList()
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
# Create a rule
|
||||
acl_rule_permission = "DENY"
|
||||
acl_rule_permission = RulePermissionType.DENY
|
||||
acl_rule_source = "192.168.1.1"
|
||||
acl_rule_destination = "192.168.1.2"
|
||||
acl_rule_protocol = "TCP"
|
||||
acl_rule_port = "80"
|
||||
acl_position_in_list = "0"
|
||||
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
@@ -82,6 +85,7 @@ def test_check_acl_block_negative():
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_position_in_list,
|
||||
)
|
||||
|
||||
assert acl.is_blocked("192.168.1.1", "192.168.1.2", "TCP", "80") == True
|
||||
@@ -90,11 +94,73 @@ def test_check_acl_block_negative():
|
||||
def test_rule_hash():
|
||||
"""Test the rule hash."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList()
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
rule = ACLRule(RulePermissionType.DENY, "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
hash_value_local = hash(rule)
|
||||
|
||||
hash_value_remote = acl.get_dictionary_hash("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
hash_value_remote = acl.get_dictionary_hash(RulePermissionType.DENY, "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert hash_value_local == hash_value_remote
|
||||
|
||||
|
||||
def test_delete_rule():
|
||||
"""Adds 3 rules and deletes 1 rule and checks its deletion."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList(RulePermissionType.ALLOW, 10)
|
||||
|
||||
# Create a first rule
|
||||
acl_rule_permission = RulePermissionType.DENY
|
||||
acl_rule_source = "192.168.1.1"
|
||||
acl_rule_destination = "192.168.1.2"
|
||||
acl_rule_protocol = "TCP"
|
||||
acl_rule_port = "80"
|
||||
acl_position_in_list = "0"
|
||||
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_position_in_list,
|
||||
)
|
||||
|
||||
# Create a second rule
|
||||
acl_rule_permission = RulePermissionType.DENY
|
||||
acl_rule_source = "20"
|
||||
acl_rule_destination = "30"
|
||||
acl_rule_protocol = "FTP"
|
||||
acl_rule_port = "21"
|
||||
acl_position_in_list = "2"
|
||||
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_position_in_list,
|
||||
)
|
||||
|
||||
# Create a third rule
|
||||
acl_rule_permission = RulePermissionType.ALLOW
|
||||
acl_rule_source = "192.168.1.3"
|
||||
acl_rule_destination = "192.168.1.1"
|
||||
acl_rule_protocol = "UDP"
|
||||
acl_rule_port = "60"
|
||||
acl_position_in_list = "4"
|
||||
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_position_in_list,
|
||||
)
|
||||
# Remove the second ACL rule added from the list
|
||||
acl.remove_rule(RulePermissionType.DENY, "20", "30", "FTP", "21")
|
||||
|
||||
assert len(acl.acl) == 10
|
||||
assert acl.acl[2] is None
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Test env creation and behaviour with different observation spaces."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
@@ -237,3 +238,140 @@ class TestLinkTrafficLevels:
|
||||
# therefore the first and third elements should be 6 and all others 0
|
||||
# (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%)
|
||||
assert np.array_equal(obs, [6, 0, 6, 0])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "obs_tests/main_config_ACCESS_CONTROL_LIST.yaml",
|
||||
TEST_CONFIG_ROOT / "obs_tests/laydown_ACL.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
class TestAccessControlList:
|
||||
"""Test the AccessControlList observation component (in isolation)."""
|
||||
|
||||
def test_obs_shape(self, temp_primaite_session):
|
||||
"""Try creating env with MultiDiscrete observation space.
|
||||
|
||||
The laydown has 3 ACL Rules - that is the maximum_acl_rules it can have.
|
||||
Each ACL Rule in the observation space has 6 different elements:
|
||||
|
||||
6 * 3 = 18
|
||||
"""
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
env.update_environent_obs()
|
||||
|
||||
assert env.env_obs.shape == (18,)
|
||||
|
||||
def test_values(self, temp_primaite_session):
|
||||
"""Test that traffic values are encoded correctly.
|
||||
|
||||
The laydown has:
|
||||
* one ACL IMPLICIT DENY rule
|
||||
|
||||
Therefore, the ACL is full of NAs aka zeros and just 6 non-zero elements representing DENY ANY ANY ANY at
|
||||
Position 2.
|
||||
"""
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
obs, reward, done, info = env.step(0)
|
||||
obs, reward, done, info = env.step(0)
|
||||
|
||||
assert np.array_equal(obs, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2])
|
||||
|
||||
def test_observation_space_with_implicit_rule(self, temp_primaite_session):
|
||||
"""
|
||||
Test observation space is what is expected when an agent adds ACLs during an episode.
|
||||
|
||||
At the start of the episode, there is a single implicit DENY rule
|
||||
In the observation space IMPLICIT DENY: 1,1,1,1,1,0
|
||||
0 shows the rule is the start (when episode began no other rules were created) so this is correct.
|
||||
|
||||
On Step 2, there is an ACL rule added at Position 0: 2,2,3,2,3,0
|
||||
|
||||
On Step 4, there is a second ACL rule added at POSITION 1: 2,4,2,3,3,1
|
||||
|
||||
The final observation space should be this:
|
||||
[2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2]
|
||||
|
||||
The ACL Rule from Step 2 is added first and has a HIGHER position than the ACL rule from Step 4
|
||||
but both come before the IMPLICIT DENY which will ALWAYS be at the end of the ACL List.
|
||||
"""
|
||||
# TODO: Refactor this at some point to build a custom ACL Hardcoded
|
||||
# Agent and then patch the AgentIdentifier Enum class so that it
|
||||
# has ACL_AGENT. This then allows us to set the agent identified in
|
||||
# the main config and is a bit cleaner.
|
||||
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
training_config = env.training_config
|
||||
for episode in range(0, training_config.num_train_episodes):
|
||||
for step in range(0, training_config.num_train_steps):
|
||||
# Do nothing action
|
||||
action = 0
|
||||
if step == 2:
|
||||
# Action to add the first ACL rule
|
||||
action = 43
|
||||
elif step == 4:
|
||||
# Action to add the second ACL rule
|
||||
action = 96
|
||||
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
# Break if done is True
|
||||
if done:
|
||||
break
|
||||
obs = env.env_obs
|
||||
|
||||
assert np.array_equal(obs, [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2])
|
||||
|
||||
def test_observation_space_with_different_positions(self, temp_primaite_session):
|
||||
"""
|
||||
Test observation space is what is expected when an agent adds ACLs during an episode.
|
||||
|
||||
At the start of the episode, there is a single implicit DENY rule
|
||||
In the observation space IMPLICIT DENY: 1,1,1,1,1,0
|
||||
0 shows the rule is the start (when episode began no other rules were created) so this is correct.
|
||||
|
||||
On Step 2, there is an ACL rule added at Position 1: 2,2,3,2,3,1
|
||||
|
||||
On Step 4 there is a second ACL rule added at Position 0: 2,4,2,3,3,0
|
||||
|
||||
The final observation space should be this:
|
||||
[2 , 4, 2, 3, 3, 0, 2, 2, 3, 2, 3, 1, 1, 1, 1, 1, 1, 2]
|
||||
|
||||
The ACL Rule from Step 2 is added before and has a LOWER position than the ACL rule from Step 4
|
||||
but both come before the IMPLICIT DENY which will ALWAYS be at the end of the ACL List.
|
||||
"""
|
||||
# TODO: Refactor this at some point to build a custom ACL Hardcoded
|
||||
# Agent and then patch the AgentIdentifier Enum class so that it
|
||||
# has ACL_AGENT. This then allows us to set the agent identified in
|
||||
# the main config and is a bit cleaner.
|
||||
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
training_config = env.training_config
|
||||
for episode in range(0, training_config.num_train_episodes):
|
||||
for step in range(0, training_config.num_train_steps):
|
||||
# Do nothing action
|
||||
action = 0
|
||||
if step == 2:
|
||||
# Action to add the first ACL rule
|
||||
action = 44
|
||||
elif step == 4:
|
||||
# Action to add the second ACL rule
|
||||
action = 95
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
# Break if done is True
|
||||
if done:
|
||||
break
|
||||
obs = env.env_obs
|
||||
|
||||
assert np.array_equal(obs, [2, 4, 2, 3, 3, 0, 2, 2, 3, 2, 3, 1, 1, 1, 1, 1, 1, 2])
|
||||
|
||||
@@ -11,30 +11,46 @@ from tests import TEST_CONFIG_ROOT
|
||||
indirect=True,
|
||||
)
|
||||
def test_seeded_learning(temp_primaite_session):
|
||||
"""Test running seeded learning produces the same output when ran twice."""
|
||||
"""
|
||||
Test running seeded learning produces the same output when ran twice.
|
||||
|
||||
.. note::
|
||||
|
||||
If this is failing, the hard-coded expected_mean_reward_per_episode
|
||||
from a pre-trained agent will probably need to be updated. If the
|
||||
env changes and those changed how this agent is trained, chances are
|
||||
the mean rewards are going to be different.
|
||||
|
||||
Run the test, but print out the session.learn_av_reward_per_episode()
|
||||
before comparing it. Then copy the printed dict and replace the
|
||||
expected_mean_reward_per_episode with those values. The test should
|
||||
now work. If not, then you've got a bug :).
|
||||
"""
|
||||
expected_mean_reward_per_episode = {
|
||||
1: -90.703125,
|
||||
2: -91.15234375,
|
||||
3: -87.5,
|
||||
4: -92.2265625,
|
||||
5: -94.6875,
|
||||
6: -91.19140625,
|
||||
7: -88.984375,
|
||||
8: -88.3203125,
|
||||
9: -112.79296875,
|
||||
10: -100.01953125,
|
||||
1: -20.7421875,
|
||||
2: -19.82421875,
|
||||
3: -17.01171875,
|
||||
4: -19.08203125,
|
||||
5: -21.93359375,
|
||||
6: -20.21484375,
|
||||
7: -15.546875,
|
||||
8: -12.08984375,
|
||||
9: -17.59765625,
|
||||
10: -14.6875,
|
||||
}
|
||||
|
||||
with temp_primaite_session as session:
|
||||
assert session._training_config.seed == 67890, (
|
||||
"Expected output is based upon a agent that was trained with " "seed 67890"
|
||||
)
|
||||
assert (
|
||||
session._training_config.seed == 67890
|
||||
), "Expected output is based upon a agent that was trained with seed 67890"
|
||||
session.learn()
|
||||
actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict()
|
||||
print(actual_mean_reward_per_episode, "THISt")
|
||||
|
||||
assert actual_mean_reward_per_episode == expected_mean_reward_per_episode
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Inconsistent results. Needs someone with RL " "knowledge to investigate further.")
|
||||
@pytest.mark.skip(reason="Inconsistent results. Needs someone with RL knowledge to investigate further.")
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[[TEST_CONFIG_ROOT / "ppo_seeded_training_config.yaml", dos_very_basic_config_path()]],
|
||||
|
||||
@@ -6,6 +6,8 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.sb3 import SB3Agent
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
@@ -41,6 +43,9 @@ def copy_session_asset(asset_path: Union[str, Path]) -> str:
|
||||
return copy_path
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Loading works fine but the exact values change with code changes, a bug report has been created."
|
||||
)
|
||||
def test_load_sb3_session():
|
||||
"""Test that loading an SB3 agent works."""
|
||||
expected_learn_mean_reward_per_episode = {
|
||||
@@ -97,6 +102,7 @@ def test_load_sb3_session():
|
||||
shutil.rmtree(test_path)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Temporarily don't worry about this not working")
|
||||
def test_load_primaite_session():
|
||||
"""Test that loading a Primaite session works."""
|
||||
expected_learn_mean_reward_per_episode = {
|
||||
@@ -157,6 +163,7 @@ def test_load_primaite_session():
|
||||
shutil.rmtree(test_path)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Temporarily don't worry about this not working")
|
||||
def test_run_loading():
|
||||
"""Test loading session via main.run."""
|
||||
expected_learn_mean_reward_per_episode = {
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.common.enums import HardwareState
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
@@ -19,16 +20,17 @@ def run_generic_set_actions(env: Primaite):
|
||||
# TEMP - random action for now
|
||||
# action = env.blue_agent_action(obs)
|
||||
action = 0
|
||||
# print("Episode:", episode, "\nStep:", step)
|
||||
if step == 5:
|
||||
# [1, 1, 2, 1, 1, 1]
|
||||
# [1, 1, 2, 1, 1, 1, 1(position)]
|
||||
# Creates an ACL rule
|
||||
# Allows traffic from server_1 to node_1 on port FTP
|
||||
action = 7
|
||||
action = 56
|
||||
elif step == 7:
|
||||
# [1, 1, 2, 0] Node Action
|
||||
# Sets Node 1 Hardware State to OFF
|
||||
# Does not resolve any service
|
||||
action = 16
|
||||
action = 128
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
@@ -57,6 +59,10 @@ def run_generic_set_actions(env: Primaite):
|
||||
)
|
||||
def test_single_action_space_is_valid(temp_primaite_session):
|
||||
"""Test single action space is valid."""
|
||||
# TODO: Refactor this at some point to build a custom ACL Hardcoded
|
||||
# Agent and then patch the AgentIdentifier Enum class so that it
|
||||
# has ACL_AGENT. This then allows us to set the agent identified in
|
||||
# the main config and is a bit cleaner.
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
|
||||
@@ -73,7 +79,7 @@ def test_single_action_space_is_valid(temp_primaite_session):
|
||||
if len(dict_item) == 4:
|
||||
contains_node_actions = True
|
||||
# Link action detected
|
||||
elif len(dict_item) == 6:
|
||||
elif len(dict_item) == 7:
|
||||
contains_acl_actions = True
|
||||
# If both are there then the ANY action type is working
|
||||
if contains_node_actions and contains_acl_actions:
|
||||
@@ -94,6 +100,10 @@ def test_single_action_space_is_valid(temp_primaite_session):
|
||||
)
|
||||
def test_agent_is_executing_actions_from_both_spaces(temp_primaite_session):
|
||||
"""Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL)."""
|
||||
# TODO: Refactor this at some point to build a custom ACL Hardcoded
|
||||
# Agent and then patch the AgentIdentifier Enum class so that it
|
||||
# has ACL_AGENT. This then allows us to set the agent identified in
|
||||
# the main config and is a bit cleaner.
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
# Run environment with specified fixed blue agent actions only
|
||||
@@ -105,11 +115,15 @@ def test_agent_is_executing_actions_from_both_spaces(temp_primaite_session):
|
||||
access_control_list = env.acl
|
||||
# Use the Access Control List object acl object attribute to get dictionary
|
||||
# Use dictionary.values() to get total list of all items in the dictionary
|
||||
acl_rules_list = access_control_list.acl.values()
|
||||
acl_rules_list = access_control_list.acl
|
||||
# Length of this list tells you how many items are in the dictionary
|
||||
# This number is the frequency of Access Control Rules in the environment
|
||||
# In the scenario, we specified that the agent should create only 1 acl rule
|
||||
num_of_rules = len(acl_rules_list)
|
||||
# This 1 rule added to the implicit deny means there should be 2 rules in total.
|
||||
rules_count = 0
|
||||
for rule in acl_rules_list:
|
||||
if isinstance(rule, ACLRule):
|
||||
rules_count += 1
|
||||
# Therefore these statements below MUST be true
|
||||
assert computer_node_hardware_state == HardwareState.OFF
|
||||
assert num_of_rules == 1
|
||||
assert rules_count == 2
|
||||
|
||||
Reference in New Issue
Block a user