Merge branch 'dev' into feature/1635_Update_Primaite_Session_page_in_Docs
# Conflicts: # src/primaite/setup/old_installation_clean_up.py # src/primaite/setup/reset_example_configs.py
This commit is contained in:
@@ -66,11 +66,11 @@ The environment config file consists of the following attributes:
|
||||
.. code-block:: yaml
|
||||
|
||||
observation_space:
|
||||
flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
- name: NODE_STATUSES
|
||||
- name: LINK_TRAFFIC_LEVELS
|
||||
- name: ACCESS_CONTROL_LIST
|
||||
options:
|
||||
combine_service_traffic : False
|
||||
quantisation_levels: 99
|
||||
@@ -80,6 +80,7 @@ The environment config file consists of the following attributes:
|
||||
|
||||
* :py:mod:`NODE_LINK_TABLE<primaite.environment.observations.NodeLinkTable>` this does not accept any additional options
|
||||
* :py:mod:`NODE_STATUSES<primaite.environment.observations.NodeStatuses>`, this does not accept any additional options
|
||||
* :py:mod:`ACCESS_CONTROL_LIST<primaite.environment.observations.AccessControlList>`, this does not accept additional options
|
||||
* :py:mod:`LINK_TRAFFIC_LEVELS<primaite.environment.observations.LinkTrafficLevels>`, this accepts the following options:
|
||||
|
||||
* ``combine_service_traffic`` - whether to consider bandwidth use separately for each network protocol or combine them into a single bandwidth reading (boolean)
|
||||
@@ -128,6 +129,14 @@ The environment config file consists of the following attributes:
|
||||
|
||||
The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases
|
||||
|
||||
* **implicit_acl_rule** [str]
|
||||
|
||||
Determines which Explicit rule the ACL list has - two options are: DENY or ALLOW.
|
||||
|
||||
* **max_number_acl_rules** [int]
|
||||
|
||||
Sets a limit on how many ACL rules there can be in the ACL list throughout the training session.
|
||||
|
||||
**Reward-Based Config Values**
|
||||
|
||||
Rewards are calculated based on the difference between the current state and reference state (the 'should be' state) of the environment.
|
||||
@@ -477,3 +486,4 @@ The lay down config file consists of the following attributes:
|
||||
* **destination** [IP address]: Defines the destination IP address for the rule in xxx.xxx.xxx.xxx format
|
||||
* **protocol** [freetext]: Defines the protocol for the rule. Must match a value in the services list
|
||||
* **port** [int]: Defines the port for the rule. Must match a value in the ports list
|
||||
* **position** [int]: Defines where to place the ACL rule in the list. Lower index or (higher up in the list) means they are checked first. Index starts at 0 (Python indexes).
|
||||
|
||||
@@ -53,3 +53,5 @@ v1.2 to v2.0 Migration guide
|
||||
* hard coded agent view
|
||||
|
||||
Each of these items have default values which are designed so that PrimAITE has the same behaviour as it did in 1.2.0, so you do not have to specify them.
|
||||
|
||||
ACL Rules in laydown configs have a new required parameter: ``position``. The lower the position, the higher up in the ACL table the rule will placed. If you have custom laydowns, you will need to go through them and add a position to each ACL_RULE.
|
||||
|
||||
@@ -1,16 +1,38 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""A class that implements the access control list implementation for the network."""
|
||||
from typing import Dict
|
||||
import logging
|
||||
from typing import Dict, Final, List, Union
|
||||
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.common.enums import RulePermissionType
|
||||
|
||||
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AccessControlList:
|
||||
"""Access Control List class."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialise an empty AccessControlList."""
|
||||
self.acl: Dict[int, ACLRule] = {} # A dictionary of ACL Rules
|
||||
def __init__(self, implicit_permission: RulePermissionType, max_acl_rules: int) -> None:
|
||||
"""Init."""
|
||||
# Implicit ALLOW or DENY firewall spec
|
||||
self.acl_implicit_permission = implicit_permission
|
||||
# Implicit rule in ACL list
|
||||
if self.acl_implicit_permission == RulePermissionType.DENY:
|
||||
self.acl_implicit_rule = ACLRule(RulePermissionType.DENY, "ANY", "ANY", "ANY", "ANY")
|
||||
elif self.acl_implicit_permission == RulePermissionType.ALLOW:
|
||||
self.acl_implicit_rule = ACLRule(RulePermissionType.ALLOW, "ANY", "ANY", "ANY", "ANY")
|
||||
else:
|
||||
raise ValueError(f"implicit permission must be ALLOW or DENY, got {self.acl_implicit_permission}")
|
||||
|
||||
# Maximum number of ACL Rules in ACL
|
||||
self.max_acl_rules: int = max_acl_rules
|
||||
# A list of ACL Rules
|
||||
self._acl: List[Union[ACLRule, None]] = [None] * (self.max_acl_rules - 1)
|
||||
|
||||
@property
|
||||
def acl(self) -> List[Union[ACLRule, None]]:
|
||||
"""Public access method for private _acl."""
|
||||
return self._acl + [self.acl_implicit_rule]
|
||||
|
||||
def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool:
|
||||
"""Checks for IP address matches.
|
||||
@@ -47,21 +69,24 @@ class AccessControlList:
|
||||
Returns:
|
||||
Indicates block if all conditions are satisfied.
|
||||
"""
|
||||
for rule_key, rule_value in self.acl.items():
|
||||
if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address):
|
||||
if (rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY") and (
|
||||
str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY"
|
||||
):
|
||||
# There's a matching rule. Get the permission
|
||||
if rule_value.get_permission() == "DENY":
|
||||
return True
|
||||
elif rule_value.get_permission() == "ALLOW":
|
||||
return False
|
||||
for rule in self.acl:
|
||||
if isinstance(rule, ACLRule):
|
||||
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
|
||||
if (rule.get_protocol() == _protocol or rule.get_protocol() == "ANY") and (
|
||||
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY"
|
||||
):
|
||||
# There's a matching rule. Get the permission
|
||||
if rule.get_permission() == RulePermissionType.DENY:
|
||||
return True
|
||||
elif rule.get_permission() == RulePermissionType.ALLOW:
|
||||
return False
|
||||
|
||||
# If there has been no rule to allow the IER through, it will return a blocked signal by default
|
||||
return True
|
||||
|
||||
def add_rule(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None:
|
||||
def add_rule(
|
||||
self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str, _position: int
|
||||
) -> None:
|
||||
"""
|
||||
Adds a new rule.
|
||||
|
||||
@@ -71,12 +96,36 @@ class AccessControlList:
|
||||
_dest_ip: the destination IP address
|
||||
_protocol: the protocol
|
||||
_port: the port
|
||||
_position: position to insert ACL rule into ACL list (starting from index 1 and NOT 0)
|
||||
"""
|
||||
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
hash_value = hash(new_rule)
|
||||
self.acl[hash_value] = new_rule
|
||||
try:
|
||||
position_index = int(_position)
|
||||
except TypeError:
|
||||
_LOGGER.info(f"Position {_position} could not be converted to integer.")
|
||||
return
|
||||
|
||||
def remove_rule(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None:
|
||||
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
# Checks position is in correct range
|
||||
if self.max_acl_rules - 1 > position_index > -1:
|
||||
try:
|
||||
_LOGGER.info(f"Position {position_index} is valid.")
|
||||
# Check to see Agent will not overwrite current ACL in ACL list
|
||||
if self._acl[position_index] is None:
|
||||
_LOGGER.info(f"Inserting rule {new_rule} at position {position_index}")
|
||||
# Adds rule
|
||||
self._acl[position_index] = new_rule
|
||||
else:
|
||||
# Cannot overwrite it
|
||||
_LOGGER.info(f"Error: inserting rule at non-empty position {position_index}")
|
||||
return
|
||||
except Exception:
|
||||
_LOGGER.info(f"New Rule could NOT be added to list at position {position_index}.")
|
||||
else:
|
||||
_LOGGER.info(f"Position {position_index} is an invalid/overwrites implicit firewall rule")
|
||||
|
||||
def remove_rule(
|
||||
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
|
||||
) -> None:
|
||||
"""
|
||||
Removes a rule.
|
||||
|
||||
@@ -87,17 +136,17 @@ class AccessControlList:
|
||||
_protocol: the protocol
|
||||
_port: the port
|
||||
"""
|
||||
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
hash_value = hash(rule)
|
||||
# There will not always be something 'popable' since the agent will be trying random things
|
||||
try:
|
||||
self.acl.pop(hash_value)
|
||||
except Exception:
|
||||
return
|
||||
rule_to_delete = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
delete_rule_hash = hash(rule_to_delete)
|
||||
|
||||
for index in range(0, len(self._acl)):
|
||||
if isinstance(self._acl[index], ACLRule) and hash(self._acl[index]) == delete_rule_hash:
|
||||
self._acl[index] = None
|
||||
|
||||
def remove_all_rules(self) -> None:
|
||||
"""Removes all rules."""
|
||||
self.acl.clear()
|
||||
for i in range(len(self._acl)):
|
||||
self._acl[i] = None
|
||||
|
||||
def get_dictionary_hash(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> int:
|
||||
"""
|
||||
@@ -129,16 +178,13 @@ class AccessControlList:
|
||||
:return: Dictionary of all ACL rules that relate to the given arguments
|
||||
:rtype: Dict[int, ACLRule]
|
||||
"""
|
||||
relevant_rules: Dict[int, ACLRule] = {}
|
||||
|
||||
for rule_key, rule_value in self.acl.items():
|
||||
if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address):
|
||||
if (
|
||||
rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY" or _protocol == "ANY"
|
||||
) and (
|
||||
str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY" or str(_port) == "ANY"
|
||||
relevant_rules = {}
|
||||
for rule in self.acl:
|
||||
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
|
||||
if (rule.get_protocol() == _protocol or rule.get_protocol() == "ANY" or _protocol == "ANY") and (
|
||||
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY" or str(_port) == "ANY"
|
||||
):
|
||||
# There's a matching rule.
|
||||
relevant_rules[rule_key] = rule_value
|
||||
relevant_rules[self._acl.index(rule)] = rule
|
||||
|
||||
return relevant_rules
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""A class that implements an access control list rule."""
|
||||
from primaite.common.enums import RulePermissionType
|
||||
|
||||
|
||||
class ACLRule:
|
||||
"""Access Control List Rule class."""
|
||||
|
||||
def __init__(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None:
|
||||
def __init__(
|
||||
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an ACL Rule.
|
||||
|
||||
@@ -15,7 +18,7 @@ class ACLRule:
|
||||
:param _protocol: The rule protocol
|
||||
:param _port: The rule port
|
||||
"""
|
||||
self.permission: str = _permission
|
||||
self.permission: RulePermissionType = _permission
|
||||
self.source_ip: str = _source_ip
|
||||
self.dest_ip: str = _dest_ip
|
||||
self.protocol: str = _protocol
|
||||
|
||||
@@ -198,3 +198,11 @@ class SB3OutputVerboseLevel(IntEnum):
|
||||
NONE = 0
|
||||
INFO = 1
|
||||
DEBUG = 2
|
||||
|
||||
|
||||
class RulePermissionType(Enum):
|
||||
"""Any firewall rule type."""
|
||||
|
||||
NONE = 0
|
||||
DENY = 1
|
||||
ALLOW = 2
|
||||
|
||||
@@ -163,3 +163,4 @@
|
||||
destination: ANY
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 0
|
||||
|
||||
@@ -243,6 +243,7 @@
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 0
|
||||
- item_type: ACL_RULE
|
||||
id: '26'
|
||||
permission: ALLOW
|
||||
@@ -250,6 +251,7 @@
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 1
|
||||
- item_type: ACL_RULE
|
||||
id: '27'
|
||||
permission: ALLOW
|
||||
@@ -257,6 +259,7 @@
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 2
|
||||
- item_type: ACL_RULE
|
||||
id: '28'
|
||||
permission: ALLOW
|
||||
@@ -264,6 +267,7 @@
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 3
|
||||
- item_type: ACL_RULE
|
||||
id: '29'
|
||||
permission: ALLOW
|
||||
@@ -271,6 +275,7 @@
|
||||
destination: 192.168.10.13
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 4
|
||||
- item_type: ACL_RULE
|
||||
id: '30'
|
||||
permission: DENY
|
||||
@@ -278,6 +283,7 @@
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 5
|
||||
- item_type: ACL_RULE
|
||||
id: '31'
|
||||
permission: DENY
|
||||
@@ -285,6 +291,7 @@
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 6
|
||||
- item_type: ACL_RULE
|
||||
id: '32'
|
||||
permission: DENY
|
||||
@@ -292,6 +299,7 @@
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 7
|
||||
- item_type: ACL_RULE
|
||||
id: '33'
|
||||
permission: DENY
|
||||
@@ -299,6 +307,7 @@
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 8
|
||||
- item_type: RED_POL
|
||||
id: '34'
|
||||
start_step: 20
|
||||
|
||||
@@ -111,6 +111,7 @@
|
||||
destination: 192.168.1.4
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 0
|
||||
- item_type: ACL_RULE
|
||||
id: '12'
|
||||
permission: ALLOW
|
||||
@@ -118,6 +119,7 @@
|
||||
destination: 192.168.1.4
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 1
|
||||
- item_type: ACL_RULE
|
||||
id: '13'
|
||||
permission: ALLOW
|
||||
@@ -125,6 +127,7 @@
|
||||
destination: 192.168.1.3
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 2
|
||||
- item_type: RED_POL
|
||||
id: '14'
|
||||
start_step: 20
|
||||
|
||||
@@ -345,6 +345,7 @@
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 0
|
||||
- item_type: ACL_RULE
|
||||
id: '34'
|
||||
permission: ALLOW
|
||||
@@ -352,6 +353,7 @@
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 1
|
||||
- item_type: ACL_RULE
|
||||
id: '35'
|
||||
permission: ALLOW
|
||||
@@ -359,6 +361,7 @@
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 2
|
||||
- item_type: ACL_RULE
|
||||
id: '36'
|
||||
permission: ALLOW
|
||||
@@ -366,6 +369,7 @@
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 3
|
||||
- item_type: ACL_RULE
|
||||
id: '37'
|
||||
permission: ALLOW
|
||||
@@ -373,6 +377,7 @@
|
||||
destination: 192.168.10.11
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 4
|
||||
- item_type: ACL_RULE
|
||||
id: '38'
|
||||
permission: ALLOW
|
||||
@@ -380,6 +385,7 @@
|
||||
destination: 192.168.10.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 5
|
||||
- item_type: ACL_RULE
|
||||
id: '39'
|
||||
permission: ALLOW
|
||||
@@ -387,6 +393,7 @@
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 6
|
||||
- item_type: ACL_RULE
|
||||
id: '40'
|
||||
permission: ALLOW
|
||||
@@ -394,6 +401,7 @@
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 7
|
||||
- item_type: ACL_RULE
|
||||
id: '41'
|
||||
permission: ALLOW
|
||||
@@ -401,6 +409,7 @@
|
||||
destination: 192.168.2.16
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 8
|
||||
- item_type: ACL_RULE
|
||||
id: '42'
|
||||
permission: ALLOW
|
||||
@@ -408,6 +417,7 @@
|
||||
destination: 192.168.2.16
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 9
|
||||
- item_type: ACL_RULE
|
||||
id: '43'
|
||||
permission: ALLOW
|
||||
@@ -415,6 +425,7 @@
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 10
|
||||
- item_type: ACL_RULE
|
||||
id: '44'
|
||||
permission: ALLOW
|
||||
@@ -422,6 +433,7 @@
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 11
|
||||
- item_type: ACL_RULE
|
||||
id: '45'
|
||||
permission: ALLOW
|
||||
@@ -429,6 +441,7 @@
|
||||
destination: 192.168.2.16
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 12
|
||||
- item_type: ACL_RULE
|
||||
id: '46'
|
||||
permission: ALLOW
|
||||
@@ -436,6 +449,7 @@
|
||||
destination: 192.168.1.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 13
|
||||
- item_type: ACL_RULE
|
||||
id: '47'
|
||||
permission: ALLOW
|
||||
@@ -443,6 +457,7 @@
|
||||
destination: 192.168.1.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 14
|
||||
- item_type: ACL_RULE
|
||||
id: '48'
|
||||
permission: ALLOW
|
||||
@@ -450,6 +465,7 @@
|
||||
destination: 192.168.1.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 15
|
||||
- item_type: ACL_RULE
|
||||
id: '49'
|
||||
permission: DENY
|
||||
@@ -457,6 +473,7 @@
|
||||
destination: ANY
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 16
|
||||
- item_type: RED_POL
|
||||
id: '50'
|
||||
start_step: 50
|
||||
|
||||
@@ -35,7 +35,7 @@ random_red_agent: False
|
||||
# Default is None (null)
|
||||
seed: null
|
||||
|
||||
# Set whether the agent will be deterministic instead of stochastic
|
||||
# Set whether the agent evaluation will be deterministic instead of stochastic
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
@@ -51,15 +51,15 @@ hard_coded_agent_view: FULL
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: NODE
|
||||
action_type: ANY
|
||||
# observation space
|
||||
observation_space:
|
||||
# flatten: true
|
||||
flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
# - name: NODE_STATUSES
|
||||
# - name: LINK_TRAFFIC_LEVELS
|
||||
|
||||
- name: NODE_STATUSES
|
||||
- name: LINK_TRAFFIC_LEVELS
|
||||
- name: ACCESS_CONTROL_LIST
|
||||
|
||||
# Number of episodes for training to run per session
|
||||
num_train_episodes: 10
|
||||
@@ -90,6 +90,11 @@ session_type: TRAIN_EVAL
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
# Implicit ACL firewall rule at end of ACL list to be the default action (ALLOW or DENY)
|
||||
implicit_acl_rule: DENY
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 30
|
||||
|
||||
# The Stable Baselines3 learn/eval output verbosity level:
|
||||
# Options are:
|
||||
# "NONE" (No Output)
|
||||
|
||||
@@ -15,6 +15,7 @@ from primaite.common.enums import (
|
||||
AgentIdentifier,
|
||||
DeepLearningFramework,
|
||||
HardCodedAgentView,
|
||||
RulePermissionType,
|
||||
SB3OutputVerboseLevel,
|
||||
SessionType,
|
||||
)
|
||||
@@ -99,6 +100,12 @@ class TrainingConfig:
|
||||
sb3_output_verbose_level: SB3OutputVerboseLevel = SB3OutputVerboseLevel.NONE
|
||||
"Stable Baselines3 learn/eval output verbosity level"
|
||||
|
||||
implicit_acl_rule: RulePermissionType = RulePermissionType.DENY
|
||||
"ALLOW or DENY implicit firewall rule to go at the end of list of ACL list."
|
||||
|
||||
max_number_acl_rules: int = 30
|
||||
"Sets a limit for number of acl rules allowed in the list and environment."
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: float = 0
|
||||
@@ -207,6 +214,7 @@ class TrainingConfig:
|
||||
"session_type": SessionType,
|
||||
"sb3_output_verbose_level": SB3OutputVerboseLevel,
|
||||
"hard_coded_agent_view": HardCodedAgentView,
|
||||
"implicit_acl_rule": RulePermissionType,
|
||||
}
|
||||
|
||||
# convert the string representation of enums into the actual enum values themselves?
|
||||
@@ -233,6 +241,7 @@ class TrainingConfig:
|
||||
data["sb3_output_verbose_level"] = self.sb3_output_verbose_level.name
|
||||
data["session_type"] = self.session_type.name
|
||||
data["hard_coded_agent_view"] = self.hard_coded_agent_view.name
|
||||
data["implicit_acl_rule"] = self.implicit_acl_rule.name
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@@ -8,7 +8,8 @@ from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.common.enums import FileSystemState, HardwareState, RulePermissionType, SoftwareState
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
@@ -275,6 +276,7 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
services = self.env.services_list
|
||||
|
||||
structure = []
|
||||
|
||||
for _, node in self.env.nodes.items():
|
||||
node_id = node.node_id
|
||||
structure.append(f"node_{node_id}_hardware_state_NONE")
|
||||
@@ -403,6 +405,182 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
return structure
|
||||
|
||||
|
||||
class AccessControlList(AbstractObservationComponent):
|
||||
"""Flat list of all the Access Control Rules in the Access Control List.
|
||||
|
||||
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
|
||||
integers.
|
||||
|
||||
Each ACL Rule has 6 elements. It will have the following structure:
|
||||
.. code-block::
|
||||
[
|
||||
acl_rule1 permission,
|
||||
acl_rule1 source_ip,
|
||||
acl_rule1 dest_ip,
|
||||
acl_rule1 protocol,
|
||||
acl_rule1 port,
|
||||
acl_rule1 position,
|
||||
acl_rule2 permission,
|
||||
acl_rule2 source_ip,
|
||||
acl_rule2 dest_ip,
|
||||
acl_rule2 protocol,
|
||||
acl_rule2 port,
|
||||
acl_rule2 position,
|
||||
...
|
||||
]
|
||||
|
||||
|
||||
Terms (for ACL Observation Space):
|
||||
[0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW)
|
||||
[0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
|
||||
[0, num nodes] - Dest IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
|
||||
[0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol)
|
||||
[0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port)
|
||||
[0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list)
|
||||
|
||||
NOTE: NA is Non-Applicable - this means the ACL Rule in the list is a NoneType and NOT an ACLRule object.
|
||||
"""
|
||||
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite"):
|
||||
"""
|
||||
Initialise an AccessControlList observation component.
|
||||
|
||||
:param env: The environment that forms the basis of the observations
|
||||
:type env: Primaite
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
# The NA and ANY types means that there are 2 extra items for Nodes, Services and Ports.
|
||||
# Number of ACL rules incremented by 1 for positions starting at index 0.
|
||||
acl_shape = [
|
||||
len(RulePermissionType),
|
||||
len(env.nodes) + 2,
|
||||
len(env.nodes) + 2,
|
||||
len(env.services_list) + 2,
|
||||
len(env.ports_list) + 2,
|
||||
env.max_number_acl_rules,
|
||||
]
|
||||
shape = acl_shape * self.env.max_number_acl_rules
|
||||
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.MultiDiscrete(shape)
|
||||
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self) -> None:
|
||||
"""Update the observation based on current environment state.
|
||||
|
||||
The structure of the observation space is described in :class:`.AccessControlList`
|
||||
"""
|
||||
obs = []
|
||||
|
||||
for index in range(0, len(self.env.acl.acl)):
|
||||
acl_rule = self.env.acl.acl[index]
|
||||
if isinstance(acl_rule, ACLRule):
|
||||
permission = acl_rule.permission
|
||||
source_ip = acl_rule.source_ip
|
||||
dest_ip = acl_rule.dest_ip
|
||||
protocol = acl_rule.protocol
|
||||
port = acl_rule.port
|
||||
position = index
|
||||
# Map each ACL attribute from what it was to an integer to fit the observation space
|
||||
source_ip_int = None
|
||||
dest_ip_int = None
|
||||
if permission == RulePermissionType.DENY:
|
||||
permission_int = 1
|
||||
else:
|
||||
permission_int = 2
|
||||
if source_ip == "ANY":
|
||||
source_ip_int = 1
|
||||
else:
|
||||
# Map Node ID (+ 1) to source IP address
|
||||
nodes = list(self.env.nodes.values())
|
||||
for node in nodes:
|
||||
if (
|
||||
isinstance(node, ServiceNode) or isinstance(node, ActiveNode)
|
||||
) and node.ip_address == source_ip:
|
||||
source_ip_int = int(node.node_id) + 1
|
||||
break
|
||||
if dest_ip == "ANY":
|
||||
dest_ip_int = 1
|
||||
else:
|
||||
# Map Node ID (+ 1) to dest IP address
|
||||
# Index of Nodes start at 1 so + 1 is needed so NA can be added.
|
||||
nodes = list(self.env.nodes.values())
|
||||
for node in nodes:
|
||||
if (
|
||||
isinstance(node, ServiceNode) or isinstance(node, ActiveNode)
|
||||
) and node.ip_address == dest_ip:
|
||||
dest_ip_int = int(node.node_id) + 1
|
||||
if protocol == "ANY":
|
||||
protocol_int = 1
|
||||
else:
|
||||
# Index of protocols and ports start from 0 so + 2 is needed to add NA and ANY
|
||||
try:
|
||||
protocol_int = self.env.services_list.index(protocol) + 2
|
||||
except AttributeError:
|
||||
_LOGGER.info(f"Service {protocol} could not be found")
|
||||
protocol_int = None
|
||||
if port == "ANY":
|
||||
port_int = 1
|
||||
else:
|
||||
if port in self.env.ports_list:
|
||||
port_int = self.env.ports_list.index(port) + 2
|
||||
else:
|
||||
_LOGGER.info(f"Port {port} could not be found.")
|
||||
port_int = None
|
||||
# Add to current obs
|
||||
obs.extend(
|
||||
[
|
||||
permission_int,
|
||||
source_ip_int,
|
||||
dest_ip_int,
|
||||
protocol_int,
|
||||
port_int,
|
||||
position,
|
||||
]
|
||||
)
|
||||
|
||||
else:
|
||||
# The Nothing or NA representation of 'NONE' ACL rules
|
||||
obs.extend([0, 0, 0, 0, 0, 0])
|
||||
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
structure = []
|
||||
for acl_rule in self.env.acl.acl:
|
||||
acl_rule_id = self.env.acl.acl.index(acl_rule)
|
||||
|
||||
for permission in RulePermissionType:
|
||||
structure.append(f"acl_rule_{acl_rule_id}_permission_{permission.name}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_source_ip_ANY")
|
||||
for node in self.env.nodes.keys():
|
||||
structure.append(f"acl_rule_{acl_rule_id}_source_ip_{node}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_dest_ip_ANY")
|
||||
for node in self.env.nodes.keys():
|
||||
structure.append(f"acl_rule_{acl_rule_id}_dest_ip_{node}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_service_ANY")
|
||||
for service in self.env.services_list:
|
||||
structure.append(f"acl_rule_{acl_rule_id}_service_{service}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_port_ANY")
|
||||
for port in self.env.ports_list:
|
||||
structure.append(f"acl_rule_{acl_rule_id}_port_{port}")
|
||||
|
||||
return structure
|
||||
|
||||
|
||||
class ObservationsHandler:
|
||||
"""
|
||||
Component-based observation space handler.
|
||||
@@ -415,6 +593,7 @@ class ObservationsHandler:
|
||||
"NODE_LINK_TABLE": NodeLinkTable,
|
||||
"NODE_STATUSES": NodeStatuses,
|
||||
"LINK_TRAFFIC_LEVELS": LinkTrafficLevels,
|
||||
"ACCESS_CONTROL_LIST": AccessControlList,
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
@@ -430,8 +609,6 @@ class ObservationsHandler:
|
||||
# used for transactions and when flatten=true
|
||||
self._flat_observation: np.ndarray
|
||||
|
||||
self.flatten: bool = False
|
||||
|
||||
def update_obs(self) -> None:
|
||||
"""Fetch fresh information about the environment."""
|
||||
current_obs = []
|
||||
@@ -485,7 +662,7 @@ class ObservationsHandler:
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Observation space, return the flattened version if flatten is True."""
|
||||
if self.flatten:
|
||||
if len(self.registered_obs_components) > 1:
|
||||
return self._flat_space
|
||||
else:
|
||||
return self._space
|
||||
@@ -493,7 +670,7 @@ class ObservationsHandler:
|
||||
@property
|
||||
def current_observation(self) -> Union[np.ndarray, Tuple[np.ndarray]]:
|
||||
"""Current observation, return the flattened version if flatten is True."""
|
||||
if self.flatten:
|
||||
if len(self.registered_obs_components) > 1:
|
||||
return self._flat_observation
|
||||
else:
|
||||
return self._observation
|
||||
@@ -528,9 +705,6 @@ class ObservationsHandler:
|
||||
# Instantiate the handler
|
||||
handler = cls()
|
||||
|
||||
if obs_space_config.get("flatten"):
|
||||
handler.flatten = True
|
||||
|
||||
for component_cfg in obs_space_config["components"]:
|
||||
# Figure out which class can instantiate the desired component
|
||||
comp_type = component_cfg["name"]
|
||||
|
||||
@@ -125,7 +125,12 @@ class Primaite(Env):
|
||||
self.red_node_pol: Dict[str, NodeStateInstructionRed] = {}
|
||||
|
||||
# Create the Access Control List
|
||||
self.acl: AccessControlList = AccessControlList()
|
||||
self.acl: AccessControlList = AccessControlList(
|
||||
self.training_config.implicit_acl_rule,
|
||||
self.training_config.max_number_acl_rules,
|
||||
)
|
||||
# Sets limit for number of ACL rules in environment
|
||||
self.max_number_acl_rules: int = self.training_config.max_number_acl_rules
|
||||
|
||||
# Create a list of services (enums)
|
||||
self.services_list: List[str] = []
|
||||
@@ -458,12 +463,11 @@ class Primaite(Env):
|
||||
_action: The action space from the agent
|
||||
"""
|
||||
# At the moment, actions are only affecting nodes
|
||||
|
||||
if self.training_config.action_type == ActionType.NODE:
|
||||
self.apply_actions_to_nodes(_action)
|
||||
elif self.training_config.action_type == ActionType.ACL:
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif len(self.action_dict[_action]) == 6: # ACL actions in multidiscrete form have len 6
|
||||
elif len(self.action_dict[_action]) == 7: # ACL actions in multidiscrete form have len 7
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif len(self.action_dict[_action]) == 4: # Node actions in multdiscrete (array) from have len 4
|
||||
self.apply_actions_to_nodes(_action)
|
||||
@@ -574,6 +578,7 @@ class Primaite(Env):
|
||||
action_destination_ip = readable_action[3]
|
||||
action_protocol = readable_action[4]
|
||||
action_port = readable_action[5]
|
||||
acl_rule_position = readable_action[6]
|
||||
|
||||
if action_decision == 0:
|
||||
# It's decided to do nothing
|
||||
@@ -623,6 +628,7 @@ class Primaite(Env):
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_rule_position,
|
||||
)
|
||||
elif action_decision == 2:
|
||||
# Remove the rule
|
||||
@@ -1018,6 +1024,7 @@ class Primaite(Env):
|
||||
acl_rule_destination = item["destination"]
|
||||
acl_rule_protocol = item["protocol"]
|
||||
acl_rule_port = item["port"]
|
||||
acl_rule_position = item["position"]
|
||||
|
||||
self.acl.add_rule(
|
||||
acl_rule_permission,
|
||||
@@ -1025,6 +1032,7 @@ class Primaite(Env):
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_rule_position,
|
||||
)
|
||||
|
||||
# TODO: confirm typehint using runtime
|
||||
@@ -1182,6 +1190,11 @@ class Primaite(Env):
|
||||
...
|
||||
}
|
||||
"""
|
||||
# Terms (for node action space):
|
||||
# [0, num nodes] - node ID (0 = nothing, node ID)
|
||||
# [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa
|
||||
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa
|
||||
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
|
||||
# reserve 0 action to be a nothing action
|
||||
actions = {0: [1, 0, 0, 0]}
|
||||
action_key = 1
|
||||
@@ -1203,8 +1216,16 @@ class Primaite(Env):
|
||||
|
||||
def create_acl_action_dict(self) -> Dict[int, List[int]]:
|
||||
"""Creates a dictionary mapping each possible discrete action to more readable multidiscrete action."""
|
||||
# Terms (for ACL action space):
|
||||
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
|
||||
# [0, 1] - Permission (0 = DENY, 1 = ALLOW)
|
||||
# [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
|
||||
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
|
||||
# [0, max acl rules - 1] - Position (0 = first index, then 1 -> x index resolving to acl rule in acl list)
|
||||
# reserve 0 action to be a nothing action
|
||||
actions = {0: [0, 0, 0, 0, 0, 0]}
|
||||
actions = {0: [0, 0, 0, 0, 0, 0, 0]}
|
||||
|
||||
action_key = 1
|
||||
# 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE
|
||||
@@ -1216,18 +1237,21 @@ class Primaite(Env):
|
||||
for dest_ip in range(self.num_nodes + 1):
|
||||
for protocol in range(self.num_services + 1):
|
||||
for port in range(self.num_ports + 1):
|
||||
action = [
|
||||
action_decision,
|
||||
action_permission,
|
||||
source_ip,
|
||||
dest_ip,
|
||||
protocol,
|
||||
port,
|
||||
]
|
||||
# Check to see if its an action we want to include as possible i.e. not a nothing action
|
||||
if is_valid_acl_action_extra(action):
|
||||
actions[action_key] = action
|
||||
action_key += 1
|
||||
for position in range(self.max_number_acl_rules - 1):
|
||||
action = [
|
||||
action_decision,
|
||||
action_permission,
|
||||
source_ip,
|
||||
dest_ip,
|
||||
protocol,
|
||||
port,
|
||||
position,
|
||||
]
|
||||
# Check to see if it is an action we want to include as possible
|
||||
# i.e. not a nothing action
|
||||
if is_valid_acl_action_extra(action):
|
||||
actions[action_key] = action
|
||||
action_key += 1
|
||||
|
||||
return actions
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -92,6 +92,7 @@
|
||||
destination: 192.168.1.2
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 0
|
||||
- item_type: ACL_RULE
|
||||
id: '7'
|
||||
permission: ALLOW
|
||||
@@ -99,3 +100,4 @@
|
||||
destination: 192.168.1.1
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 0
|
||||
|
||||
86
tests/config/obs_tests/laydown_ACL.yaml
Normal file
86
tests/config/obs_tests/laydown_ACL.yaml
Normal file
@@ -0,0 +1,86 @@
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '80'
|
||||
- port: '21'
|
||||
- item_type: SERVICES
|
||||
service_list:
|
||||
- name: TCP
|
||||
- name: FTP
|
||||
|
||||
########################################
|
||||
# Nodes
|
||||
- item_type: NODE
|
||||
node_id: '1'
|
||||
name: PC1
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.1
|
||||
software_state: COMPROMISED
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: FTP
|
||||
port: '21'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '2'
|
||||
name: SERVER
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.2
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: FTP
|
||||
port: '21'
|
||||
state: OVERWHELMED
|
||||
- item_type: NODE
|
||||
node_id: '3'
|
||||
name: SWITCH1
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.3
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
|
||||
########################################
|
||||
# Links
|
||||
- item_type: LINK
|
||||
id: '4'
|
||||
name: link1
|
||||
bandwidth: 1000
|
||||
source: '1'
|
||||
destination: '3'
|
||||
- item_type: LINK
|
||||
id: '5'
|
||||
name: link2
|
||||
bandwidth: 1000
|
||||
source: '3'
|
||||
destination: '2'
|
||||
|
||||
#########################################
|
||||
# IERS
|
||||
- item_type: GREEN_IER
|
||||
id: '5'
|
||||
start_step: 0
|
||||
end_step: 5
|
||||
load: 999
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '2'
|
||||
mission_criticality: 5
|
||||
|
||||
#########################################
|
||||
# ACL Rules
|
||||
106
tests/config/obs_tests/main_config_ACCESS_CONTROL_LIST.yaml
Normal file
106
tests/config/obs_tests/main_config_ACCESS_CONTROL_LIST.yaml
Normal file
@@ -0,0 +1,106 @@
|
||||
# Main Config File
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agent_framework: SB3
|
||||
agent_identifier: PPO
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: ANY
|
||||
# Number of episodes for training to run per session
|
||||
num_train_episodes: 1
|
||||
# Number of time_steps for training per episode
|
||||
num_train_steps: 5
|
||||
|
||||
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
|
||||
implicit_acl_rule: DENY
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 3
|
||||
|
||||
observation_space:
|
||||
components:
|
||||
- name: ACCESS_CONTROL_LIST
|
||||
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 1
|
||||
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAIN
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1_000_000_000
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0
|
||||
# Node Hardware State
|
||||
off_should_be_on: -10
|
||||
off_should_be_resetting: -5
|
||||
on_should_be_off: -2
|
||||
on_should_be_resetting: -5
|
||||
resetting_should_be_on: -5
|
||||
resetting_should_be_off: -2
|
||||
resetting: -3
|
||||
# Node Software or Service State
|
||||
good_should_be_patching: 2
|
||||
good_should_be_compromised: 5
|
||||
good_should_be_overwhelmed: 5
|
||||
patching_should_be_good: -5
|
||||
patching_should_be_compromised: 2
|
||||
patching_should_be_overwhelmed: 2
|
||||
patching: -3
|
||||
compromised_should_be_good: -20
|
||||
compromised_should_be_patching: -20
|
||||
compromised_should_be_overwhelmed: -20
|
||||
compromised: -20
|
||||
overwhelmed_should_be_good: -20
|
||||
overwhelmed_should_be_patching: -20
|
||||
overwhelmed_should_be_compromised: -20
|
||||
overwhelmed: -20
|
||||
# Node File System State
|
||||
good_should_be_repairing: 2
|
||||
good_should_be_restoring: 2
|
||||
good_should_be_corrupt: 5
|
||||
good_should_be_destroyed: 10
|
||||
repairing_should_be_good: -5
|
||||
repairing_should_be_restoring: 2
|
||||
repairing_should_be_corrupt: 2
|
||||
repairing_should_be_destroyed: 0
|
||||
repairing: -3
|
||||
restoring_should_be_good: -10
|
||||
restoring_should_be_repairing: -2
|
||||
restoring_should_be_corrupt: 1
|
||||
restoring_should_be_destroyed: 2
|
||||
restoring: -6
|
||||
corrupt_should_be_good: -10
|
||||
corrupt_should_be_repairing: -10
|
||||
corrupt_should_be_restoring: -10
|
||||
corrupt_should_be_destroyed: 2
|
||||
corrupt: -10
|
||||
destroyed_should_be_good: -20
|
||||
destroyed_should_be_repairing: -20
|
||||
destroyed_should_be_restoring: -20
|
||||
destroyed_should_be_corrupt: -20
|
||||
destroyed: -20
|
||||
scanning: -2
|
||||
# IER status
|
||||
red_ier_running: -5
|
||||
green_ier_blocked: -10
|
||||
|
||||
# Patching / Reset durations
|
||||
os_patching_duration: 5 # The time taken to patch the OS
|
||||
node_reset_duration: 5 # The time taken to reset a node (hardware)
|
||||
service_patching_duration: 5 # The time taken to patch a service
|
||||
file_system_repairing_limit: 5 # The time take to repair the file system
|
||||
file_system_restoring_limit: 5 # The time take to restore the file system
|
||||
file_system_scanning_limit: 5 # The time taken to scan the file system
|
||||
@@ -39,6 +39,11 @@ observation_space:
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 1
|
||||
|
||||
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
|
||||
implicit_acl_rule: ALLOW
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 4
|
||||
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAIN
|
||||
# Determine whether to load an agent from file
|
||||
|
||||
@@ -37,6 +37,11 @@ observation_space:
|
||||
time_delay: 1
|
||||
# Filename of the scenario / laydown
|
||||
|
||||
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
|
||||
implicit_acl_rule: ALLOW
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 4
|
||||
|
||||
session_type: TRAIN
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
|
||||
@@ -40,7 +40,8 @@ agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1_000_000_000
|
||||
|
||||
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
|
||||
implicit_acl_rule: DENY
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0
|
||||
|
||||
@@ -91,6 +91,8 @@ session_type: EVAL
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
implicit_acl_rule: DENY
|
||||
max_number_acl_rules: 10
|
||||
# The Stable Baselines3 learn/eval output verbosity level:
|
||||
# Options are:
|
||||
# "NONE" (No Output)
|
||||
|
||||
@@ -36,7 +36,7 @@ random_red_agent: False
|
||||
# Default is None (null)
|
||||
seed: None
|
||||
|
||||
# Set whether the agent will be deterministic instead of stochastic
|
||||
# Set whether the agent evaluation will be deterministic instead of stochastic
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
@@ -55,11 +55,11 @@ hard_coded_agent_view: FULL
|
||||
action_type: NODE
|
||||
# observation space
|
||||
observation_space:
|
||||
# flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
# - name: NODE_STATUSES
|
||||
# - name: LINK_TRAFFIC_LEVELS
|
||||
# - name: ACCESS_CONTROL_LIST
|
||||
# Number of episodes to run per session
|
||||
num_train_episodes: 10
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ random_red_agent: False
|
||||
# Default is None (null)
|
||||
seed: 67890
|
||||
|
||||
# Set whether the agent will be deterministic instead of stochastic
|
||||
# Set whether the agent evaluation will be deterministic instead of stochastic
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
@@ -55,7 +55,6 @@ hard_coded_agent_view: FULL
|
||||
action_type: NODE
|
||||
# observation space
|
||||
observation_space:
|
||||
# flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
# - name: NODE_STATUSES
|
||||
|
||||
@@ -38,6 +38,15 @@ load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
|
||||
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
|
||||
implicit_acl_rule: DENY
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 10
|
||||
|
||||
observation_space:
|
||||
components:
|
||||
- name: ACCESS_CONTROL_LIST
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '21'
|
||||
- port: '80'
|
||||
- item_type: SERVICES
|
||||
service_list:
|
||||
- name: ftp
|
||||
- name: TCP
|
||||
- item_type: NODE
|
||||
node_id: '1'
|
||||
name: node
|
||||
@@ -16,8 +16,8 @@
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: ftp
|
||||
port: '21'
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: COMPROMISED
|
||||
- item_type: NODE
|
||||
node_id: '2'
|
||||
@@ -30,15 +30,15 @@
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: ftp
|
||||
port: '21'
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: COMPROMISED
|
||||
- item_type: RED_IER
|
||||
id: '3'
|
||||
start_step: 2
|
||||
end_step: 15
|
||||
load: 1000
|
||||
protocol: ftp
|
||||
protocol: TCP
|
||||
port: CORRUPT
|
||||
source: '1'
|
||||
destination: '2'
|
||||
|
||||
@@ -47,6 +47,9 @@ agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
# Choice whether to have an ALLOW or DENY implicit rule or not (TRUE or FALSE)
|
||||
implicit_acl_rule: DENY
|
||||
max_number_acl_rules: 10
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0
|
||||
|
||||
@@ -3,40 +3,41 @@
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.common.enums import RulePermissionType
|
||||
|
||||
|
||||
def test_acl_address_match_1():
|
||||
"""Test that matching IP addresses produce True."""
|
||||
acl = AccessControlList()
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
|
||||
|
||||
|
||||
def test_acl_address_match_2():
|
||||
"""Test that mismatching IP addresses produce False."""
|
||||
acl = AccessControlList()
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.3") == False
|
||||
|
||||
|
||||
def test_acl_address_match_3():
|
||||
"""Test the ANY condition for source IP addresses produce True."""
|
||||
acl = AccessControlList()
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule("ALLOW", "ANY", "192.168.1.2", "TCP", "80")
|
||||
rule = ACLRule(RulePermissionType.ALLOW, "ANY", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
|
||||
|
||||
|
||||
def test_acl_address_match_4():
|
||||
"""Test the ANY condition for dest IP addresses produce True."""
|
||||
acl = AccessControlList()
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule("ALLOW", "192.168.1.1", "ANY", "TCP", "80")
|
||||
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "ANY", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
|
||||
|
||||
@@ -44,14 +45,15 @@ def test_acl_address_match_4():
|
||||
def test_check_acl_block_affirmative():
|
||||
"""Test the block function (affirmative)."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList()
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
# Create a rule
|
||||
acl_rule_permission = "ALLOW"
|
||||
acl_rule_permission = RulePermissionType.ALLOW
|
||||
acl_rule_source = "192.168.1.1"
|
||||
acl_rule_destination = "192.168.1.2"
|
||||
acl_rule_protocol = "TCP"
|
||||
acl_rule_port = "80"
|
||||
acl_position_in_list = "0"
|
||||
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
@@ -59,22 +61,23 @@ def test_check_acl_block_affirmative():
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_position_in_list,
|
||||
)
|
||||
|
||||
assert acl.is_blocked("192.168.1.1", "192.168.1.2", "TCP", "80") == False
|
||||
|
||||
|
||||
def test_check_acl_block_negative():
|
||||
"""Test the block function (negative)."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList()
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
# Create a rule
|
||||
acl_rule_permission = "DENY"
|
||||
acl_rule_permission = RulePermissionType.DENY
|
||||
acl_rule_source = "192.168.1.1"
|
||||
acl_rule_destination = "192.168.1.2"
|
||||
acl_rule_protocol = "TCP"
|
||||
acl_rule_port = "80"
|
||||
acl_position_in_list = "0"
|
||||
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
@@ -82,6 +85,7 @@ def test_check_acl_block_negative():
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_position_in_list,
|
||||
)
|
||||
|
||||
assert acl.is_blocked("192.168.1.1", "192.168.1.2", "TCP", "80") == True
|
||||
@@ -90,11 +94,73 @@ def test_check_acl_block_negative():
|
||||
def test_rule_hash():
|
||||
"""Test the rule hash."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList()
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
rule = ACLRule(RulePermissionType.DENY, "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
hash_value_local = hash(rule)
|
||||
|
||||
hash_value_remote = acl.get_dictionary_hash("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
hash_value_remote = acl.get_dictionary_hash(RulePermissionType.DENY, "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert hash_value_local == hash_value_remote
|
||||
|
||||
|
||||
def test_delete_rule():
|
||||
"""Adds 3 rules and deletes 1 rule and checks its deletion."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList(RulePermissionType.ALLOW, 10)
|
||||
|
||||
# Create a first rule
|
||||
acl_rule_permission = RulePermissionType.DENY
|
||||
acl_rule_source = "192.168.1.1"
|
||||
acl_rule_destination = "192.168.1.2"
|
||||
acl_rule_protocol = "TCP"
|
||||
acl_rule_port = "80"
|
||||
acl_position_in_list = "0"
|
||||
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_position_in_list,
|
||||
)
|
||||
|
||||
# Create a second rule
|
||||
acl_rule_permission = RulePermissionType.DENY
|
||||
acl_rule_source = "20"
|
||||
acl_rule_destination = "30"
|
||||
acl_rule_protocol = "FTP"
|
||||
acl_rule_port = "21"
|
||||
acl_position_in_list = "2"
|
||||
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_position_in_list,
|
||||
)
|
||||
|
||||
# Create a third rule
|
||||
acl_rule_permission = RulePermissionType.ALLOW
|
||||
acl_rule_source = "192.168.1.3"
|
||||
acl_rule_destination = "192.168.1.1"
|
||||
acl_rule_protocol = "UDP"
|
||||
acl_rule_port = "60"
|
||||
acl_position_in_list = "4"
|
||||
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_position_in_list,
|
||||
)
|
||||
# Remove the second ACL rule added from the list
|
||||
acl.remove_rule(RulePermissionType.DENY, "20", "30", "FTP", "21")
|
||||
|
||||
assert len(acl.acl) == 10
|
||||
assert acl.acl[2] is None
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Test env creation and behaviour with different observation spaces."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
@@ -237,3 +238,140 @@ class TestLinkTrafficLevels:
|
||||
# therefore the first and third elements should be 6 and all others 0
|
||||
# (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%)
|
||||
assert np.array_equal(obs, [6, 0, 6, 0])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "obs_tests/main_config_ACCESS_CONTROL_LIST.yaml",
|
||||
TEST_CONFIG_ROOT / "obs_tests/laydown_ACL.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
class TestAccessControlList:
|
||||
"""Test the AccessControlList observation component (in isolation)."""
|
||||
|
||||
def test_obs_shape(self, temp_primaite_session):
|
||||
"""Try creating env with MultiDiscrete observation space.
|
||||
|
||||
The laydown has 3 ACL Rules - that is the maximum_acl_rules it can have.
|
||||
Each ACL Rule in the observation space has 6 different elements:
|
||||
|
||||
6 * 3 = 18
|
||||
"""
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
env.update_environent_obs()
|
||||
|
||||
assert env.env_obs.shape == (18,)
|
||||
|
||||
def test_values(self, temp_primaite_session):
|
||||
"""Test that traffic values are encoded correctly.
|
||||
|
||||
The laydown has:
|
||||
* one ACL IMPLICIT DENY rule
|
||||
|
||||
Therefore, the ACL is full of NAs aka zeros and just 6 non-zero elements representing DENY ANY ANY ANY at
|
||||
Position 2.
|
||||
"""
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
obs, reward, done, info = env.step(0)
|
||||
obs, reward, done, info = env.step(0)
|
||||
|
||||
assert np.array_equal(obs, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2])
|
||||
|
||||
def test_observation_space_with_implicit_rule(self, temp_primaite_session):
|
||||
"""
|
||||
Test observation space is what is expected when an agent adds ACLs during an episode.
|
||||
|
||||
At the start of the episode, there is a single implicit DENY rule
|
||||
In the observation space IMPLICIT DENY: 1,1,1,1,1,0
|
||||
0 shows the rule is the start (when episode began no other rules were created) so this is correct.
|
||||
|
||||
On Step 2, there is an ACL rule added at Position 0: 2,2,3,2,3,0
|
||||
|
||||
On Step 4, there is a second ACL rule added at POSITION 1: 2,4,2,3,3,1
|
||||
|
||||
The final observation space should be this:
|
||||
[2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2]
|
||||
|
||||
The ACL Rule from Step 2 is added first and has a HIGHER position than the ACL rule from Step 4
|
||||
but both come before the IMPLICIT DENY which will ALWAYS be at the end of the ACL List.
|
||||
"""
|
||||
# TODO: Refactor this at some point to build a custom ACL Hardcoded
|
||||
# Agent and then patch the AgentIdentifier Enum class so that it
|
||||
# has ACL_AGENT. This then allows us to set the agent identified in
|
||||
# the main config and is a bit cleaner.
|
||||
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
training_config = env.training_config
|
||||
for episode in range(0, training_config.num_train_episodes):
|
||||
for step in range(0, training_config.num_train_steps):
|
||||
# Do nothing action
|
||||
action = 0
|
||||
if step == 2:
|
||||
# Action to add the first ACL rule
|
||||
action = 43
|
||||
elif step == 4:
|
||||
# Action to add the second ACL rule
|
||||
action = 96
|
||||
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
# Break if done is True
|
||||
if done:
|
||||
break
|
||||
obs = env.env_obs
|
||||
|
||||
assert np.array_equal(obs, [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2])
|
||||
|
||||
def test_observation_space_with_different_positions(self, temp_primaite_session):
|
||||
"""
|
||||
Test observation space is what is expected when an agent adds ACLs during an episode.
|
||||
|
||||
At the start of the episode, there is a single implicit DENY rule
|
||||
In the observation space IMPLICIT DENY: 1,1,1,1,1,0
|
||||
0 shows the rule is the start (when episode began no other rules were created) so this is correct.
|
||||
|
||||
On Step 2, there is an ACL rule added at Position 1: 2,2,3,2,3,1
|
||||
|
||||
On Step 4 there is a second ACL rule added at Position 0: 2,4,2,3,3,0
|
||||
|
||||
The final observation space should be this:
|
||||
[2 , 4, 2, 3, 3, 0, 2, 2, 3, 2, 3, 1, 1, 1, 1, 1, 1, 2]
|
||||
|
||||
The ACL Rule from Step 2 is added before and has a LOWER position than the ACL rule from Step 4
|
||||
but both come before the IMPLICIT DENY which will ALWAYS be at the end of the ACL List.
|
||||
"""
|
||||
# TODO: Refactor this at some point to build a custom ACL Hardcoded
|
||||
# Agent and then patch the AgentIdentifier Enum class so that it
|
||||
# has ACL_AGENT. This then allows us to set the agent identified in
|
||||
# the main config and is a bit cleaner.
|
||||
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
training_config = env.training_config
|
||||
for episode in range(0, training_config.num_train_episodes):
|
||||
for step in range(0, training_config.num_train_steps):
|
||||
# Do nothing action
|
||||
action = 0
|
||||
if step == 2:
|
||||
# Action to add the first ACL rule
|
||||
action = 44
|
||||
elif step == 4:
|
||||
# Action to add the second ACL rule
|
||||
action = 95
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
# Break if done is True
|
||||
if done:
|
||||
break
|
||||
obs = env.env_obs
|
||||
|
||||
assert np.array_equal(obs, [2, 4, 2, 3, 3, 0, 2, 2, 3, 2, 3, 1, 1, 1, 1, 1, 1, 2])
|
||||
|
||||
@@ -11,30 +11,46 @@ from tests import TEST_CONFIG_ROOT
|
||||
indirect=True,
|
||||
)
|
||||
def test_seeded_learning(temp_primaite_session):
|
||||
"""Test running seeded learning produces the same output when ran twice."""
|
||||
"""
|
||||
Test running seeded learning produces the same output when ran twice.
|
||||
|
||||
.. note::
|
||||
|
||||
If this is failing, the hard-coded expected_mean_reward_per_episode
|
||||
from a pre-trained agent will probably need to be updated. If the
|
||||
env changes and those changed how this agent is trained, chances are
|
||||
the mean rewards are going to be different.
|
||||
|
||||
Run the test, but print out the session.learn_av_reward_per_episode()
|
||||
before comparing it. Then copy the printed dict and replace the
|
||||
expected_mean_reward_per_episode with those values. The test should
|
||||
now work. If not, then you've got a bug :).
|
||||
"""
|
||||
expected_mean_reward_per_episode = {
|
||||
1: -90.703125,
|
||||
2: -91.15234375,
|
||||
3: -87.5,
|
||||
4: -92.2265625,
|
||||
5: -94.6875,
|
||||
6: -91.19140625,
|
||||
7: -88.984375,
|
||||
8: -88.3203125,
|
||||
9: -112.79296875,
|
||||
10: -100.01953125,
|
||||
1: -20.7421875,
|
||||
2: -19.82421875,
|
||||
3: -17.01171875,
|
||||
4: -19.08203125,
|
||||
5: -21.93359375,
|
||||
6: -20.21484375,
|
||||
7: -15.546875,
|
||||
8: -12.08984375,
|
||||
9: -17.59765625,
|
||||
10: -14.6875,
|
||||
}
|
||||
|
||||
with temp_primaite_session as session:
|
||||
assert session._training_config.seed == 67890, (
|
||||
"Expected output is based upon a agent that was trained with " "seed 67890"
|
||||
)
|
||||
assert (
|
||||
session._training_config.seed == 67890
|
||||
), "Expected output is based upon a agent that was trained with seed 67890"
|
||||
session.learn()
|
||||
actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict()
|
||||
print(actual_mean_reward_per_episode, "THISt")
|
||||
|
||||
assert actual_mean_reward_per_episode == expected_mean_reward_per_episode
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Inconsistent results. Needs someone with RL " "knowledge to investigate further.")
|
||||
@pytest.mark.skip(reason="Inconsistent results. Needs someone with RL knowledge to investigate further.")
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[[TEST_CONFIG_ROOT / "ppo_seeded_training_config.yaml", dos_very_basic_config_path()]],
|
||||
|
||||
@@ -43,6 +43,9 @@ def copy_session_asset(asset_path: Union[str, Path]) -> str:
|
||||
return copy_path
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Loading works fine but the exact values change with code changes, a bug report has been created."
|
||||
)
|
||||
def test_load_sb3_session():
|
||||
"""Test that loading an SB3 agent works."""
|
||||
expected_learn_mean_reward_per_episode = {
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.common.enums import HardwareState
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
@@ -19,16 +20,17 @@ def run_generic_set_actions(env: Primaite):
|
||||
# TEMP - random action for now
|
||||
# action = env.blue_agent_action(obs)
|
||||
action = 0
|
||||
# print("Episode:", episode, "\nStep:", step)
|
||||
if step == 5:
|
||||
# [1, 1, 2, 1, 1, 1]
|
||||
# [1, 1, 2, 1, 1, 1, 1(position)]
|
||||
# Creates an ACL rule
|
||||
# Allows traffic from server_1 to node_1 on port FTP
|
||||
action = 7
|
||||
action = 56
|
||||
elif step == 7:
|
||||
# [1, 1, 2, 0] Node Action
|
||||
# Sets Node 1 Hardware State to OFF
|
||||
# Does not resolve any service
|
||||
action = 16
|
||||
action = 128
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
@@ -57,6 +59,10 @@ def run_generic_set_actions(env: Primaite):
|
||||
)
|
||||
def test_single_action_space_is_valid(temp_primaite_session):
|
||||
"""Test single action space is valid."""
|
||||
# TODO: Refactor this at some point to build a custom ACL Hardcoded
|
||||
# Agent and then patch the AgentIdentifier Enum class so that it
|
||||
# has ACL_AGENT. This then allows us to set the agent identified in
|
||||
# the main config and is a bit cleaner.
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
|
||||
@@ -73,7 +79,7 @@ def test_single_action_space_is_valid(temp_primaite_session):
|
||||
if len(dict_item) == 4:
|
||||
contains_node_actions = True
|
||||
# Link action detected
|
||||
elif len(dict_item) == 6:
|
||||
elif len(dict_item) == 7:
|
||||
contains_acl_actions = True
|
||||
# If both are there then the ANY action type is working
|
||||
if contains_node_actions and contains_acl_actions:
|
||||
@@ -94,6 +100,10 @@ def test_single_action_space_is_valid(temp_primaite_session):
|
||||
)
|
||||
def test_agent_is_executing_actions_from_both_spaces(temp_primaite_session):
|
||||
"""Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL)."""
|
||||
# TODO: Refactor this at some point to build a custom ACL Hardcoded
|
||||
# Agent and then patch the AgentIdentifier Enum class so that it
|
||||
# has ACL_AGENT. This then allows us to set the agent identified in
|
||||
# the main config and is a bit cleaner.
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
# Run environment with specified fixed blue agent actions only
|
||||
@@ -105,11 +115,15 @@ def test_agent_is_executing_actions_from_both_spaces(temp_primaite_session):
|
||||
access_control_list = env.acl
|
||||
# Use the Access Control List object acl object attribute to get dictionary
|
||||
# Use dictionary.values() to get total list of all items in the dictionary
|
||||
acl_rules_list = access_control_list.acl.values()
|
||||
acl_rules_list = access_control_list.acl
|
||||
# Length of this list tells you how many items are in the dictionary
|
||||
# This number is the frequency of Access Control Rules in the environment
|
||||
# In the scenario, we specified that the agent should create only 1 acl rule
|
||||
num_of_rules = len(acl_rules_list)
|
||||
# This 1 rule added to the implicit deny means there should be 2 rules in total.
|
||||
rules_count = 0
|
||||
for rule in acl_rules_list:
|
||||
if isinstance(rule, ACLRule):
|
||||
rules_count += 1
|
||||
# Therefore these statements below MUST be true
|
||||
assert computer_node_hardware_state == HardwareState.OFF
|
||||
assert num_of_rules == 1
|
||||
assert rules_count == 2
|
||||
|
||||
Reference in New Issue
Block a user