Merged PR 120: Change Functionality of ACL Rules
## Summary
### ACL List
First change was I changed `access_control_list.py` from a `dict` to a `list` so it is now an ordered structure. This was done so I could implement the positions inside the `ACL` and `ANY` action spaces.
From this, some functions have changed such as `add_rule` and `remove_rule`, `is_blocked` and `get_relevant_rules`.
The ACL list is now a fixed size and on initialisation it is filled with `None` types. When a function calls `self.acl` the `implicit rule` (if there is one) is added after the last `ACLRule` object in the list. The remainder of the list (if there is left over space) is padded out with `None`.
As the agent adds rules, the `None` are replaced by `ACLRule` objects and the agent cannot overwrite an existing `ACLRule` with another, it can only write over `None` types.
### ACL Training Config Changes
Changes have been made to the `training_config_main.yaml`. There are 2 new items:
`implicit_acl_rule:` - Implicit ACL firewall rule at end of list to be default action (ALLOW or DENY)
`max_number_acl_rules:` - Total number of ACL rules allowed in the environment
In the `OBSERVATION_SPACE` area of the config, `ACCESS_CONTROL_LIST` can be selected
They have default values if none are specified so for the older configs - these values are in the `TrainingConfig` dataclass.
### ACL and ANY Action Spaces
I changed the ACL space from length of 6 to 7. I have included the `position` of where the agent wants to position the ACL Rule.
`position` = index in `self.acl` with bounds [0 to ...]
As a result, total possible actions have gone up.
### ACL Observation Space
In the observations.py I have made a new observation component: Access Control List.
It has the following mappings/meanings:
[0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW)
[0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
[0, num nodes] - Dest IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
[0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol)
[0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port)
[0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list)
I created a new 0 meaning, which means NA and represents the None objects in the ACLList.
Also, there is no 'flatten' in the observation space components and this has been done in the observations.py now if there are multiple components.
## Test process
I have written tests in a new `TestAccessControlList` object in `test_observations.py`.
I ran a single test which was 1000 episodes, SB3/PPO, Config 5 and ACL Observation Space. I seemed to get some interesting results which may need investigating on Monday.

## Checklist
- ...
This commit is contained in:
@@ -66,11 +66,11 @@ The environment config file consists of the following attributes:
|
||||
.. code-block:: yaml
|
||||
|
||||
observation_space:
|
||||
flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
- name: NODE_STATUSES
|
||||
- name: LINK_TRAFFIC_LEVELS
|
||||
- name: ACCESS_CONTROL_LIST
|
||||
options:
|
||||
combine_service_traffic : False
|
||||
quantisation_levels: 99
|
||||
@@ -80,6 +80,7 @@ The environment config file consists of the following attributes:
|
||||
|
||||
* :py:mod:`NODE_LINK_TABLE<primaite.environment.observations.NodeLinkTable>` this does not accept any additional options
|
||||
* :py:mod:`NODE_STATUSES<primaite.environment.observations.NodeStatuses>`, this does not accept any additional options
|
||||
* :py:mod:`ACCESS_CONTROL_LIST<primaite.environment.observations.AccessControlList>`, this does not accept additional options
|
||||
* :py:mod:`LINK_TRAFFIC_LEVELS<primaite.environment.observations.LinkTrafficLevels>`, this accepts the following options:
|
||||
|
||||
* ``combine_service_traffic`` - whether to consider bandwidth use separately for each network protocol or combine them into a single bandwidth reading (boolean)
|
||||
@@ -128,6 +129,14 @@ The environment config file consists of the following attributes:
|
||||
|
||||
The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases
|
||||
|
||||
* **implicit_acl_rule** [str]
|
||||
|
||||
Determines which Explicit rule the ACL list has - two options are: DENY or ALLOW.
|
||||
|
||||
* **max_number_acl_rules** [int]
|
||||
|
||||
Sets a limit on how many ACL rules there can be in the ACL list throughout the training session.
|
||||
|
||||
**Reward-Based Config Values**
|
||||
|
||||
Rewards are calculated based on the difference between the current state and reference state (the 'should be' state) of the environment.
|
||||
@@ -477,3 +486,4 @@ The lay down config file consists of the following attributes:
|
||||
* **destination** [IP address]: Defines the destination IP address for the rule in xxx.xxx.xxx.xxx format
|
||||
* **protocol** [freetext]: Defines the protocol for the rule. Must match a value in the services list
|
||||
* **port** [int]: Defines the port for the rule. Must match a value in the ports list
|
||||
* **position** [int]: Defines where to place the ACL rule in the list. Lower index or (higher up in the list) means they are checked first. Index starts at 0 (Python indexes).
|
||||
|
||||
@@ -53,3 +53,5 @@ v1.2 to v2.0 Migration guide
|
||||
* hard coded agent view
|
||||
|
||||
Each of these items have default values which are designed so that PrimAITE has the same behaviour as it did in 1.2.0, so you do not have to specify them.
|
||||
|
||||
ACL Rules in laydown configs have a new required parameter: ``position``. The lower the position, the higher up in the ACL table the rule will placed. If you have custom laydowns, you will need to go through them and add a position to each ACL_RULE.
|
||||
|
||||
@@ -1,16 +1,38 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""A class that implements the access control list implementation for the network."""
|
||||
from typing import Dict
|
||||
import logging
|
||||
from typing import Dict, Final, List, Union
|
||||
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.common.enums import RulePermissionType
|
||||
|
||||
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AccessControlList:
|
||||
"""Access Control List class."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialise an empty AccessControlList."""
|
||||
self.acl: Dict[int, ACLRule] = {} # A dictionary of ACL Rules
|
||||
def __init__(self, implicit_permission: RulePermissionType, max_acl_rules: int) -> None:
|
||||
"""Init."""
|
||||
# Implicit ALLOW or DENY firewall spec
|
||||
self.acl_implicit_permission = implicit_permission
|
||||
# Implicit rule in ACL list
|
||||
if self.acl_implicit_permission == RulePermissionType.DENY:
|
||||
self.acl_implicit_rule = ACLRule(RulePermissionType.DENY, "ANY", "ANY", "ANY", "ANY")
|
||||
elif self.acl_implicit_permission == RulePermissionType.ALLOW:
|
||||
self.acl_implicit_rule = ACLRule(RulePermissionType.ALLOW, "ANY", "ANY", "ANY", "ANY")
|
||||
else:
|
||||
raise ValueError(f"implicit permission must be ALLOW or DENY, got {self.acl_implicit_permission}")
|
||||
|
||||
# Maximum number of ACL Rules in ACL
|
||||
self.max_acl_rules: int = max_acl_rules
|
||||
# A list of ACL Rules
|
||||
self._acl: List[Union[ACLRule, None]] = [None] * (self.max_acl_rules - 1)
|
||||
|
||||
@property
|
||||
def acl(self) -> List[Union[ACLRule, None]]:
|
||||
"""Public access method for private _acl."""
|
||||
return self._acl + [self.acl_implicit_rule]
|
||||
|
||||
def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool:
|
||||
"""Checks for IP address matches.
|
||||
@@ -47,21 +69,24 @@ class AccessControlList:
|
||||
Returns:
|
||||
Indicates block if all conditions are satisfied.
|
||||
"""
|
||||
for rule_key, rule_value in self.acl.items():
|
||||
if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address):
|
||||
if (rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY") and (
|
||||
str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY"
|
||||
):
|
||||
# There's a matching rule. Get the permission
|
||||
if rule_value.get_permission() == "DENY":
|
||||
return True
|
||||
elif rule_value.get_permission() == "ALLOW":
|
||||
return False
|
||||
for rule in self.acl:
|
||||
if isinstance(rule, ACLRule):
|
||||
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
|
||||
if (rule.get_protocol() == _protocol or rule.get_protocol() == "ANY") and (
|
||||
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY"
|
||||
):
|
||||
# There's a matching rule. Get the permission
|
||||
if rule.get_permission() == RulePermissionType.DENY:
|
||||
return True
|
||||
elif rule.get_permission() == RulePermissionType.ALLOW:
|
||||
return False
|
||||
|
||||
# If there has been no rule to allow the IER through, it will return a blocked signal by default
|
||||
return True
|
||||
|
||||
def add_rule(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None:
|
||||
def add_rule(
|
||||
self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str, _position: int
|
||||
) -> None:
|
||||
"""
|
||||
Adds a new rule.
|
||||
|
||||
@@ -71,12 +96,36 @@ class AccessControlList:
|
||||
_dest_ip: the destination IP address
|
||||
_protocol: the protocol
|
||||
_port: the port
|
||||
_position: position to insert ACL rule into ACL list (starting from index 1 and NOT 0)
|
||||
"""
|
||||
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
hash_value = hash(new_rule)
|
||||
self.acl[hash_value] = new_rule
|
||||
try:
|
||||
position_index = int(_position)
|
||||
except TypeError:
|
||||
_LOGGER.info(f"Position {_position} could not be converted to integer.")
|
||||
return
|
||||
|
||||
def remove_rule(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None:
|
||||
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
# Checks position is in correct range
|
||||
if self.max_acl_rules - 1 > position_index > -1:
|
||||
try:
|
||||
_LOGGER.info(f"Position {position_index} is valid.")
|
||||
# Check to see Agent will not overwrite current ACL in ACL list
|
||||
if self._acl[position_index] is None:
|
||||
_LOGGER.info(f"Inserting rule {new_rule} at position {position_index}")
|
||||
# Adds rule
|
||||
self._acl[position_index] = new_rule
|
||||
else:
|
||||
# Cannot overwrite it
|
||||
_LOGGER.info(f"Error: inserting rule at non-empty position {position_index}")
|
||||
return
|
||||
except Exception:
|
||||
_LOGGER.info(f"New Rule could NOT be added to list at position {position_index}.")
|
||||
else:
|
||||
_LOGGER.info(f"Position {position_index} is an invalid/overwrites implicit firewall rule")
|
||||
|
||||
def remove_rule(
|
||||
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
|
||||
) -> None:
|
||||
"""
|
||||
Removes a rule.
|
||||
|
||||
@@ -87,17 +136,17 @@ class AccessControlList:
|
||||
_protocol: the protocol
|
||||
_port: the port
|
||||
"""
|
||||
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
hash_value = hash(rule)
|
||||
# There will not always be something 'popable' since the agent will be trying random things
|
||||
try:
|
||||
self.acl.pop(hash_value)
|
||||
except Exception:
|
||||
return
|
||||
rule_to_delete = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
delete_rule_hash = hash(rule_to_delete)
|
||||
|
||||
for index in range(0, len(self._acl)):
|
||||
if isinstance(self._acl[index], ACLRule) and hash(self._acl[index]) == delete_rule_hash:
|
||||
self._acl[index] = None
|
||||
|
||||
def remove_all_rules(self) -> None:
|
||||
"""Removes all rules."""
|
||||
self.acl.clear()
|
||||
for i in range(len(self._acl)):
|
||||
self._acl[i] = None
|
||||
|
||||
def get_dictionary_hash(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> int:
|
||||
"""
|
||||
@@ -129,16 +178,13 @@ class AccessControlList:
|
||||
:return: Dictionary of all ACL rules that relate to the given arguments
|
||||
:rtype: Dict[int, ACLRule]
|
||||
"""
|
||||
relevant_rules: Dict[int, ACLRule] = {}
|
||||
|
||||
for rule_key, rule_value in self.acl.items():
|
||||
if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address):
|
||||
if (
|
||||
rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY" or _protocol == "ANY"
|
||||
) and (
|
||||
str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY" or str(_port) == "ANY"
|
||||
relevant_rules = {}
|
||||
for rule in self.acl:
|
||||
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
|
||||
if (rule.get_protocol() == _protocol or rule.get_protocol() == "ANY" or _protocol == "ANY") and (
|
||||
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY" or str(_port) == "ANY"
|
||||
):
|
||||
# There's a matching rule.
|
||||
relevant_rules[rule_key] = rule_value
|
||||
relevant_rules[self._acl.index(rule)] = rule
|
||||
|
||||
return relevant_rules
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""A class that implements an access control list rule."""
|
||||
from primaite.common.enums import RulePermissionType
|
||||
|
||||
|
||||
class ACLRule:
|
||||
"""Access Control List Rule class."""
|
||||
|
||||
def __init__(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None:
|
||||
def __init__(
|
||||
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an ACL Rule.
|
||||
|
||||
@@ -15,7 +18,7 @@ class ACLRule:
|
||||
:param _protocol: The rule protocol
|
||||
:param _port: The rule port
|
||||
"""
|
||||
self.permission: str = _permission
|
||||
self.permission: RulePermissionType = _permission
|
||||
self.source_ip: str = _source_ip
|
||||
self.dest_ip: str = _dest_ip
|
||||
self.protocol: str = _protocol
|
||||
|
||||
@@ -198,3 +198,11 @@ class SB3OutputVerboseLevel(IntEnum):
|
||||
NONE = 0
|
||||
INFO = 1
|
||||
DEBUG = 2
|
||||
|
||||
|
||||
class RulePermissionType(Enum):
|
||||
"""Any firewall rule type."""
|
||||
|
||||
NONE = 0
|
||||
DENY = 1
|
||||
ALLOW = 2
|
||||
|
||||
@@ -163,3 +163,4 @@
|
||||
destination: ANY
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 0
|
||||
|
||||
@@ -243,6 +243,7 @@
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 0
|
||||
- item_type: ACL_RULE
|
||||
id: '26'
|
||||
permission: ALLOW
|
||||
@@ -250,6 +251,7 @@
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 1
|
||||
- item_type: ACL_RULE
|
||||
id: '27'
|
||||
permission: ALLOW
|
||||
@@ -257,6 +259,7 @@
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 2
|
||||
- item_type: ACL_RULE
|
||||
id: '28'
|
||||
permission: ALLOW
|
||||
@@ -264,6 +267,7 @@
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 3
|
||||
- item_type: ACL_RULE
|
||||
id: '29'
|
||||
permission: ALLOW
|
||||
@@ -271,6 +275,7 @@
|
||||
destination: 192.168.10.13
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 4
|
||||
- item_type: ACL_RULE
|
||||
id: '30'
|
||||
permission: DENY
|
||||
@@ -278,6 +283,7 @@
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 5
|
||||
- item_type: ACL_RULE
|
||||
id: '31'
|
||||
permission: DENY
|
||||
@@ -285,6 +291,7 @@
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 6
|
||||
- item_type: ACL_RULE
|
||||
id: '32'
|
||||
permission: DENY
|
||||
@@ -292,6 +299,7 @@
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 7
|
||||
- item_type: ACL_RULE
|
||||
id: '33'
|
||||
permission: DENY
|
||||
@@ -299,6 +307,7 @@
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 8
|
||||
- item_type: RED_POL
|
||||
id: '34'
|
||||
start_step: 20
|
||||
|
||||
@@ -111,6 +111,7 @@
|
||||
destination: 192.168.1.4
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 0
|
||||
- item_type: ACL_RULE
|
||||
id: '12'
|
||||
permission: ALLOW
|
||||
@@ -118,6 +119,7 @@
|
||||
destination: 192.168.1.4
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 1
|
||||
- item_type: ACL_RULE
|
||||
id: '13'
|
||||
permission: ALLOW
|
||||
@@ -125,6 +127,7 @@
|
||||
destination: 192.168.1.3
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 2
|
||||
- item_type: RED_POL
|
||||
id: '14'
|
||||
start_step: 20
|
||||
|
||||
@@ -345,6 +345,7 @@
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 0
|
||||
- item_type: ACL_RULE
|
||||
id: '34'
|
||||
permission: ALLOW
|
||||
@@ -352,6 +353,7 @@
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 1
|
||||
- item_type: ACL_RULE
|
||||
id: '35'
|
||||
permission: ALLOW
|
||||
@@ -359,6 +361,7 @@
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 2
|
||||
- item_type: ACL_RULE
|
||||
id: '36'
|
||||
permission: ALLOW
|
||||
@@ -366,6 +369,7 @@
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 3
|
||||
- item_type: ACL_RULE
|
||||
id: '37'
|
||||
permission: ALLOW
|
||||
@@ -373,6 +377,7 @@
|
||||
destination: 192.168.10.11
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 4
|
||||
- item_type: ACL_RULE
|
||||
id: '38'
|
||||
permission: ALLOW
|
||||
@@ -380,6 +385,7 @@
|
||||
destination: 192.168.10.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 5
|
||||
- item_type: ACL_RULE
|
||||
id: '39'
|
||||
permission: ALLOW
|
||||
@@ -387,6 +393,7 @@
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 6
|
||||
- item_type: ACL_RULE
|
||||
id: '40'
|
||||
permission: ALLOW
|
||||
@@ -394,6 +401,7 @@
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 7
|
||||
- item_type: ACL_RULE
|
||||
id: '41'
|
||||
permission: ALLOW
|
||||
@@ -401,6 +409,7 @@
|
||||
destination: 192.168.2.16
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 8
|
||||
- item_type: ACL_RULE
|
||||
id: '42'
|
||||
permission: ALLOW
|
||||
@@ -408,6 +417,7 @@
|
||||
destination: 192.168.2.16
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 9
|
||||
- item_type: ACL_RULE
|
||||
id: '43'
|
||||
permission: ALLOW
|
||||
@@ -415,6 +425,7 @@
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 10
|
||||
- item_type: ACL_RULE
|
||||
id: '44'
|
||||
permission: ALLOW
|
||||
@@ -422,6 +433,7 @@
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 11
|
||||
- item_type: ACL_RULE
|
||||
id: '45'
|
||||
permission: ALLOW
|
||||
@@ -429,6 +441,7 @@
|
||||
destination: 192.168.2.16
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 12
|
||||
- item_type: ACL_RULE
|
||||
id: '46'
|
||||
permission: ALLOW
|
||||
@@ -436,6 +449,7 @@
|
||||
destination: 192.168.1.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 13
|
||||
- item_type: ACL_RULE
|
||||
id: '47'
|
||||
permission: ALLOW
|
||||
@@ -443,6 +457,7 @@
|
||||
destination: 192.168.1.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 14
|
||||
- item_type: ACL_RULE
|
||||
id: '48'
|
||||
permission: ALLOW
|
||||
@@ -450,6 +465,7 @@
|
||||
destination: 192.168.1.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 15
|
||||
- item_type: ACL_RULE
|
||||
id: '49'
|
||||
permission: DENY
|
||||
@@ -457,6 +473,7 @@
|
||||
destination: ANY
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 16
|
||||
- item_type: RED_POL
|
||||
id: '50'
|
||||
start_step: 50
|
||||
|
||||
@@ -35,7 +35,7 @@ random_red_agent: False
|
||||
# Default is None (null)
|
||||
seed: null
|
||||
|
||||
# Set whether the agent will be deterministic instead of stochastic
|
||||
# Set whether the agent evaluation will be deterministic instead of stochastic
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
@@ -51,15 +51,15 @@ hard_coded_agent_view: FULL
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: NODE
|
||||
action_type: ANY
|
||||
# observation space
|
||||
observation_space:
|
||||
# flatten: true
|
||||
flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
# - name: NODE_STATUSES
|
||||
# - name: LINK_TRAFFIC_LEVELS
|
||||
|
||||
- name: NODE_STATUSES
|
||||
- name: LINK_TRAFFIC_LEVELS
|
||||
- name: ACCESS_CONTROL_LIST
|
||||
|
||||
# Number of episodes for training to run per session
|
||||
num_train_episodes: 10
|
||||
@@ -90,6 +90,11 @@ session_type: TRAIN_EVAL
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
# Implicit ACL firewall rule at end of ACL list to be the default action (ALLOW or DENY)
|
||||
implicit_acl_rule: DENY
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 30
|
||||
|
||||
# The Stable Baselines3 learn/eval output verbosity level:
|
||||
# Options are:
|
||||
# "NONE" (No Output)
|
||||
|
||||
@@ -15,6 +15,7 @@ from primaite.common.enums import (
|
||||
AgentIdentifier,
|
||||
DeepLearningFramework,
|
||||
HardCodedAgentView,
|
||||
RulePermissionType,
|
||||
SB3OutputVerboseLevel,
|
||||
SessionType,
|
||||
)
|
||||
@@ -99,6 +100,12 @@ class TrainingConfig:
|
||||
sb3_output_verbose_level: SB3OutputVerboseLevel = SB3OutputVerboseLevel.NONE
|
||||
"Stable Baselines3 learn/eval output verbosity level"
|
||||
|
||||
implicit_acl_rule: RulePermissionType = RulePermissionType.DENY
|
||||
"ALLOW or DENY implicit firewall rule to go at the end of list of ACL list."
|
||||
|
||||
max_number_acl_rules: int = 30
|
||||
"Sets a limit for number of acl rules allowed in the list and environment."
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: float = 0
|
||||
@@ -207,6 +214,7 @@ class TrainingConfig:
|
||||
"session_type": SessionType,
|
||||
"sb3_output_verbose_level": SB3OutputVerboseLevel,
|
||||
"hard_coded_agent_view": HardCodedAgentView,
|
||||
"implicit_acl_rule": RulePermissionType,
|
||||
}
|
||||
|
||||
# convert the string representation of enums into the actual enum values themselves?
|
||||
@@ -233,6 +241,7 @@ class TrainingConfig:
|
||||
data["sb3_output_verbose_level"] = self.sb3_output_verbose_level.name
|
||||
data["session_type"] = self.session_type.name
|
||||
data["hard_coded_agent_view"] = self.hard_coded_agent_view.name
|
||||
data["implicit_acl_rule"] = self.implicit_acl_rule.name
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@@ -8,7 +8,8 @@ from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.common.enums import FileSystemState, HardwareState, RulePermissionType, SoftwareState
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
@@ -275,6 +276,7 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
services = self.env.services_list
|
||||
|
||||
structure = []
|
||||
|
||||
for _, node in self.env.nodes.items():
|
||||
node_id = node.node_id
|
||||
structure.append(f"node_{node_id}_hardware_state_NONE")
|
||||
@@ -403,6 +405,182 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
return structure
|
||||
|
||||
|
||||
class AccessControlList(AbstractObservationComponent):
|
||||
"""Flat list of all the Access Control Rules in the Access Control List.
|
||||
|
||||
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
|
||||
integers.
|
||||
|
||||
Each ACL Rule has 6 elements. It will have the following structure:
|
||||
.. code-block::
|
||||
[
|
||||
acl_rule1 permission,
|
||||
acl_rule1 source_ip,
|
||||
acl_rule1 dest_ip,
|
||||
acl_rule1 protocol,
|
||||
acl_rule1 port,
|
||||
acl_rule1 position,
|
||||
acl_rule2 permission,
|
||||
acl_rule2 source_ip,
|
||||
acl_rule2 dest_ip,
|
||||
acl_rule2 protocol,
|
||||
acl_rule2 port,
|
||||
acl_rule2 position,
|
||||
...
|
||||
]
|
||||
|
||||
|
||||
Terms (for ACL Observation Space):
|
||||
[0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW)
|
||||
[0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
|
||||
[0, num nodes] - Dest IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
|
||||
[0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol)
|
||||
[0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port)
|
||||
[0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list)
|
||||
|
||||
NOTE: NA is Non-Applicable - this means the ACL Rule in the list is a NoneType and NOT an ACLRule object.
|
||||
"""
|
||||
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite"):
|
||||
"""
|
||||
Initialise an AccessControlList observation component.
|
||||
|
||||
:param env: The environment that forms the basis of the observations
|
||||
:type env: Primaite
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
# The NA and ANY types means that there are 2 extra items for Nodes, Services and Ports.
|
||||
# Number of ACL rules incremented by 1 for positions starting at index 0.
|
||||
acl_shape = [
|
||||
len(RulePermissionType),
|
||||
len(env.nodes) + 2,
|
||||
len(env.nodes) + 2,
|
||||
len(env.services_list) + 2,
|
||||
len(env.ports_list) + 2,
|
||||
env.max_number_acl_rules,
|
||||
]
|
||||
shape = acl_shape * self.env.max_number_acl_rules
|
||||
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.MultiDiscrete(shape)
|
||||
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self) -> None:
|
||||
"""Update the observation based on current environment state.
|
||||
|
||||
The structure of the observation space is described in :class:`.AccessControlList`
|
||||
"""
|
||||
obs = []
|
||||
|
||||
for index in range(0, len(self.env.acl.acl)):
|
||||
acl_rule = self.env.acl.acl[index]
|
||||
if isinstance(acl_rule, ACLRule):
|
||||
permission = acl_rule.permission
|
||||
source_ip = acl_rule.source_ip
|
||||
dest_ip = acl_rule.dest_ip
|
||||
protocol = acl_rule.protocol
|
||||
port = acl_rule.port
|
||||
position = index
|
||||
# Map each ACL attribute from what it was to an integer to fit the observation space
|
||||
source_ip_int = None
|
||||
dest_ip_int = None
|
||||
if permission == RulePermissionType.DENY:
|
||||
permission_int = 1
|
||||
else:
|
||||
permission_int = 2
|
||||
if source_ip == "ANY":
|
||||
source_ip_int = 1
|
||||
else:
|
||||
# Map Node ID (+ 1) to source IP address
|
||||
nodes = list(self.env.nodes.values())
|
||||
for node in nodes:
|
||||
if (
|
||||
isinstance(node, ServiceNode) or isinstance(node, ActiveNode)
|
||||
) and node.ip_address == source_ip:
|
||||
source_ip_int = int(node.node_id) + 1
|
||||
break
|
||||
if dest_ip == "ANY":
|
||||
dest_ip_int = 1
|
||||
else:
|
||||
# Map Node ID (+ 1) to dest IP address
|
||||
# Index of Nodes start at 1 so + 1 is needed so NA can be added.
|
||||
nodes = list(self.env.nodes.values())
|
||||
for node in nodes:
|
||||
if (
|
||||
isinstance(node, ServiceNode) or isinstance(node, ActiveNode)
|
||||
) and node.ip_address == dest_ip:
|
||||
dest_ip_int = int(node.node_id) + 1
|
||||
if protocol == "ANY":
|
||||
protocol_int = 1
|
||||
else:
|
||||
# Index of protocols and ports start from 0 so + 2 is needed to add NA and ANY
|
||||
try:
|
||||
protocol_int = self.env.services_list.index(protocol) + 2
|
||||
except AttributeError:
|
||||
_LOGGER.info(f"Service {protocol} could not be found")
|
||||
protocol_int = None
|
||||
if port == "ANY":
|
||||
port_int = 1
|
||||
else:
|
||||
if port in self.env.ports_list:
|
||||
port_int = self.env.ports_list.index(port) + 2
|
||||
else:
|
||||
_LOGGER.info(f"Port {port} could not be found.")
|
||||
port_int = None
|
||||
# Add to current obs
|
||||
obs.extend(
|
||||
[
|
||||
permission_int,
|
||||
source_ip_int,
|
||||
dest_ip_int,
|
||||
protocol_int,
|
||||
port_int,
|
||||
position,
|
||||
]
|
||||
)
|
||||
|
||||
else:
|
||||
# The Nothing or NA representation of 'NONE' ACL rules
|
||||
obs.extend([0, 0, 0, 0, 0, 0])
|
||||
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
structure = []
|
||||
for acl_rule in self.env.acl.acl:
|
||||
acl_rule_id = self.env.acl.acl.index(acl_rule)
|
||||
|
||||
for permission in RulePermissionType:
|
||||
structure.append(f"acl_rule_{acl_rule_id}_permission_{permission.name}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_source_ip_ANY")
|
||||
for node in self.env.nodes.keys():
|
||||
structure.append(f"acl_rule_{acl_rule_id}_source_ip_{node}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_dest_ip_ANY")
|
||||
for node in self.env.nodes.keys():
|
||||
structure.append(f"acl_rule_{acl_rule_id}_dest_ip_{node}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_service_ANY")
|
||||
for service in self.env.services_list:
|
||||
structure.append(f"acl_rule_{acl_rule_id}_service_{service}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_port_ANY")
|
||||
for port in self.env.ports_list:
|
||||
structure.append(f"acl_rule_{acl_rule_id}_port_{port}")
|
||||
|
||||
return structure
|
||||
|
||||
|
||||
class ObservationsHandler:
|
||||
"""
|
||||
Component-based observation space handler.
|
||||
@@ -415,6 +593,7 @@ class ObservationsHandler:
|
||||
"NODE_LINK_TABLE": NodeLinkTable,
|
||||
"NODE_STATUSES": NodeStatuses,
|
||||
"LINK_TRAFFIC_LEVELS": LinkTrafficLevels,
|
||||
"ACCESS_CONTROL_LIST": AccessControlList,
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
@@ -430,8 +609,6 @@ class ObservationsHandler:
|
||||
# used for transactions and when flatten=true
|
||||
self._flat_observation: np.ndarray
|
||||
|
||||
self.flatten: bool = False
|
||||
|
||||
def update_obs(self) -> None:
|
||||
"""Fetch fresh information about the environment."""
|
||||
current_obs = []
|
||||
@@ -485,7 +662,7 @@ class ObservationsHandler:
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Observation space, return the flattened version if flatten is True."""
|
||||
if self.flatten:
|
||||
if len(self.registered_obs_components) > 1:
|
||||
return self._flat_space
|
||||
else:
|
||||
return self._space
|
||||
@@ -493,7 +670,7 @@ class ObservationsHandler:
|
||||
@property
|
||||
def current_observation(self) -> Union[np.ndarray, Tuple[np.ndarray]]:
|
||||
"""Current observation, return the flattened version if flatten is True."""
|
||||
if self.flatten:
|
||||
if len(self.registered_obs_components) > 1:
|
||||
return self._flat_observation
|
||||
else:
|
||||
return self._observation
|
||||
@@ -528,9 +705,6 @@ class ObservationsHandler:
|
||||
# Instantiate the handler
|
||||
handler = cls()
|
||||
|
||||
if obs_space_config.get("flatten"):
|
||||
handler.flatten = True
|
||||
|
||||
for component_cfg in obs_space_config["components"]:
|
||||
# Figure out which class can instantiate the desired component
|
||||
comp_type = component_cfg["name"]
|
||||
|
||||
@@ -125,7 +125,12 @@ class Primaite(Env):
|
||||
self.red_node_pol: Dict[str, NodeStateInstructionRed] = {}
|
||||
|
||||
# Create the Access Control List
|
||||
self.acl: AccessControlList = AccessControlList()
|
||||
self.acl: AccessControlList = AccessControlList(
|
||||
self.training_config.implicit_acl_rule,
|
||||
self.training_config.max_number_acl_rules,
|
||||
)
|
||||
# Sets limit for number of ACL rules in environment
|
||||
self.max_number_acl_rules: int = self.training_config.max_number_acl_rules
|
||||
|
||||
# Create a list of services (enums)
|
||||
self.services_list: List[str] = []
|
||||
@@ -458,12 +463,11 @@ class Primaite(Env):
|
||||
_action: The action space from the agent
|
||||
"""
|
||||
# At the moment, actions are only affecting nodes
|
||||
|
||||
if self.training_config.action_type == ActionType.NODE:
|
||||
self.apply_actions_to_nodes(_action)
|
||||
elif self.training_config.action_type == ActionType.ACL:
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif len(self.action_dict[_action]) == 6: # ACL actions in multidiscrete form have len 6
|
||||
elif len(self.action_dict[_action]) == 7: # ACL actions in multidiscrete form have len 7
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif len(self.action_dict[_action]) == 4: # Node actions in multdiscrete (array) from have len 4
|
||||
self.apply_actions_to_nodes(_action)
|
||||
@@ -574,6 +578,7 @@ class Primaite(Env):
|
||||
action_destination_ip = readable_action[3]
|
||||
action_protocol = readable_action[4]
|
||||
action_port = readable_action[5]
|
||||
acl_rule_position = readable_action[6]
|
||||
|
||||
if action_decision == 0:
|
||||
# It's decided to do nothing
|
||||
@@ -623,6 +628,7 @@ class Primaite(Env):
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_rule_position,
|
||||
)
|
||||
elif action_decision == 2:
|
||||
# Remove the rule
|
||||
@@ -1018,6 +1024,7 @@ class Primaite(Env):
|
||||
acl_rule_destination = item["destination"]
|
||||
acl_rule_protocol = item["protocol"]
|
||||
acl_rule_port = item["port"]
|
||||
acl_rule_position = item["position"]
|
||||
|
||||
self.acl.add_rule(
|
||||
acl_rule_permission,
|
||||
@@ -1025,6 +1032,7 @@ class Primaite(Env):
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_rule_position,
|
||||
)
|
||||
|
||||
# TODO: confirm typehint using runtime
|
||||
@@ -1182,6 +1190,11 @@ class Primaite(Env):
|
||||
...
|
||||
}
|
||||
"""
|
||||
# Terms (for node action space):
|
||||
# [0, num nodes] - node ID (0 = nothing, node ID)
|
||||
# [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa
|
||||
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa
|
||||
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
|
||||
# reserve 0 action to be a nothing action
|
||||
actions = {0: [1, 0, 0, 0]}
|
||||
action_key = 1
|
||||
@@ -1203,8 +1216,16 @@ class Primaite(Env):
|
||||
|
||||
def create_acl_action_dict(self) -> Dict[int, List[int]]:
|
||||
"""Creates a dictionary mapping each possible discrete action to more readable multidiscrete action."""
|
||||
# Terms (for ACL action space):
|
||||
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
|
||||
# [0, 1] - Permission (0 = DENY, 1 = ALLOW)
|
||||
# [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
|
||||
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
|
||||
# [0, max acl rules - 1] - Position (0 = first index, then 1 -> x index resolving to acl rule in acl list)
|
||||
# reserve 0 action to be a nothing action
|
||||
actions = {0: [0, 0, 0, 0, 0, 0]}
|
||||
actions = {0: [0, 0, 0, 0, 0, 0, 0]}
|
||||
|
||||
action_key = 1
|
||||
# 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE
|
||||
@@ -1216,18 +1237,21 @@ class Primaite(Env):
|
||||
for dest_ip in range(self.num_nodes + 1):
|
||||
for protocol in range(self.num_services + 1):
|
||||
for port in range(self.num_ports + 1):
|
||||
action = [
|
||||
action_decision,
|
||||
action_permission,
|
||||
source_ip,
|
||||
dest_ip,
|
||||
protocol,
|
||||
port,
|
||||
]
|
||||
# Check to see if its an action we want to include as possible i.e. not a nothing action
|
||||
if is_valid_acl_action_extra(action):
|
||||
actions[action_key] = action
|
||||
action_key += 1
|
||||
for position in range(self.max_number_acl_rules - 1):
|
||||
action = [
|
||||
action_decision,
|
||||
action_permission,
|
||||
source_ip,
|
||||
dest_ip,
|
||||
protocol,
|
||||
port,
|
||||
position,
|
||||
]
|
||||
# Check to see if it is an action we want to include as possible
|
||||
# i.e. not a nothing action
|
||||
if is_valid_acl_action_extra(action):
|
||||
actions[action_key] = action
|
||||
action_key += 1
|
||||
|
||||
return actions
|
||||
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from typing import TYPE_CHECKING
|
||||
from logging import Logger
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@@ -2,16 +2,13 @@
|
||||
import filecmp
|
||||
import os
|
||||
import shutil
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from primaite import getLogger, USERS_CONFIG_DIR
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -92,6 +92,7 @@
|
||||
destination: 192.168.1.2
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 0
|
||||
- item_type: ACL_RULE
|
||||
id: '7'
|
||||
permission: ALLOW
|
||||
@@ -99,3 +100,4 @@
|
||||
destination: 192.168.1.1
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 0
|
||||
|
||||
86
tests/config/obs_tests/laydown_ACL.yaml
Normal file
86
tests/config/obs_tests/laydown_ACL.yaml
Normal file
@@ -0,0 +1,86 @@
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '80'
|
||||
- port: '21'
|
||||
- item_type: SERVICES
|
||||
service_list:
|
||||
- name: TCP
|
||||
- name: FTP
|
||||
|
||||
########################################
|
||||
# Nodes
|
||||
- item_type: NODE
|
||||
node_id: '1'
|
||||
name: PC1
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.1
|
||||
software_state: COMPROMISED
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: FTP
|
||||
port: '21'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '2'
|
||||
name: SERVER
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.2
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: FTP
|
||||
port: '21'
|
||||
state: OVERWHELMED
|
||||
- item_type: NODE
|
||||
node_id: '3'
|
||||
name: SWITCH1
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.3
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
|
||||
########################################
|
||||
# Links
|
||||
- item_type: LINK
|
||||
id: '4'
|
||||
name: link1
|
||||
bandwidth: 1000
|
||||
source: '1'
|
||||
destination: '3'
|
||||
- item_type: LINK
|
||||
id: '5'
|
||||
name: link2
|
||||
bandwidth: 1000
|
||||
source: '3'
|
||||
destination: '2'
|
||||
|
||||
#########################################
|
||||
# IERS
|
||||
- item_type: GREEN_IER
|
||||
id: '5'
|
||||
start_step: 0
|
||||
end_step: 5
|
||||
load: 999
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '2'
|
||||
mission_criticality: 5
|
||||
|
||||
#########################################
|
||||
# ACL Rules
|
||||
106
tests/config/obs_tests/main_config_ACCESS_CONTROL_LIST.yaml
Normal file
106
tests/config/obs_tests/main_config_ACCESS_CONTROL_LIST.yaml
Normal file
@@ -0,0 +1,106 @@
|
||||
# Main Config File
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agent_framework: SB3
|
||||
agent_identifier: PPO
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: ANY
|
||||
# Number of episodes for training to run per session
|
||||
num_train_episodes: 1
|
||||
# Number of time_steps for training per episode
|
||||
num_train_steps: 5
|
||||
|
||||
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
|
||||
implicit_acl_rule: DENY
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 3
|
||||
|
||||
observation_space:
|
||||
components:
|
||||
- name: ACCESS_CONTROL_LIST
|
||||
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 1
|
||||
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAIN
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1_000_000_000
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0
|
||||
# Node Hardware State
|
||||
off_should_be_on: -10
|
||||
off_should_be_resetting: -5
|
||||
on_should_be_off: -2
|
||||
on_should_be_resetting: -5
|
||||
resetting_should_be_on: -5
|
||||
resetting_should_be_off: -2
|
||||
resetting: -3
|
||||
# Node Software or Service State
|
||||
good_should_be_patching: 2
|
||||
good_should_be_compromised: 5
|
||||
good_should_be_overwhelmed: 5
|
||||
patching_should_be_good: -5
|
||||
patching_should_be_compromised: 2
|
||||
patching_should_be_overwhelmed: 2
|
||||
patching: -3
|
||||
compromised_should_be_good: -20
|
||||
compromised_should_be_patching: -20
|
||||
compromised_should_be_overwhelmed: -20
|
||||
compromised: -20
|
||||
overwhelmed_should_be_good: -20
|
||||
overwhelmed_should_be_patching: -20
|
||||
overwhelmed_should_be_compromised: -20
|
||||
overwhelmed: -20
|
||||
# Node File System State
|
||||
good_should_be_repairing: 2
|
||||
good_should_be_restoring: 2
|
||||
good_should_be_corrupt: 5
|
||||
good_should_be_destroyed: 10
|
||||
repairing_should_be_good: -5
|
||||
repairing_should_be_restoring: 2
|
||||
repairing_should_be_corrupt: 2
|
||||
repairing_should_be_destroyed: 0
|
||||
repairing: -3
|
||||
restoring_should_be_good: -10
|
||||
restoring_should_be_repairing: -2
|
||||
restoring_should_be_corrupt: 1
|
||||
restoring_should_be_destroyed: 2
|
||||
restoring: -6
|
||||
corrupt_should_be_good: -10
|
||||
corrupt_should_be_repairing: -10
|
||||
corrupt_should_be_restoring: -10
|
||||
corrupt_should_be_destroyed: 2
|
||||
corrupt: -10
|
||||
destroyed_should_be_good: -20
|
||||
destroyed_should_be_repairing: -20
|
||||
destroyed_should_be_restoring: -20
|
||||
destroyed_should_be_corrupt: -20
|
||||
destroyed: -20
|
||||
scanning: -2
|
||||
# IER status
|
||||
red_ier_running: -5
|
||||
green_ier_blocked: -10
|
||||
|
||||
# Patching / Reset durations
|
||||
os_patching_duration: 5 # The time taken to patch the OS
|
||||
node_reset_duration: 5 # The time taken to reset a node (hardware)
|
||||
service_patching_duration: 5 # The time taken to patch a service
|
||||
file_system_repairing_limit: 5 # The time take to repair the file system
|
||||
file_system_restoring_limit: 5 # The time take to restore the file system
|
||||
file_system_scanning_limit: 5 # The time taken to scan the file system
|
||||
@@ -39,6 +39,11 @@ observation_space:
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 1
|
||||
|
||||
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
|
||||
implicit_acl_rule: ALLOW
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 4
|
||||
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAIN
|
||||
# Determine whether to load an agent from file
|
||||
|
||||
@@ -37,6 +37,11 @@ observation_space:
|
||||
time_delay: 1
|
||||
# Filename of the scenario / laydown
|
||||
|
||||
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
|
||||
implicit_acl_rule: ALLOW
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 4
|
||||
|
||||
session_type: TRAIN
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
|
||||
@@ -40,7 +40,8 @@ agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1_000_000_000
|
||||
|
||||
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
|
||||
implicit_acl_rule: DENY
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0
|
||||
|
||||
@@ -91,6 +91,8 @@ session_type: EVAL
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
implicit_acl_rule: DENY
|
||||
max_number_acl_rules: 10
|
||||
# The Stable Baselines3 learn/eval output verbosity level:
|
||||
# Options are:
|
||||
# "NONE" (No Output)
|
||||
|
||||
@@ -36,7 +36,7 @@ random_red_agent: False
|
||||
# Default is None (null)
|
||||
seed: None
|
||||
|
||||
# Set whether the agent will be deterministic instead of stochastic
|
||||
# Set whether the agent evaluation will be deterministic instead of stochastic
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
@@ -55,11 +55,11 @@ hard_coded_agent_view: FULL
|
||||
action_type: NODE
|
||||
# observation space
|
||||
observation_space:
|
||||
# flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
# - name: NODE_STATUSES
|
||||
# - name: LINK_TRAFFIC_LEVELS
|
||||
# - name: ACCESS_CONTROL_LIST
|
||||
# Number of episodes to run per session
|
||||
num_train_episodes: 10
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ random_red_agent: False
|
||||
# Default is None (null)
|
||||
seed: 67890
|
||||
|
||||
# Set whether the agent will be deterministic instead of stochastic
|
||||
# Set whether the agent evaluation will be deterministic instead of stochastic
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
@@ -55,7 +55,6 @@ hard_coded_agent_view: FULL
|
||||
action_type: NODE
|
||||
# observation space
|
||||
observation_space:
|
||||
# flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
# - name: NODE_STATUSES
|
||||
|
||||
@@ -38,6 +38,15 @@ load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
|
||||
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
|
||||
implicit_acl_rule: DENY
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 10
|
||||
|
||||
observation_space:
|
||||
components:
|
||||
- name: ACCESS_CONTROL_LIST
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '21'
|
||||
- port: '80'
|
||||
- item_type: SERVICES
|
||||
service_list:
|
||||
- name: ftp
|
||||
- name: TCP
|
||||
- item_type: NODE
|
||||
node_id: '1'
|
||||
name: node
|
||||
@@ -16,8 +16,8 @@
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: ftp
|
||||
port: '21'
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: COMPROMISED
|
||||
- item_type: NODE
|
||||
node_id: '2'
|
||||
@@ -30,15 +30,15 @@
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: ftp
|
||||
port: '21'
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: COMPROMISED
|
||||
- item_type: RED_IER
|
||||
id: '3'
|
||||
start_step: 2
|
||||
end_step: 15
|
||||
load: 1000
|
||||
protocol: ftp
|
||||
protocol: TCP
|
||||
port: CORRUPT
|
||||
source: '1'
|
||||
destination: '2'
|
||||
|
||||
@@ -47,6 +47,9 @@ agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
# Choice whether to have an ALLOW or DENY implicit rule or not (TRUE or FALSE)
|
||||
implicit_acl_rule: DENY
|
||||
max_number_acl_rules: 10
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0
|
||||
|
||||
@@ -3,40 +3,41 @@
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.common.enums import RulePermissionType
|
||||
|
||||
|
||||
def test_acl_address_match_1():
|
||||
"""Test that matching IP addresses produce True."""
|
||||
acl = AccessControlList()
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
|
||||
|
||||
|
||||
def test_acl_address_match_2():
|
||||
"""Test that mismatching IP addresses produce False."""
|
||||
acl = AccessControlList()
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.3") == False
|
||||
|
||||
|
||||
def test_acl_address_match_3():
|
||||
"""Test the ANY condition for source IP addresses produce True."""
|
||||
acl = AccessControlList()
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule("ALLOW", "ANY", "192.168.1.2", "TCP", "80")
|
||||
rule = ACLRule(RulePermissionType.ALLOW, "ANY", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
|
||||
|
||||
|
||||
def test_acl_address_match_4():
|
||||
"""Test the ANY condition for dest IP addresses produce True."""
|
||||
acl = AccessControlList()
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule("ALLOW", "192.168.1.1", "ANY", "TCP", "80")
|
||||
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "ANY", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
|
||||
|
||||
@@ -44,14 +45,15 @@ def test_acl_address_match_4():
|
||||
def test_check_acl_block_affirmative():
|
||||
"""Test the block function (affirmative)."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList()
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
# Create a rule
|
||||
acl_rule_permission = "ALLOW"
|
||||
acl_rule_permission = RulePermissionType.ALLOW
|
||||
acl_rule_source = "192.168.1.1"
|
||||
acl_rule_destination = "192.168.1.2"
|
||||
acl_rule_protocol = "TCP"
|
||||
acl_rule_port = "80"
|
||||
acl_position_in_list = "0"
|
||||
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
@@ -59,22 +61,23 @@ def test_check_acl_block_affirmative():
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_position_in_list,
|
||||
)
|
||||
|
||||
assert acl.is_blocked("192.168.1.1", "192.168.1.2", "TCP", "80") == False
|
||||
|
||||
|
||||
def test_check_acl_block_negative():
|
||||
"""Test the block function (negative)."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList()
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
# Create a rule
|
||||
acl_rule_permission = "DENY"
|
||||
acl_rule_permission = RulePermissionType.DENY
|
||||
acl_rule_source = "192.168.1.1"
|
||||
acl_rule_destination = "192.168.1.2"
|
||||
acl_rule_protocol = "TCP"
|
||||
acl_rule_port = "80"
|
||||
acl_position_in_list = "0"
|
||||
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
@@ -82,6 +85,7 @@ def test_check_acl_block_negative():
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_position_in_list,
|
||||
)
|
||||
|
||||
assert acl.is_blocked("192.168.1.1", "192.168.1.2", "TCP", "80") == True
|
||||
@@ -90,11 +94,73 @@ def test_check_acl_block_negative():
|
||||
def test_rule_hash():
|
||||
"""Test the rule hash."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList()
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
rule = ACLRule(RulePermissionType.DENY, "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
hash_value_local = hash(rule)
|
||||
|
||||
hash_value_remote = acl.get_dictionary_hash("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
hash_value_remote = acl.get_dictionary_hash(RulePermissionType.DENY, "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert hash_value_local == hash_value_remote
|
||||
|
||||
|
||||
def test_delete_rule():
|
||||
"""Adds 3 rules and deletes 1 rule and checks its deletion."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList(RulePermissionType.ALLOW, 10)
|
||||
|
||||
# Create a first rule
|
||||
acl_rule_permission = RulePermissionType.DENY
|
||||
acl_rule_source = "192.168.1.1"
|
||||
acl_rule_destination = "192.168.1.2"
|
||||
acl_rule_protocol = "TCP"
|
||||
acl_rule_port = "80"
|
||||
acl_position_in_list = "0"
|
||||
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_position_in_list,
|
||||
)
|
||||
|
||||
# Create a second rule
|
||||
acl_rule_permission = RulePermissionType.DENY
|
||||
acl_rule_source = "20"
|
||||
acl_rule_destination = "30"
|
||||
acl_rule_protocol = "FTP"
|
||||
acl_rule_port = "21"
|
||||
acl_position_in_list = "2"
|
||||
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_position_in_list,
|
||||
)
|
||||
|
||||
# Create a third rule
|
||||
acl_rule_permission = RulePermissionType.ALLOW
|
||||
acl_rule_source = "192.168.1.3"
|
||||
acl_rule_destination = "192.168.1.1"
|
||||
acl_rule_protocol = "UDP"
|
||||
acl_rule_port = "60"
|
||||
acl_position_in_list = "4"
|
||||
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_position_in_list,
|
||||
)
|
||||
# Remove the second ACL rule added from the list
|
||||
acl.remove_rule(RulePermissionType.DENY, "20", "30", "FTP", "21")
|
||||
|
||||
assert len(acl.acl) == 10
|
||||
assert acl.acl[2] is None
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Test env creation and behaviour with different observation spaces."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
@@ -237,3 +238,140 @@ class TestLinkTrafficLevels:
|
||||
# therefore the first and third elements should be 6 and all others 0
|
||||
# (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%)
|
||||
assert np.array_equal(obs, [6, 0, 6, 0])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "obs_tests/main_config_ACCESS_CONTROL_LIST.yaml",
|
||||
TEST_CONFIG_ROOT / "obs_tests/laydown_ACL.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
class TestAccessControlList:
|
||||
"""Test the AccessControlList observation component (in isolation)."""
|
||||
|
||||
def test_obs_shape(self, temp_primaite_session):
|
||||
"""Try creating env with MultiDiscrete observation space.
|
||||
|
||||
The laydown has 3 ACL Rules - that is the maximum_acl_rules it can have.
|
||||
Each ACL Rule in the observation space has 6 different elements:
|
||||
|
||||
6 * 3 = 18
|
||||
"""
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
env.update_environent_obs()
|
||||
|
||||
assert env.env_obs.shape == (18,)
|
||||
|
||||
def test_values(self, temp_primaite_session):
|
||||
"""Test that traffic values are encoded correctly.
|
||||
|
||||
The laydown has:
|
||||
* one ACL IMPLICIT DENY rule
|
||||
|
||||
Therefore, the ACL is full of NAs aka zeros and just 6 non-zero elements representing DENY ANY ANY ANY at
|
||||
Position 2.
|
||||
"""
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
obs, reward, done, info = env.step(0)
|
||||
obs, reward, done, info = env.step(0)
|
||||
|
||||
assert np.array_equal(obs, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2])
|
||||
|
||||
def test_observation_space_with_implicit_rule(self, temp_primaite_session):
|
||||
"""
|
||||
Test observation space is what is expected when an agent adds ACLs during an episode.
|
||||
|
||||
At the start of the episode, there is a single implicit DENY rule
|
||||
In the observation space IMPLICIT DENY: 1,1,1,1,1,0
|
||||
0 shows the rule is the start (when episode began no other rules were created) so this is correct.
|
||||
|
||||
On Step 2, there is an ACL rule added at Position 0: 2,2,3,2,3,0
|
||||
|
||||
On Step 4, there is a second ACL rule added at POSITION 1: 2,4,2,3,3,1
|
||||
|
||||
The final observation space should be this:
|
||||
[2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2]
|
||||
|
||||
The ACL Rule from Step 2 is added first and has a HIGHER position than the ACL rule from Step 4
|
||||
but both come before the IMPLICIT DENY which will ALWAYS be at the end of the ACL List.
|
||||
"""
|
||||
# TODO: Refactor this at some point to build a custom ACL Hardcoded
|
||||
# Agent and then patch the AgentIdentifier Enum class so that it
|
||||
# has ACL_AGENT. This then allows us to set the agent identified in
|
||||
# the main config and is a bit cleaner.
|
||||
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
training_config = env.training_config
|
||||
for episode in range(0, training_config.num_train_episodes):
|
||||
for step in range(0, training_config.num_train_steps):
|
||||
# Do nothing action
|
||||
action = 0
|
||||
if step == 2:
|
||||
# Action to add the first ACL rule
|
||||
action = 43
|
||||
elif step == 4:
|
||||
# Action to add the second ACL rule
|
||||
action = 96
|
||||
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
# Break if done is True
|
||||
if done:
|
||||
break
|
||||
obs = env.env_obs
|
||||
|
||||
assert np.array_equal(obs, [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2])
|
||||
|
||||
def test_observation_space_with_different_positions(self, temp_primaite_session):
|
||||
"""
|
||||
Test observation space is what is expected when an agent adds ACLs during an episode.
|
||||
|
||||
At the start of the episode, there is a single implicit DENY rule
|
||||
In the observation space IMPLICIT DENY: 1,1,1,1,1,0
|
||||
0 shows the rule is the start (when episode began no other rules were created) so this is correct.
|
||||
|
||||
On Step 2, there is an ACL rule added at Position 1: 2,2,3,2,3,1
|
||||
|
||||
On Step 4 there is a second ACL rule added at Position 0: 2,4,2,3,3,0
|
||||
|
||||
The final observation space should be this:
|
||||
[2 , 4, 2, 3, 3, 0, 2, 2, 3, 2, 3, 1, 1, 1, 1, 1, 1, 2]
|
||||
|
||||
The ACL Rule from Step 2 is added before and has a LOWER position than the ACL rule from Step 4
|
||||
but both come before the IMPLICIT DENY which will ALWAYS be at the end of the ACL List.
|
||||
"""
|
||||
# TODO: Refactor this at some point to build a custom ACL Hardcoded
|
||||
# Agent and then patch the AgentIdentifier Enum class so that it
|
||||
# has ACL_AGENT. This then allows us to set the agent identified in
|
||||
# the main config and is a bit cleaner.
|
||||
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
training_config = env.training_config
|
||||
for episode in range(0, training_config.num_train_episodes):
|
||||
for step in range(0, training_config.num_train_steps):
|
||||
# Do nothing action
|
||||
action = 0
|
||||
if step == 2:
|
||||
# Action to add the first ACL rule
|
||||
action = 44
|
||||
elif step == 4:
|
||||
# Action to add the second ACL rule
|
||||
action = 95
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
# Break if done is True
|
||||
if done:
|
||||
break
|
||||
obs = env.env_obs
|
||||
|
||||
assert np.array_equal(obs, [2, 4, 2, 3, 3, 0, 2, 2, 3, 2, 3, 1, 1, 1, 1, 1, 1, 2])
|
||||
|
||||
@@ -11,30 +11,46 @@ from tests import TEST_CONFIG_ROOT
|
||||
indirect=True,
|
||||
)
|
||||
def test_seeded_learning(temp_primaite_session):
|
||||
"""Test running seeded learning produces the same output when ran twice."""
|
||||
"""
|
||||
Test running seeded learning produces the same output when ran twice.
|
||||
|
||||
.. note::
|
||||
|
||||
If this is failing, the hard-coded expected_mean_reward_per_episode
|
||||
from a pre-trained agent will probably need to be updated. If the
|
||||
env changes and those changed how this agent is trained, chances are
|
||||
the mean rewards are going to be different.
|
||||
|
||||
Run the test, but print out the session.learn_av_reward_per_episode()
|
||||
before comparing it. Then copy the printed dict and replace the
|
||||
expected_mean_reward_per_episode with those values. The test should
|
||||
now work. If not, then you've got a bug :).
|
||||
"""
|
||||
expected_mean_reward_per_episode = {
|
||||
1: -90.703125,
|
||||
2: -91.15234375,
|
||||
3: -87.5,
|
||||
4: -92.2265625,
|
||||
5: -94.6875,
|
||||
6: -91.19140625,
|
||||
7: -88.984375,
|
||||
8: -88.3203125,
|
||||
9: -112.79296875,
|
||||
10: -100.01953125,
|
||||
1: -20.7421875,
|
||||
2: -19.82421875,
|
||||
3: -17.01171875,
|
||||
4: -19.08203125,
|
||||
5: -21.93359375,
|
||||
6: -20.21484375,
|
||||
7: -15.546875,
|
||||
8: -12.08984375,
|
||||
9: -17.59765625,
|
||||
10: -14.6875,
|
||||
}
|
||||
|
||||
with temp_primaite_session as session:
|
||||
assert session._training_config.seed == 67890, (
|
||||
"Expected output is based upon a agent that was trained with " "seed 67890"
|
||||
)
|
||||
assert (
|
||||
session._training_config.seed == 67890
|
||||
), "Expected output is based upon a agent that was trained with seed 67890"
|
||||
session.learn()
|
||||
actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict()
|
||||
print(actual_mean_reward_per_episode, "THISt")
|
||||
|
||||
assert actual_mean_reward_per_episode == expected_mean_reward_per_episode
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Inconsistent results. Needs someone with RL " "knowledge to investigate further.")
|
||||
@pytest.mark.skip(reason="Inconsistent results. Needs someone with RL knowledge to investigate further.")
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[[TEST_CONFIG_ROOT / "ppo_seeded_training_config.yaml", dos_very_basic_config_path()]],
|
||||
|
||||
@@ -43,6 +43,9 @@ def copy_session_asset(asset_path: Union[str, Path]) -> str:
|
||||
return copy_path
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Loading works fine but the exact values change with code changes, a bug report has been created."
|
||||
)
|
||||
def test_load_sb3_session():
|
||||
"""Test that loading an SB3 agent works."""
|
||||
expected_learn_mean_reward_per_episode = {
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.common.enums import HardwareState
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
@@ -19,16 +20,17 @@ def run_generic_set_actions(env: Primaite):
|
||||
# TEMP - random action for now
|
||||
# action = env.blue_agent_action(obs)
|
||||
action = 0
|
||||
# print("Episode:", episode, "\nStep:", step)
|
||||
if step == 5:
|
||||
# [1, 1, 2, 1, 1, 1]
|
||||
# [1, 1, 2, 1, 1, 1, 1(position)]
|
||||
# Creates an ACL rule
|
||||
# Allows traffic from server_1 to node_1 on port FTP
|
||||
action = 7
|
||||
action = 56
|
||||
elif step == 7:
|
||||
# [1, 1, 2, 0] Node Action
|
||||
# Sets Node 1 Hardware State to OFF
|
||||
# Does not resolve any service
|
||||
action = 16
|
||||
action = 128
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
@@ -57,6 +59,10 @@ def run_generic_set_actions(env: Primaite):
|
||||
)
|
||||
def test_single_action_space_is_valid(temp_primaite_session):
|
||||
"""Test single action space is valid."""
|
||||
# TODO: Refactor this at some point to build a custom ACL Hardcoded
|
||||
# Agent and then patch the AgentIdentifier Enum class so that it
|
||||
# has ACL_AGENT. This then allows us to set the agent identified in
|
||||
# the main config and is a bit cleaner.
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
|
||||
@@ -73,7 +79,7 @@ def test_single_action_space_is_valid(temp_primaite_session):
|
||||
if len(dict_item) == 4:
|
||||
contains_node_actions = True
|
||||
# Link action detected
|
||||
elif len(dict_item) == 6:
|
||||
elif len(dict_item) == 7:
|
||||
contains_acl_actions = True
|
||||
# If both are there then the ANY action type is working
|
||||
if contains_node_actions and contains_acl_actions:
|
||||
@@ -94,6 +100,10 @@ def test_single_action_space_is_valid(temp_primaite_session):
|
||||
)
|
||||
def test_agent_is_executing_actions_from_both_spaces(temp_primaite_session):
|
||||
"""Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL)."""
|
||||
# TODO: Refactor this at some point to build a custom ACL Hardcoded
|
||||
# Agent and then patch the AgentIdentifier Enum class so that it
|
||||
# has ACL_AGENT. This then allows us to set the agent identified in
|
||||
# the main config and is a bit cleaner.
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
# Run environment with specified fixed blue agent actions only
|
||||
@@ -105,11 +115,15 @@ def test_agent_is_executing_actions_from_both_spaces(temp_primaite_session):
|
||||
access_control_list = env.acl
|
||||
# Use the Access Control List object acl object attribute to get dictionary
|
||||
# Use dictionary.values() to get total list of all items in the dictionary
|
||||
acl_rules_list = access_control_list.acl.values()
|
||||
acl_rules_list = access_control_list.acl
|
||||
# Length of this list tells you how many items are in the dictionary
|
||||
# This number is the frequency of Access Control Rules in the environment
|
||||
# In the scenario, we specified that the agent should create only 1 acl rule
|
||||
num_of_rules = len(acl_rules_list)
|
||||
# This 1 rule added to the implicit deny means there should be 2 rules in total.
|
||||
rules_count = 0
|
||||
for rule in acl_rules_list:
|
||||
if isinstance(rule, ACLRule):
|
||||
rules_count += 1
|
||||
# Therefore these statements below MUST be true
|
||||
assert computer_node_hardware_state == HardwareState.OFF
|
||||
assert num_of_rules == 1
|
||||
assert rules_count == 2
|
||||
|
||||
Reference in New Issue
Block a user