Merged PR 120: Change Functionality of ACL Rules

## Summary
### ACL List
First change was I changed `access_control_list.py` from a `dict` to a `list` so it is now an ordered structure. This was done so I could implement the positions inside the `ACL` and `ANY` action spaces.

From this, some functions have changed such as `add_rule` and `remove_rule`, `is_blocked` and `get_relevant_rules`.

The ACL list is now a fixed size and on initialisation it is filled with `None` types. When a function calls `self.acl` the `implicit rule` (if there is one) is added after the last `ACLRule` object in the list. The remainder of the list (if there is left over space) is padded out with `None`.

As the agent adds rules, the `None` are replaced by `ACLRule` objects and the agent cannot overwrite an existing `ACLRule` with another, it can only write over `None` types.

### ACL Training Config Changes
Changes have been made to the `training_config_main.yaml`. There are 2 new items:

`implicit_acl_rule:` - Implicit ACL firewall rule at end of list to be default action (ALLOW or DENY)
`max_number_acl_rules:` - Total number of ACL rules allowed in the environment

In the `OBSERVATION_SPACE` area of the config, `ACCESS_CONTROL_LIST` can be selected

They have default values if none are specified so for the older configs - these values are in the `TrainingConfig` dataclass.

### ACL and ANY Action Spaces
I changed the ACL space from length of 6 to 7. I have included the `position` of where the agent wants to position the ACL Rule.

`position` = index in `self.acl` with bounds [0 to ...]

As a result, total possible actions have gone up.

### ACL Observation Space
In the observations.py I have made a new observation component: Access Control List.
It has the following mappings/meanings:

        [0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW)
        [0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
        [0, num nodes] - Dest IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
        [0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol)
        [0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port)
        [0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list)

I created a new 0 meaning, which means NA and represents the None objects in the ACLList.

Also, there is no 'flatten' in the observation space components and this has been done in the observations.py now if there are multiple components.

## Test process
I have written tests in a new `TestAccessControlList` object in `test_observations.py`.

I ran a single test which was 1000 episodes, SB3/PPO, Config 5 and ACL Observation Space. I seemed to get some interesting results which may need investigating on Monday.

![Figure_1.png](https://dev.azure.com/ma-dev-uk/b50a61ee-86c4-48bc-9a0b-a67645ba12ee/_apis/git/repositories/2825053e-bd3b-45b2-8680-1281809eefa2/pullRequests/120/attachments/Figure_1.png)

## Checklist
- ...
This commit is contained in:
Sunil Samra
2023-07-18 10:31:15 +00:00
33 changed files with 888 additions and 128 deletions

View File

@@ -66,11 +66,11 @@ The environment config file consists of the following attributes:
.. code-block:: yaml
observation_space:
flatten: true
components:
- name: NODE_LINK_TABLE
- name: NODE_STATUSES
- name: LINK_TRAFFIC_LEVELS
- name: ACCESS_CONTROL_LIST
options:
combine_service_traffic : False
quantisation_levels: 99
@@ -80,6 +80,7 @@ The environment config file consists of the following attributes:
* :py:mod:`NODE_LINK_TABLE<primaite.environment.observations.NodeLinkTable>` this does not accept any additional options
* :py:mod:`NODE_STATUSES<primaite.environment.observations.NodeStatuses>`, this does not accept any additional options
* :py:mod:`ACCESS_CONTROL_LIST<primaite.environment.observations.AccessControlList>`, this does not accept additional options
* :py:mod:`LINK_TRAFFIC_LEVELS<primaite.environment.observations.LinkTrafficLevels>`, this accepts the following options:
* ``combine_service_traffic`` - whether to consider bandwidth use separately for each network protocol or combine them into a single bandwidth reading (boolean)
@@ -128,6 +129,14 @@ The environment config file consists of the following attributes:
The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases
* **implicit_acl_rule** [str]
Determines which Explicit rule the ACL list has - two options are: DENY or ALLOW.
* **max_number_acl_rules** [int]
Sets a limit on how many ACL rules there can be in the ACL list throughout the training session.
**Reward-Based Config Values**
Rewards are calculated based on the difference between the current state and reference state (the 'should be' state) of the environment.
@@ -477,3 +486,4 @@ The lay down config file consists of the following attributes:
* **destination** [IP address]: Defines the destination IP address for the rule in xxx.xxx.xxx.xxx format
* **protocol** [freetext]: Defines the protocol for the rule. Must match a value in the services list
* **port** [int]: Defines the port for the rule. Must match a value in the ports list
* **position** [int]: Defines where to place the ACL rule in the list. Lower index or (higher up in the list) means they are checked first. Index starts at 0 (Python indexes).

View File

@@ -53,3 +53,5 @@ v1.2 to v2.0 Migration guide
* hard coded agent view
Each of these items have default values which are designed so that PrimAITE has the same behaviour as it did in 1.2.0, so you do not have to specify them.
ACL Rules in laydown configs have a new required parameter: ``position``. The lower the position, the higher up in the ACL table the rule will placed. If you have custom laydowns, you will need to go through them and add a position to each ACL_RULE.

View File

@@ -1,16 +1,38 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""A class that implements the access control list implementation for the network."""
from typing import Dict
import logging
from typing import Dict, Final, List, Union
from primaite.acl.acl_rule import ACLRule
from primaite.common.enums import RulePermissionType
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
class AccessControlList:
"""Access Control List class."""
def __init__(self) -> None:
"""Initialise an empty AccessControlList."""
self.acl: Dict[int, ACLRule] = {} # A dictionary of ACL Rules
def __init__(self, implicit_permission: RulePermissionType, max_acl_rules: int) -> None:
"""Init."""
# Implicit ALLOW or DENY firewall spec
self.acl_implicit_permission = implicit_permission
# Implicit rule in ACL list
if self.acl_implicit_permission == RulePermissionType.DENY:
self.acl_implicit_rule = ACLRule(RulePermissionType.DENY, "ANY", "ANY", "ANY", "ANY")
elif self.acl_implicit_permission == RulePermissionType.ALLOW:
self.acl_implicit_rule = ACLRule(RulePermissionType.ALLOW, "ANY", "ANY", "ANY", "ANY")
else:
raise ValueError(f"implicit permission must be ALLOW or DENY, got {self.acl_implicit_permission}")
# Maximum number of ACL Rules in ACL
self.max_acl_rules: int = max_acl_rules
# A list of ACL Rules
self._acl: List[Union[ACLRule, None]] = [None] * (self.max_acl_rules - 1)
@property
def acl(self) -> List[Union[ACLRule, None]]:
"""Public access method for private _acl."""
return self._acl + [self.acl_implicit_rule]
def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool:
"""Checks for IP address matches.
@@ -47,21 +69,24 @@ class AccessControlList:
Returns:
Indicates block if all conditions are satisfied.
"""
for rule_key, rule_value in self.acl.items():
if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address):
if (rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY") and (
str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY"
):
# There's a matching rule. Get the permission
if rule_value.get_permission() == "DENY":
return True
elif rule_value.get_permission() == "ALLOW":
return False
for rule in self.acl:
if isinstance(rule, ACLRule):
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
if (rule.get_protocol() == _protocol or rule.get_protocol() == "ANY") and (
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY"
):
# There's a matching rule. Get the permission
if rule.get_permission() == RulePermissionType.DENY:
return True
elif rule.get_permission() == RulePermissionType.ALLOW:
return False
# If there has been no rule to allow the IER through, it will return a blocked signal by default
return True
def add_rule(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None:
def add_rule(
self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str, _position: int
) -> None:
"""
Adds a new rule.
@@ -71,12 +96,36 @@ class AccessControlList:
_dest_ip: the destination IP address
_protocol: the protocol
_port: the port
_position: position to insert ACL rule into ACL list (starting from index 1 and NOT 0)
"""
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
hash_value = hash(new_rule)
self.acl[hash_value] = new_rule
try:
position_index = int(_position)
except TypeError:
_LOGGER.info(f"Position {_position} could not be converted to integer.")
return
def remove_rule(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None:
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
# Checks position is in correct range
if self.max_acl_rules - 1 > position_index > -1:
try:
_LOGGER.info(f"Position {position_index} is valid.")
# Check to see Agent will not overwrite current ACL in ACL list
if self._acl[position_index] is None:
_LOGGER.info(f"Inserting rule {new_rule} at position {position_index}")
# Adds rule
self._acl[position_index] = new_rule
else:
# Cannot overwrite it
_LOGGER.info(f"Error: inserting rule at non-empty position {position_index}")
return
except Exception:
_LOGGER.info(f"New Rule could NOT be added to list at position {position_index}.")
else:
_LOGGER.info(f"Position {position_index} is an invalid/overwrites implicit firewall rule")
def remove_rule(
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
) -> None:
"""
Removes a rule.
@@ -87,17 +136,17 @@ class AccessControlList:
_protocol: the protocol
_port: the port
"""
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
hash_value = hash(rule)
# There will not always be something 'popable' since the agent will be trying random things
try:
self.acl.pop(hash_value)
except Exception:
return
rule_to_delete = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
delete_rule_hash = hash(rule_to_delete)
for index in range(0, len(self._acl)):
if isinstance(self._acl[index], ACLRule) and hash(self._acl[index]) == delete_rule_hash:
self._acl[index] = None
def remove_all_rules(self) -> None:
"""Removes all rules."""
self.acl.clear()
for i in range(len(self._acl)):
self._acl[i] = None
def get_dictionary_hash(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> int:
"""
@@ -129,16 +178,13 @@ class AccessControlList:
:return: Dictionary of all ACL rules that relate to the given arguments
:rtype: Dict[int, ACLRule]
"""
relevant_rules: Dict[int, ACLRule] = {}
for rule_key, rule_value in self.acl.items():
if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address):
if (
rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY" or _protocol == "ANY"
) and (
str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY" or str(_port) == "ANY"
relevant_rules = {}
for rule in self.acl:
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
if (rule.get_protocol() == _protocol or rule.get_protocol() == "ANY" or _protocol == "ANY") and (
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY" or str(_port) == "ANY"
):
# There's a matching rule.
relevant_rules[rule_key] = rule_value
relevant_rules[self._acl.index(rule)] = rule
return relevant_rules

View File

@@ -1,11 +1,14 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""A class that implements an access control list rule."""
from primaite.common.enums import RulePermissionType
class ACLRule:
"""Access Control List Rule class."""
def __init__(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None:
def __init__(
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
) -> None:
"""
Initialise an ACL Rule.
@@ -15,7 +18,7 @@ class ACLRule:
:param _protocol: The rule protocol
:param _port: The rule port
"""
self.permission: str = _permission
self.permission: RulePermissionType = _permission
self.source_ip: str = _source_ip
self.dest_ip: str = _dest_ip
self.protocol: str = _protocol

View File

@@ -198,3 +198,11 @@ class SB3OutputVerboseLevel(IntEnum):
NONE = 0
INFO = 1
DEBUG = 2
class RulePermissionType(Enum):
"""Any firewall rule type."""
NONE = 0
DENY = 1
ALLOW = 2

View File

@@ -163,3 +163,4 @@
destination: ANY
protocol: ANY
port: ANY
position: 0

View File

@@ -243,6 +243,7 @@
destination: 192.168.10.14
protocol: TCP
port: 80
position: 0
- item_type: ACL_RULE
id: '26'
permission: ALLOW
@@ -250,6 +251,7 @@
destination: 192.168.10.14
protocol: TCP
port: 80
position: 1
- item_type: ACL_RULE
id: '27'
permission: ALLOW
@@ -257,6 +259,7 @@
destination: 192.168.10.14
protocol: TCP
port: 80
position: 2
- item_type: ACL_RULE
id: '28'
permission: ALLOW
@@ -264,6 +267,7 @@
destination: 192.168.20.15
protocol: TCP
port: 80
position: 3
- item_type: ACL_RULE
id: '29'
permission: ALLOW
@@ -271,6 +275,7 @@
destination: 192.168.10.13
protocol: TCP
port: 80
position: 4
- item_type: ACL_RULE
id: '30'
permission: DENY
@@ -278,6 +283,7 @@
destination: 192.168.20.15
protocol: TCP
port: 80
position: 5
- item_type: ACL_RULE
id: '31'
permission: DENY
@@ -285,6 +291,7 @@
destination: 192.168.20.15
protocol: TCP
port: 80
position: 6
- item_type: ACL_RULE
id: '32'
permission: DENY
@@ -292,6 +299,7 @@
destination: 192.168.20.15
protocol: TCP
port: 80
position: 7
- item_type: ACL_RULE
id: '33'
permission: DENY
@@ -299,6 +307,7 @@
destination: 192.168.10.14
protocol: TCP
port: 80
position: 8
- item_type: RED_POL
id: '34'
start_step: 20

View File

@@ -111,6 +111,7 @@
destination: 192.168.1.4
protocol: TCP
port: 80
position: 0
- item_type: ACL_RULE
id: '12'
permission: ALLOW
@@ -118,6 +119,7 @@
destination: 192.168.1.4
protocol: TCP
port: 80
position: 1
- item_type: ACL_RULE
id: '13'
permission: ALLOW
@@ -125,6 +127,7 @@
destination: 192.168.1.3
protocol: TCP
port: 80
position: 2
- item_type: RED_POL
id: '14'
start_step: 20

View File

@@ -345,6 +345,7 @@
destination: 192.168.2.10
protocol: ANY
port: ANY
position: 0
- item_type: ACL_RULE
id: '34'
permission: ALLOW
@@ -352,6 +353,7 @@
destination: 192.168.2.14
protocol: ANY
port: ANY
position: 1
- item_type: ACL_RULE
id: '35'
permission: ALLOW
@@ -359,6 +361,7 @@
destination: 192.168.2.14
protocol: ANY
port: ANY
position: 2
- item_type: ACL_RULE
id: '36'
permission: ALLOW
@@ -366,6 +369,7 @@
destination: 192.168.2.10
protocol: ANY
port: ANY
position: 3
- item_type: ACL_RULE
id: '37'
permission: ALLOW
@@ -373,6 +377,7 @@
destination: 192.168.10.11
protocol: ANY
port: ANY
position: 4
- item_type: ACL_RULE
id: '38'
permission: ALLOW
@@ -380,6 +385,7 @@
destination: 192.168.10.12
protocol: ANY
port: ANY
position: 5
- item_type: ACL_RULE
id: '39'
permission: ALLOW
@@ -387,6 +393,7 @@
destination: 192.168.2.14
protocol: ANY
port: ANY
position: 6
- item_type: ACL_RULE
id: '40'
permission: ALLOW
@@ -394,6 +401,7 @@
destination: 192.168.2.10
protocol: ANY
port: ANY
position: 7
- item_type: ACL_RULE
id: '41'
permission: ALLOW
@@ -401,6 +409,7 @@
destination: 192.168.2.16
protocol: ANY
port: ANY
position: 8
- item_type: ACL_RULE
id: '42'
permission: ALLOW
@@ -408,6 +417,7 @@
destination: 192.168.2.16
protocol: ANY
port: ANY
position: 9
- item_type: ACL_RULE
id: '43'
permission: ALLOW
@@ -415,6 +425,7 @@
destination: 192.168.2.10
protocol: ANY
port: ANY
position: 10
- item_type: ACL_RULE
id: '44'
permission: ALLOW
@@ -422,6 +433,7 @@
destination: 192.168.2.14
protocol: ANY
port: ANY
position: 11
- item_type: ACL_RULE
id: '45'
permission: ALLOW
@@ -429,6 +441,7 @@
destination: 192.168.2.16
protocol: ANY
port: ANY
position: 12
- item_type: ACL_RULE
id: '46'
permission: ALLOW
@@ -436,6 +449,7 @@
destination: 192.168.1.12
protocol: ANY
port: ANY
position: 13
- item_type: ACL_RULE
id: '47'
permission: ALLOW
@@ -443,6 +457,7 @@
destination: 192.168.1.12
protocol: ANY
port: ANY
position: 14
- item_type: ACL_RULE
id: '48'
permission: ALLOW
@@ -450,6 +465,7 @@
destination: 192.168.1.12
protocol: ANY
port: ANY
position: 15
- item_type: ACL_RULE
id: '49'
permission: DENY
@@ -457,6 +473,7 @@
destination: ANY
protocol: ANY
port: ANY
position: 16
- item_type: RED_POL
id: '50'
start_step: 50

View File

@@ -35,7 +35,7 @@ random_red_agent: False
# Default is None (null)
seed: null
# Set whether the agent will be deterministic instead of stochastic
# Set whether the agent evaluation will be deterministic instead of stochastic
# Options are:
# True
# False
@@ -51,15 +51,15 @@ hard_coded_agent_view: FULL
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
action_type: ANY
# observation space
observation_space:
# flatten: true
flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
- name: NODE_STATUSES
- name: LINK_TRAFFIC_LEVELS
- name: ACCESS_CONTROL_LIST
# Number of episodes for training to run per session
num_train_episodes: 10
@@ -90,6 +90,11 @@ session_type: TRAIN_EVAL
# The high value for the observation space
observation_space_high_value: 1000000000
# Implicit ACL firewall rule at end of ACL list to be the default action (ALLOW or DENY)
implicit_acl_rule: DENY
# Total number of ACL rules allowed in the environment
max_number_acl_rules: 30
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)

View File

@@ -15,6 +15,7 @@ from primaite.common.enums import (
AgentIdentifier,
DeepLearningFramework,
HardCodedAgentView,
RulePermissionType,
SB3OutputVerboseLevel,
SessionType,
)
@@ -99,6 +100,12 @@ class TrainingConfig:
sb3_output_verbose_level: SB3OutputVerboseLevel = SB3OutputVerboseLevel.NONE
"Stable Baselines3 learn/eval output verbosity level"
implicit_acl_rule: RulePermissionType = RulePermissionType.DENY
"ALLOW or DENY implicit firewall rule to go at the end of list of ACL list."
max_number_acl_rules: int = 30
"Sets a limit for number of acl rules allowed in the list and environment."
# Reward values
# Generic
all_ok: float = 0
@@ -207,6 +214,7 @@ class TrainingConfig:
"session_type": SessionType,
"sb3_output_verbose_level": SB3OutputVerboseLevel,
"hard_coded_agent_view": HardCodedAgentView,
"implicit_acl_rule": RulePermissionType,
}
# convert the string representation of enums into the actual enum values themselves?
@@ -233,6 +241,7 @@ class TrainingConfig:
data["sb3_output_verbose_level"] = self.sb3_output_verbose_level.name
data["session_type"] = self.session_type.name
data["hard_coded_agent_view"] = self.hard_coded_agent_view.name
data["implicit_acl_rule"] = self.implicit_acl_rule.name
return data

View File

@@ -8,7 +8,8 @@ from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union
import numpy as np
from gym import spaces
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
from primaite.acl.acl_rule import ACLRule
from primaite.common.enums import FileSystemState, HardwareState, RulePermissionType, SoftwareState
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
@@ -275,6 +276,7 @@ class NodeStatuses(AbstractObservationComponent):
services = self.env.services_list
structure = []
for _, node in self.env.nodes.items():
node_id = node.node_id
structure.append(f"node_{node_id}_hardware_state_NONE")
@@ -403,6 +405,182 @@ class LinkTrafficLevels(AbstractObservationComponent):
return structure
class AccessControlList(AbstractObservationComponent):
"""Flat list of all the Access Control Rules in the Access Control List.
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
integers.
Each ACL Rule has 6 elements. It will have the following structure:
.. code-block::
[
acl_rule1 permission,
acl_rule1 source_ip,
acl_rule1 dest_ip,
acl_rule1 protocol,
acl_rule1 port,
acl_rule1 position,
acl_rule2 permission,
acl_rule2 source_ip,
acl_rule2 dest_ip,
acl_rule2 protocol,
acl_rule2 port,
acl_rule2 position,
...
]
Terms (for ACL Observation Space):
[0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW)
[0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
[0, num nodes] - Dest IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
[0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol)
[0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port)
[0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list)
NOTE: NA is Non-Applicable - this means the ACL Rule in the list is a NoneType and NOT an ACLRule object.
"""
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite"):
"""
Initialise an AccessControlList observation component.
:param env: The environment that forms the basis of the observations
:type env: Primaite
"""
super().__init__(env)
# 1. Define the shape of your observation space component
# The NA and ANY types means that there are 2 extra items for Nodes, Services and Ports.
# Number of ACL rules incremented by 1 for positions starting at index 0.
acl_shape = [
len(RulePermissionType),
len(env.nodes) + 2,
len(env.nodes) + 2,
len(env.services_list) + 2,
len(env.ports_list) + 2,
env.max_number_acl_rules,
]
shape = acl_shape * self.env.max_number_acl_rules
# 2. Create Observation space
self.space = spaces.MultiDiscrete(shape)
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
self.structure = self.generate_structure()
def update(self) -> None:
"""Update the observation based on current environment state.
The structure of the observation space is described in :class:`.AccessControlList`
"""
obs = []
for index in range(0, len(self.env.acl.acl)):
acl_rule = self.env.acl.acl[index]
if isinstance(acl_rule, ACLRule):
permission = acl_rule.permission
source_ip = acl_rule.source_ip
dest_ip = acl_rule.dest_ip
protocol = acl_rule.protocol
port = acl_rule.port
position = index
# Map each ACL attribute from what it was to an integer to fit the observation space
source_ip_int = None
dest_ip_int = None
if permission == RulePermissionType.DENY:
permission_int = 1
else:
permission_int = 2
if source_ip == "ANY":
source_ip_int = 1
else:
# Map Node ID (+ 1) to source IP address
nodes = list(self.env.nodes.values())
for node in nodes:
if (
isinstance(node, ServiceNode) or isinstance(node, ActiveNode)
) and node.ip_address == source_ip:
source_ip_int = int(node.node_id) + 1
break
if dest_ip == "ANY":
dest_ip_int = 1
else:
# Map Node ID (+ 1) to dest IP address
# Index of Nodes start at 1 so + 1 is needed so NA can be added.
nodes = list(self.env.nodes.values())
for node in nodes:
if (
isinstance(node, ServiceNode) or isinstance(node, ActiveNode)
) and node.ip_address == dest_ip:
dest_ip_int = int(node.node_id) + 1
if protocol == "ANY":
protocol_int = 1
else:
# Index of protocols and ports start from 0 so + 2 is needed to add NA and ANY
try:
protocol_int = self.env.services_list.index(protocol) + 2
except AttributeError:
_LOGGER.info(f"Service {protocol} could not be found")
protocol_int = None
if port == "ANY":
port_int = 1
else:
if port in self.env.ports_list:
port_int = self.env.ports_list.index(port) + 2
else:
_LOGGER.info(f"Port {port} could not be found.")
port_int = None
# Add to current obs
obs.extend(
[
permission_int,
source_ip_int,
dest_ip_int,
protocol_int,
port_int,
position,
]
)
else:
# The Nothing or NA representation of 'NONE' ACL rules
obs.extend([0, 0, 0, 0, 0, 0])
self.current_observation[:] = obs
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
structure = []
for acl_rule in self.env.acl.acl:
acl_rule_id = self.env.acl.acl.index(acl_rule)
for permission in RulePermissionType:
structure.append(f"acl_rule_{acl_rule_id}_permission_{permission.name}")
structure.append(f"acl_rule_{acl_rule_id}_source_ip_ANY")
for node in self.env.nodes.keys():
structure.append(f"acl_rule_{acl_rule_id}_source_ip_{node}")
structure.append(f"acl_rule_{acl_rule_id}_dest_ip_ANY")
for node in self.env.nodes.keys():
structure.append(f"acl_rule_{acl_rule_id}_dest_ip_{node}")
structure.append(f"acl_rule_{acl_rule_id}_service_ANY")
for service in self.env.services_list:
structure.append(f"acl_rule_{acl_rule_id}_service_{service}")
structure.append(f"acl_rule_{acl_rule_id}_port_ANY")
for port in self.env.ports_list:
structure.append(f"acl_rule_{acl_rule_id}_port_{port}")
return structure
class ObservationsHandler:
"""
Component-based observation space handler.
@@ -415,6 +593,7 @@ class ObservationsHandler:
"NODE_LINK_TABLE": NodeLinkTable,
"NODE_STATUSES": NodeStatuses,
"LINK_TRAFFIC_LEVELS": LinkTrafficLevels,
"ACCESS_CONTROL_LIST": AccessControlList,
}
def __init__(self) -> None:
@@ -430,8 +609,6 @@ class ObservationsHandler:
# used for transactions and when flatten=true
self._flat_observation: np.ndarray
self.flatten: bool = False
def update_obs(self) -> None:
"""Fetch fresh information about the environment."""
current_obs = []
@@ -485,7 +662,7 @@ class ObservationsHandler:
@property
def space(self) -> spaces.Space:
"""Observation space, return the flattened version if flatten is True."""
if self.flatten:
if len(self.registered_obs_components) > 1:
return self._flat_space
else:
return self._space
@@ -493,7 +670,7 @@ class ObservationsHandler:
@property
def current_observation(self) -> Union[np.ndarray, Tuple[np.ndarray]]:
"""Current observation, return the flattened version if flatten is True."""
if self.flatten:
if len(self.registered_obs_components) > 1:
return self._flat_observation
else:
return self._observation
@@ -528,9 +705,6 @@ class ObservationsHandler:
# Instantiate the handler
handler = cls()
if obs_space_config.get("flatten"):
handler.flatten = True
for component_cfg in obs_space_config["components"]:
# Figure out which class can instantiate the desired component
comp_type = component_cfg["name"]

View File

@@ -125,7 +125,12 @@ class Primaite(Env):
self.red_node_pol: Dict[str, NodeStateInstructionRed] = {}
# Create the Access Control List
self.acl: AccessControlList = AccessControlList()
self.acl: AccessControlList = AccessControlList(
self.training_config.implicit_acl_rule,
self.training_config.max_number_acl_rules,
)
# Sets limit for number of ACL rules in environment
self.max_number_acl_rules: int = self.training_config.max_number_acl_rules
# Create a list of services (enums)
self.services_list: List[str] = []
@@ -458,12 +463,11 @@ class Primaite(Env):
_action: The action space from the agent
"""
# At the moment, actions are only affecting nodes
if self.training_config.action_type == ActionType.NODE:
self.apply_actions_to_nodes(_action)
elif self.training_config.action_type == ActionType.ACL:
self.apply_actions_to_acl(_action)
elif len(self.action_dict[_action]) == 6: # ACL actions in multidiscrete form have len 6
elif len(self.action_dict[_action]) == 7: # ACL actions in multidiscrete form have len 7
self.apply_actions_to_acl(_action)
elif len(self.action_dict[_action]) == 4: # Node actions in multdiscrete (array) from have len 4
self.apply_actions_to_nodes(_action)
@@ -574,6 +578,7 @@ class Primaite(Env):
action_destination_ip = readable_action[3]
action_protocol = readable_action[4]
action_port = readable_action[5]
acl_rule_position = readable_action[6]
if action_decision == 0:
# It's decided to do nothing
@@ -623,6 +628,7 @@ class Primaite(Env):
acl_rule_destination,
acl_rule_protocol,
acl_rule_port,
acl_rule_position,
)
elif action_decision == 2:
# Remove the rule
@@ -1018,6 +1024,7 @@ class Primaite(Env):
acl_rule_destination = item["destination"]
acl_rule_protocol = item["protocol"]
acl_rule_port = item["port"]
acl_rule_position = item["position"]
self.acl.add_rule(
acl_rule_permission,
@@ -1025,6 +1032,7 @@ class Primaite(Env):
acl_rule_destination,
acl_rule_protocol,
acl_rule_port,
acl_rule_position,
)
# TODO: confirm typehint using runtime
@@ -1182,6 +1190,11 @@ class Primaite(Env):
...
}
"""
# Terms (for node action space):
# [0, num nodes] - node ID (0 = nothing, node ID)
# [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
# reserve 0 action to be a nothing action
actions = {0: [1, 0, 0, 0]}
action_key = 1
@@ -1203,8 +1216,16 @@ class Primaite(Env):
def create_acl_action_dict(self) -> Dict[int, List[int]]:
"""Creates a dictionary mapping each possible discrete action to more readable multidiscrete action."""
# Terms (for ACL action space):
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
# [0, 1] - Permission (0 = DENY, 1 = ALLOW)
# [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses)
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
# [0, max acl rules - 1] - Position (0 = first index, then 1 -> x index resolving to acl rule in acl list)
# reserve 0 action to be a nothing action
actions = {0: [0, 0, 0, 0, 0, 0]}
actions = {0: [0, 0, 0, 0, 0, 0, 0]}
action_key = 1
# 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE
@@ -1216,18 +1237,21 @@ class Primaite(Env):
for dest_ip in range(self.num_nodes + 1):
for protocol in range(self.num_services + 1):
for port in range(self.num_ports + 1):
action = [
action_decision,
action_permission,
source_ip,
dest_ip,
protocol,
port,
]
# Check to see if its an action we want to include as possible i.e. not a nothing action
if is_valid_acl_action_extra(action):
actions[action_key] = action
action_key += 1
for position in range(self.max_number_acl_rules - 1):
action = [
action_decision,
action_permission,
source_ip,
dest_ip,
protocol,
port,
position,
]
# Check to see if it is an action we want to include as possible
# i.e. not a nothing action
if is_valid_acl_action_extra(action):
actions[action_key] = action
action_key += 1
return actions

View File

@@ -1,11 +1,8 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from typing import TYPE_CHECKING
from logging import Logger
from primaite import getLogger
if TYPE_CHECKING:
from logging import Logger
_LOGGER: Logger = getLogger(__name__)

View File

@@ -2,16 +2,13 @@
import filecmp
import os
import shutil
from logging import Logger
from pathlib import Path
from typing import TYPE_CHECKING
import pkg_resources
from primaite import getLogger, USERS_CONFIG_DIR
if TYPE_CHECKING:
from logging import Logger
_LOGGER: Logger = getLogger(__name__)

File diff suppressed because one or more lines are too long

View File

@@ -92,6 +92,7 @@
destination: 192.168.1.2
protocol: TCP
port: 80
position: 0
- item_type: ACL_RULE
id: '7'
permission: ALLOW
@@ -99,3 +100,4 @@
destination: 192.168.1.1
protocol: TCP
port: 80
position: 0

View File

@@ -0,0 +1,86 @@
- item_type: PORTS
ports_list:
- port: '80'
- port: '21'
- item_type: SERVICES
service_list:
- name: TCP
- name: FTP
########################################
# Nodes
- item_type: NODE
node_id: '1'
name: PC1
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.1
software_state: COMPROMISED
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: FTP
port: '21'
state: GOOD
- item_type: NODE
node_id: '2'
name: SERVER
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.2
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: FTP
port: '21'
state: OVERWHELMED
- item_type: NODE
node_id: '3'
name: SWITCH1
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.1.3
software_state: GOOD
file_system_state: GOOD
########################################
# Links
- item_type: LINK
id: '4'
name: link1
bandwidth: 1000
source: '1'
destination: '3'
- item_type: LINK
id: '5'
name: link2
bandwidth: 1000
source: '3'
destination: '2'
#########################################
# IERS
- item_type: GREEN_IER
id: '5'
start_step: 0
end_step: 5
load: 999
protocol: TCP
port: '80'
source: '1'
destination: '2'
mission_criticality: 5
#########################################
# ACL Rules

View File

@@ -0,0 +1,106 @@
# Main Config File
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_framework: SB3
agent_identifier: PPO
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: ANY
# Number of episodes for training to run per session
num_train_episodes: 1
# Number of time_steps for training per episode
num_train_steps: 5
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
implicit_acl_rule: DENY
# Total number of ACL rules allowed in the environment
max_number_acl_rules: 3
observation_space:
components:
- name: ACCESS_CONTROL_LIST
# Time delay between steps (for generic agents)
time_delay: 1
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAIN
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Environment config values
# The high value for the observation space
observation_space_high_value: 1_000_000_000
# Reward values
# Generic
all_ok: 0
# Node Hardware State
off_should_be_on: -10
off_should_be_resetting: -5
on_should_be_off: -2
on_should_be_resetting: -5
resetting_should_be_on: -5
resetting_should_be_off: -2
resetting: -3
# Node Software or Service State
good_should_be_patching: 2
good_should_be_compromised: 5
good_should_be_overwhelmed: 5
patching_should_be_good: -5
patching_should_be_compromised: 2
patching_should_be_overwhelmed: 2
patching: -3
compromised_should_be_good: -20
compromised_should_be_patching: -20
compromised_should_be_overwhelmed: -20
compromised: -20
overwhelmed_should_be_good: -20
overwhelmed_should_be_patching: -20
overwhelmed_should_be_compromised: -20
overwhelmed: -20
# Node File System State
good_should_be_repairing: 2
good_should_be_restoring: 2
good_should_be_corrupt: 5
good_should_be_destroyed: 10
repairing_should_be_good: -5
repairing_should_be_restoring: 2
repairing_should_be_corrupt: 2
repairing_should_be_destroyed: 0
repairing: -3
restoring_should_be_good: -10
restoring_should_be_repairing: -2
restoring_should_be_corrupt: 1
restoring_should_be_destroyed: 2
restoring: -6
corrupt_should_be_good: -10
corrupt_should_be_repairing: -10
corrupt_should_be_restoring: -10
corrupt_should_be_destroyed: 2
corrupt: -10
destroyed_should_be_good: -20
destroyed_should_be_repairing: -20
destroyed_should_be_restoring: -20
destroyed_should_be_corrupt: -20
destroyed: -20
scanning: -2
# IER status
red_ier_running: -5
green_ier_blocked: -10
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS
node_reset_duration: 5 # The time taken to reset a node (hardware)
service_patching_duration: 5 # The time taken to patch a service
file_system_repairing_limit: 5 # The time take to repair the file system
file_system_restoring_limit: 5 # The time take to restore the file system
file_system_scanning_limit: 5 # The time taken to scan the file system

View File

@@ -39,6 +39,11 @@ observation_space:
# Time delay between steps (for generic agents)
time_delay: 1
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
implicit_acl_rule: ALLOW
# Total number of ACL rules allowed in the environment
max_number_acl_rules: 4
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAIN
# Determine whether to load an agent from file

View File

@@ -37,6 +37,11 @@ observation_space:
time_delay: 1
# Filename of the scenario / laydown
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
implicit_acl_rule: ALLOW
# Total number of ACL rules allowed in the environment
max_number_acl_rules: 4
session_type: TRAIN
# Determine whether to load an agent from file
load_agent: False

View File

@@ -40,7 +40,8 @@ agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Environment config values
# The high value for the observation space
observation_space_high_value: 1_000_000_000
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
implicit_acl_rule: DENY
# Reward values
# Generic
all_ok: 0

View File

@@ -91,6 +91,8 @@ session_type: EVAL
# The high value for the observation space
observation_space_high_value: 1000000000
implicit_acl_rule: DENY
max_number_acl_rules: 10
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)

View File

@@ -36,7 +36,7 @@ random_red_agent: False
# Default is None (null)
seed: None
# Set whether the agent will be deterministic instead of stochastic
# Set whether the agent evaluation will be deterministic instead of stochastic
# Options are:
# True
# False
@@ -55,11 +55,11 @@ hard_coded_agent_view: FULL
action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# - name: ACCESS_CONTROL_LIST
# Number of episodes to run per session
num_train_episodes: 10

View File

@@ -36,7 +36,7 @@ random_red_agent: False
# Default is None (null)
seed: 67890
# Set whether the agent will be deterministic instead of stochastic
# Set whether the agent evaluation will be deterministic instead of stochastic
# Options are:
# True
# False
@@ -55,7 +55,6 @@ hard_coded_agent_view: FULL
action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES

View File

@@ -38,6 +38,15 @@ load_agent: False
# File path and file name of agent if you're loading one in
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
implicit_acl_rule: DENY
# Total number of ACL rules allowed in the environment
max_number_acl_rules: 10
observation_space:
components:
- name: ACCESS_CONTROL_LIST
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000

View File

@@ -1,10 +1,10 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
- item_type: PORTS
ports_list:
- port: '21'
- port: '80'
- item_type: SERVICES
service_list:
- name: ftp
- name: TCP
- item_type: NODE
node_id: '1'
name: node
@@ -16,8 +16,8 @@
software_state: GOOD
file_system_state: GOOD
services:
- name: ftp
port: '21'
- name: TCP
port: '80'
state: COMPROMISED
- item_type: NODE
node_id: '2'
@@ -30,15 +30,15 @@
software_state: GOOD
file_system_state: GOOD
services:
- name: ftp
port: '21'
- name: TCP
port: '80'
state: COMPROMISED
- item_type: RED_IER
id: '3'
start_step: 2
end_step: 15
load: 1000
protocol: ftp
protocol: TCP
port: CORRUPT
source: '1'
destination: '2'

View File

@@ -47,6 +47,9 @@ agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# The high value for the observation space
observation_space_high_value: 1000000000
# Choice whether to have an ALLOW or DENY implicit rule or not (TRUE or FALSE)
implicit_acl_rule: DENY
max_number_acl_rules: 10
# Reward values
# Generic
all_ok: 0

View File

@@ -3,40 +3,41 @@
from primaite.acl.access_control_list import AccessControlList
from primaite.acl.acl_rule import ACLRule
from primaite.common.enums import RulePermissionType
def test_acl_address_match_1():
"""Test that matching IP addresses produce True."""
acl = AccessControlList()
acl = AccessControlList(RulePermissionType.DENY, 10)
rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80")
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "192.168.1.2", "TCP", "80")
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
def test_acl_address_match_2():
"""Test that mismatching IP addresses produce False."""
acl = AccessControlList()
acl = AccessControlList(RulePermissionType.DENY, 10)
rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80")
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "192.168.1.2", "TCP", "80")
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.3") == False
def test_acl_address_match_3():
"""Test the ANY condition for source IP addresses produce True."""
acl = AccessControlList()
acl = AccessControlList(RulePermissionType.DENY, 10)
rule = ACLRule("ALLOW", "ANY", "192.168.1.2", "TCP", "80")
rule = ACLRule(RulePermissionType.ALLOW, "ANY", "192.168.1.2", "TCP", "80")
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
def test_acl_address_match_4():
"""Test the ANY condition for dest IP addresses produce True."""
acl = AccessControlList()
acl = AccessControlList(RulePermissionType.DENY, 10)
rule = ACLRule("ALLOW", "192.168.1.1", "ANY", "TCP", "80")
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "ANY", "TCP", "80")
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
@@ -44,14 +45,15 @@ def test_acl_address_match_4():
def test_check_acl_block_affirmative():
"""Test the block function (affirmative)."""
# Create the Access Control List
acl = AccessControlList()
acl = AccessControlList(RulePermissionType.DENY, 10)
# Create a rule
acl_rule_permission = "ALLOW"
acl_rule_permission = RulePermissionType.ALLOW
acl_rule_source = "192.168.1.1"
acl_rule_destination = "192.168.1.2"
acl_rule_protocol = "TCP"
acl_rule_port = "80"
acl_position_in_list = "0"
acl.add_rule(
acl_rule_permission,
@@ -59,22 +61,23 @@ def test_check_acl_block_affirmative():
acl_rule_destination,
acl_rule_protocol,
acl_rule_port,
acl_position_in_list,
)
assert acl.is_blocked("192.168.1.1", "192.168.1.2", "TCP", "80") == False
def test_check_acl_block_negative():
"""Test the block function (negative)."""
# Create the Access Control List
acl = AccessControlList()
acl = AccessControlList(RulePermissionType.DENY, 10)
# Create a rule
acl_rule_permission = "DENY"
acl_rule_permission = RulePermissionType.DENY
acl_rule_source = "192.168.1.1"
acl_rule_destination = "192.168.1.2"
acl_rule_protocol = "TCP"
acl_rule_port = "80"
acl_position_in_list = "0"
acl.add_rule(
acl_rule_permission,
@@ -82,6 +85,7 @@ def test_check_acl_block_negative():
acl_rule_destination,
acl_rule_protocol,
acl_rule_port,
acl_position_in_list,
)
assert acl.is_blocked("192.168.1.1", "192.168.1.2", "TCP", "80") == True
@@ -90,11 +94,73 @@ def test_check_acl_block_negative():
def test_rule_hash():
"""Test the rule hash."""
# Create the Access Control List
acl = AccessControlList()
acl = AccessControlList(RulePermissionType.DENY, 10)
rule = ACLRule("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
rule = ACLRule(RulePermissionType.DENY, "192.168.1.1", "192.168.1.2", "TCP", "80")
hash_value_local = hash(rule)
hash_value_remote = acl.get_dictionary_hash("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
hash_value_remote = acl.get_dictionary_hash(RulePermissionType.DENY, "192.168.1.1", "192.168.1.2", "TCP", "80")
assert hash_value_local == hash_value_remote
def test_delete_rule():
"""Adds 3 rules and deletes 1 rule and checks its deletion."""
# Create the Access Control List
acl = AccessControlList(RulePermissionType.ALLOW, 10)
# Create a first rule
acl_rule_permission = RulePermissionType.DENY
acl_rule_source = "192.168.1.1"
acl_rule_destination = "192.168.1.2"
acl_rule_protocol = "TCP"
acl_rule_port = "80"
acl_position_in_list = "0"
acl.add_rule(
acl_rule_permission,
acl_rule_source,
acl_rule_destination,
acl_rule_protocol,
acl_rule_port,
acl_position_in_list,
)
# Create a second rule
acl_rule_permission = RulePermissionType.DENY
acl_rule_source = "20"
acl_rule_destination = "30"
acl_rule_protocol = "FTP"
acl_rule_port = "21"
acl_position_in_list = "2"
acl.add_rule(
acl_rule_permission,
acl_rule_source,
acl_rule_destination,
acl_rule_protocol,
acl_rule_port,
acl_position_in_list,
)
# Create a third rule
acl_rule_permission = RulePermissionType.ALLOW
acl_rule_source = "192.168.1.3"
acl_rule_destination = "192.168.1.1"
acl_rule_protocol = "UDP"
acl_rule_port = "60"
acl_position_in_list = "4"
acl.add_rule(
acl_rule_permission,
acl_rule_source,
acl_rule_destination,
acl_rule_protocol,
acl_rule_port,
acl_position_in_list,
)
# Remove the second ACL rule added from the list
acl.remove_rule(RulePermissionType.DENY, "20", "30", "FTP", "21")
assert len(acl.acl) == 10
assert acl.acl[2] is None

View File

@@ -1,5 +1,6 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Test env creation and behaviour with different observation spaces."""
import numpy as np
import pytest
@@ -237,3 +238,140 @@ class TestLinkTrafficLevels:
# therefore the first and third elements should be 6 and all others 0
# (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%)
assert np.array_equal(obs, [6, 0, 6, 0])
@pytest.mark.parametrize(
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT / "obs_tests/main_config_ACCESS_CONTROL_LIST.yaml",
TEST_CONFIG_ROOT / "obs_tests/laydown_ACL.yaml",
]
],
indirect=True,
)
class TestAccessControlList:
"""Test the AccessControlList observation component (in isolation)."""
def test_obs_shape(self, temp_primaite_session):
"""Try creating env with MultiDiscrete observation space.
The laydown has 3 ACL Rules - that is the maximum_acl_rules it can have.
Each ACL Rule in the observation space has 6 different elements:
6 * 3 = 18
"""
with temp_primaite_session as session:
env = session.env
env.update_environent_obs()
assert env.env_obs.shape == (18,)
def test_values(self, temp_primaite_session):
"""Test that traffic values are encoded correctly.
The laydown has:
* one ACL IMPLICIT DENY rule
Therefore, the ACL is full of NAs aka zeros and just 6 non-zero elements representing DENY ANY ANY ANY at
Position 2.
"""
with temp_primaite_session as session:
env = session.env
obs, reward, done, info = env.step(0)
obs, reward, done, info = env.step(0)
assert np.array_equal(obs, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2])
def test_observation_space_with_implicit_rule(self, temp_primaite_session):
"""
Test observation space is what is expected when an agent adds ACLs during an episode.
At the start of the episode, there is a single implicit DENY rule
In the observation space IMPLICIT DENY: 1,1,1,1,1,0
0 shows the rule is the start (when episode began no other rules were created) so this is correct.
On Step 2, there is an ACL rule added at Position 0: 2,2,3,2,3,0
On Step 4, there is a second ACL rule added at POSITION 1: 2,4,2,3,3,1
The final observation space should be this:
[2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2]
The ACL Rule from Step 2 is added first and has a HIGHER position than the ACL rule from Step 4
but both come before the IMPLICIT DENY which will ALWAYS be at the end of the ACL List.
"""
# TODO: Refactor this at some point to build a custom ACL Hardcoded
# Agent and then patch the AgentIdentifier Enum class so that it
# has ACL_AGENT. This then allows us to set the agent identified in
# the main config and is a bit cleaner.
with temp_primaite_session as session:
env = session.env
training_config = env.training_config
for episode in range(0, training_config.num_train_episodes):
for step in range(0, training_config.num_train_steps):
# Do nothing action
action = 0
if step == 2:
# Action to add the first ACL rule
action = 43
elif step == 4:
# Action to add the second ACL rule
action = 96
# Run the simulation step on the live environment
obs, reward, done, info = env.step(action)
# Break if done is True
if done:
break
obs = env.env_obs
assert np.array_equal(obs, [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2])
def test_observation_space_with_different_positions(self, temp_primaite_session):
"""
Test observation space is what is expected when an agent adds ACLs during an episode.
At the start of the episode, there is a single implicit DENY rule
In the observation space IMPLICIT DENY: 1,1,1,1,1,0
0 shows the rule is the start (when episode began no other rules were created) so this is correct.
On Step 2, there is an ACL rule added at Position 1: 2,2,3,2,3,1
On Step 4 there is a second ACL rule added at Position 0: 2,4,2,3,3,0
The final observation space should be this:
[2 , 4, 2, 3, 3, 0, 2, 2, 3, 2, 3, 1, 1, 1, 1, 1, 1, 2]
The ACL Rule from Step 2 is added before and has a LOWER position than the ACL rule from Step 4
but both come before the IMPLICIT DENY which will ALWAYS be at the end of the ACL List.
"""
# TODO: Refactor this at some point to build a custom ACL Hardcoded
# Agent and then patch the AgentIdentifier Enum class so that it
# has ACL_AGENT. This then allows us to set the agent identified in
# the main config and is a bit cleaner.
with temp_primaite_session as session:
env = session.env
training_config = env.training_config
for episode in range(0, training_config.num_train_episodes):
for step in range(0, training_config.num_train_steps):
# Do nothing action
action = 0
if step == 2:
# Action to add the first ACL rule
action = 44
elif step == 4:
# Action to add the second ACL rule
action = 95
# Run the simulation step on the live environment
obs, reward, done, info = env.step(action)
# Break if done is True
if done:
break
obs = env.env_obs
assert np.array_equal(obs, [2, 4, 2, 3, 3, 0, 2, 2, 3, 2, 3, 1, 1, 1, 1, 1, 1, 2])

View File

@@ -11,30 +11,46 @@ from tests import TEST_CONFIG_ROOT
indirect=True,
)
def test_seeded_learning(temp_primaite_session):
"""Test running seeded learning produces the same output when ran twice."""
"""
Test running seeded learning produces the same output when ran twice.
.. note::
If this is failing, the hard-coded expected_mean_reward_per_episode
from a pre-trained agent will probably need to be updated. If the
env changes and those changed how this agent is trained, chances are
the mean rewards are going to be different.
Run the test, but print out the session.learn_av_reward_per_episode()
before comparing it. Then copy the printed dict and replace the
expected_mean_reward_per_episode with those values. The test should
now work. If not, then you've got a bug :).
"""
expected_mean_reward_per_episode = {
1: -90.703125,
2: -91.15234375,
3: -87.5,
4: -92.2265625,
5: -94.6875,
6: -91.19140625,
7: -88.984375,
8: -88.3203125,
9: -112.79296875,
10: -100.01953125,
1: -20.7421875,
2: -19.82421875,
3: -17.01171875,
4: -19.08203125,
5: -21.93359375,
6: -20.21484375,
7: -15.546875,
8: -12.08984375,
9: -17.59765625,
10: -14.6875,
}
with temp_primaite_session as session:
assert session._training_config.seed == 67890, (
"Expected output is based upon a agent that was trained with " "seed 67890"
)
assert (
session._training_config.seed == 67890
), "Expected output is based upon a agent that was trained with seed 67890"
session.learn()
actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict()
print(actual_mean_reward_per_episode, "THISt")
assert actual_mean_reward_per_episode == expected_mean_reward_per_episode
@pytest.mark.skip(reason="Inconsistent results. Needs someone with RL " "knowledge to investigate further.")
@pytest.mark.skip(reason="Inconsistent results. Needs someone with RL knowledge to investigate further.")
@pytest.mark.parametrize(
"temp_primaite_session",
[[TEST_CONFIG_ROOT / "ppo_seeded_training_config.yaml", dos_very_basic_config_path()]],

View File

@@ -43,6 +43,9 @@ def copy_session_asset(asset_path: Union[str, Path]) -> str:
return copy_path
@pytest.mark.xfail(
reason="Loading works fine but the exact values change with code changes, a bug report has been created."
)
def test_load_sb3_session():
"""Test that loading an SB3 agent works."""
expected_learn_mean_reward_per_episode = {

View File

@@ -3,6 +3,7 @@ import time
import pytest
from primaite.acl.acl_rule import ACLRule
from primaite.common.enums import HardwareState
from primaite.environment.primaite_env import Primaite
from tests import TEST_CONFIG_ROOT
@@ -19,16 +20,17 @@ def run_generic_set_actions(env: Primaite):
# TEMP - random action for now
# action = env.blue_agent_action(obs)
action = 0
# print("Episode:", episode, "\nStep:", step)
if step == 5:
# [1, 1, 2, 1, 1, 1]
# [1, 1, 2, 1, 1, 1, 1(position)]
# Creates an ACL rule
# Allows traffic from server_1 to node_1 on port FTP
action = 7
action = 56
elif step == 7:
# [1, 1, 2, 0] Node Action
# Sets Node 1 Hardware State to OFF
# Does not resolve any service
action = 16
action = 128
# Run the simulation step on the live environment
obs, reward, done, info = env.step(action)
@@ -57,6 +59,10 @@ def run_generic_set_actions(env: Primaite):
)
def test_single_action_space_is_valid(temp_primaite_session):
"""Test single action space is valid."""
# TODO: Refactor this at some point to build a custom ACL Hardcoded
# Agent and then patch the AgentIdentifier Enum class so that it
# has ACL_AGENT. This then allows us to set the agent identified in
# the main config and is a bit cleaner.
with temp_primaite_session as session:
env = session.env
@@ -73,7 +79,7 @@ def test_single_action_space_is_valid(temp_primaite_session):
if len(dict_item) == 4:
contains_node_actions = True
# Link action detected
elif len(dict_item) == 6:
elif len(dict_item) == 7:
contains_acl_actions = True
# If both are there then the ANY action type is working
if contains_node_actions and contains_acl_actions:
@@ -94,6 +100,10 @@ def test_single_action_space_is_valid(temp_primaite_session):
)
def test_agent_is_executing_actions_from_both_spaces(temp_primaite_session):
"""Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL)."""
# TODO: Refactor this at some point to build a custom ACL Hardcoded
# Agent and then patch the AgentIdentifier Enum class so that it
# has ACL_AGENT. This then allows us to set the agent identified in
# the main config and is a bit cleaner.
with temp_primaite_session as session:
env = session.env
# Run environment with specified fixed blue agent actions only
@@ -105,11 +115,15 @@ def test_agent_is_executing_actions_from_both_spaces(temp_primaite_session):
access_control_list = env.acl
# Use the Access Control List object acl object attribute to get dictionary
# Use dictionary.values() to get total list of all items in the dictionary
acl_rules_list = access_control_list.acl.values()
acl_rules_list = access_control_list.acl
# Length of this list tells you how many items are in the dictionary
# This number is the frequency of Access Control Rules in the environment
# In the scenario, we specified that the agent should create only 1 acl rule
num_of_rules = len(acl_rules_list)
# This 1 rule added to the implicit deny means there should be 2 rules in total.
rules_count = 0
for rule in acl_rules_list:
if isinstance(rule, ACLRule):
rules_count += 1
# Therefore these statements below MUST be true
assert computer_node_hardware_state == HardwareState.OFF
assert num_of_rules == 1
assert rules_count == 2