901 - changed ACL instantiation and changed acl t private _acl (list not dict) attribute, added laydown_ACL.yaml for testing, fixed encoding of acl rules to integers for obs space, added ACL position to node action space and added generic test where agents adds two ACL rules.

This commit is contained in:
SunilSamra
2023-06-20 11:47:20 +01:00
parent c6a947fbaf
commit df42a791c9
9 changed files with 305 additions and 107 deletions

View File

@@ -11,21 +11,43 @@ _LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
class AccessControlList:
"""Access Control List class."""
def __init__(self, implicit_permission, max_acl_rules):
def __init__(self, apply_implicit_rule, implicit_permission, max_acl_rules):
"""Init."""
# Bool option in main_config to decide to use implicit rule or not
self.apply_implicit_rule: bool = apply_implicit_rule
# Implicit ALLOW or DENY firewall spec
# Last rule in the ACL list
self.acl_implicit_rule = implicit_permission
# Create implicit rule based on input
if self.acl_implicit_rule == "DENY":
implicit_rule = ACLRule("DENY", "ANY", "ANY", "ANY", "ANY")
else:
implicit_rule = ACLRule("ALLOW", "ANY", "ANY", "ANY", "ANY")
self.acl_implicit_permission = implicit_permission
# Maximum number of ACL Rules in ACL
self.max_acl_rules: int = max_acl_rules
# A list of ACL Rules
self.acl: List[ACLRule] = [implicit_rule]
self._acl: List[ACLRule] = []
# Implicit rule
@property
def acl_implicit_rule(self):
"""ACL implicit rule class attribute with added logic to change it depending on option in main_config."""
# Create implicit rule based on input
if self.apply_implicit_rule:
if self.acl_implicit_permission == "DENY":
return ACLRule("DENY", "ANY", "ANY", "ANY", "ANY")
elif self.acl_implicit_permission == "ALLOW":
return ACLRule("ALLOW", "ANY", "ANY", "ANY", "ANY")
else:
return None
else:
return None
@property
def acl(self):
"""Public access method for private _acl.
Adds implicit rule to end of acl list and
Pads out rest of list (if empty) with -1.
"""
if self.acl_implicit_rule is not None:
acl_list = self._acl + [self.acl_implicit_rule]
return acl_list + [-1] * (self.max_acl_rules - len(acl_list))
def check_address_match(self, _rule, _source_ip_address, _dest_ip_address):
"""
@@ -85,7 +107,9 @@ class AccessControlList:
# If there has been no rule to allow the IER through, it will return a blocked signal by default
return True
def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port, _position):
def add_rule(
self, _permission, _source_ip, _dest_ip, _protocol, _port, _position=None
):
"""
Adds a new rule.
@@ -99,18 +123,22 @@ class AccessControlList:
"""
position_index = int(_position)
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
if len(self.acl) < self.max_acl_rules:
if len(self.acl) > position_index > -1:
try:
self.acl.insert(position_index, new_rule)
except Exception:
print(len(self._acl))
if len(self._acl) + 1 < self.max_acl_rules:
if _position is not None:
if self.max_acl_rules - 1 > position_index > -1:
try:
self._acl.insert(position_index, new_rule)
except Exception:
_LOGGER.info(
f"New Rule could NOT be added to list at position {position_index}."
)
else:
_LOGGER.info(
f"New Rule could NOT be added to list at position {position_index}."
f"Position {position_index} is an invalid index for list/overwrites implicit firewall rule"
)
else:
_LOGGER.info(
f"Position {position_index} is an invalid index for list and/or overwrites implicit firewall rule"
)
self.acl.append(new_rule)
else:
_LOGGER.info(
f"The ACL list is FULL."

View File

@@ -1,5 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Final, Optional, Union
@@ -10,7 +9,6 @@ from primaite import USERS_CONFIG_DIR, getLogger
from primaite.common.enums import ActionType
_LOGGER = getLogger(__name__)
logging.basicConfig(level=logging.DEBUG, format="%(message)s")
_EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training"

View File

@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Dict, Final, List, Tuple, Union
import numpy as np
from gym import spaces
from primaite.acl.acl_rule import ACLRule
from primaite.common.enums import (
FileSystemState,
HardwareState,
@@ -22,7 +23,6 @@ from primaite.nodes.service_node import ServiceNode
if TYPE_CHECKING:
from primaite.environment.primaite_env import Primaite
_LOGGER = logging.getLogger(__name__)
@@ -346,16 +346,19 @@ class AccessControlList(AbstractObservationComponent):
# 1. Define the shape of your observation space component
acl_shape = [
len(RulePermissionType),
len(env.nodes),
len(env.nodes),
len(env.nodes) + 1,
len(env.nodes) + 1,
len(env.services_list),
len(env.ports_list),
env.max_number_acl_rules,
]
len(acl_shape)
# shape = acl_shape
shape = acl_shape * self.env.max_number_acl_rules
# 2. Create Observation space
self.space = spaces.MultiDiscrete(shape)
print("obs space:", self.space)
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
@@ -365,67 +368,85 @@ class AccessControlList(AbstractObservationComponent):
The structure of the observation space is described in :class:`.AccessControlList`
"""
obs = []
for acl_rule in self.env.acl.acl:
permission = acl_rule.permission
source_ip = acl_rule.source_ip
dest_ip = acl_rule.dest_ip
protocol = acl_rule.protocol
port = acl_rule.port
position = self.env.acl.acl.index(acl_rule)
if permission == "DENY":
permission_int = 0
else:
permission_int = 1
if source_ip == "ANY":
source_ip_int = 0
else:
source_ip_int = self.obtain_node_id_using_ip(source_ip)
if dest_ip == "ANY":
dest_ip_int = 0
else:
dest_ip_int = self.obtain_node_id_using_ip(dest_ip)
if protocol == "ANY":
protocol_int = 0
else:
try:
protocol_int = Protocol[protocol].value
except AttributeError:
_LOGGER.info(f"Service {protocol} could not be found")
protocol_int = -1
if port == "ANY":
port_int = 0
else:
if port in self.env.ports_list:
port_int = self.env.ports_list.index(port)
for index in range(len(self.env.acl.acl)):
acl_rule = self.env.acl.acl[index]
if isinstance(acl_rule, ACLRule):
permission = acl_rule.permission
source_ip = acl_rule.source_ip
dest_ip = acl_rule.dest_ip
protocol = acl_rule.protocol
port = acl_rule.port
position = index
source_ip_int = -1
dest_ip_int = -1
if permission == "DENY":
permission_int = 0
else:
_LOGGER.info(f"Port {port} could not be found.")
permission_int = 1
if source_ip == "ANY":
source_ip_int = 0
else:
nodes = list(self.env.nodes.values())
for node in nodes:
# print(node.ip_address, source_ip, node.ip_address == source_ip)
if (
isinstance(node, ServiceNode)
or isinstance(node, ActiveNode)
) and node.ip_address == source_ip:
source_ip_int = node.node_id
break
if dest_ip == "ANY":
dest_ip_int = 0
else:
nodes = list(self.env.nodes.values())
for node in nodes:
if (
isinstance(node, ServiceNode)
or isinstance(node, ActiveNode)
) and node.ip_address == dest_ip:
dest_ip_int = node.node_id
if protocol == "ANY":
protocol_int = 0
else:
try:
protocol_int = Protocol[protocol].value
except AttributeError:
_LOGGER.info(f"Service {protocol} could not be found")
protocol_int = -1
if port == "ANY":
port_int = 0
else:
if port in self.env.ports_list:
port_int = self.env.ports_list.index(port)
else:
_LOGGER.info(f"Port {port} could not be found.")
print(permission_int, source_ip, dest_ip, protocol_int, port_int, position)
obs.extend(
[
permission_int,
source_ip_int,
dest_ip_int,
protocol_int,
port_int,
position,
]
)
# Either do the multiply on the obs space
# Change the obs to
if source_ip_int != -1 and dest_ip_int != -1:
items_to_add = [
permission_int,
source_ip_int,
dest_ip_int,
protocol_int,
port_int,
position,
]
position = position * 6
for item in items_to_add:
obs.insert(position, int(item))
position += 1
else:
items_to_add = [-1, -1, -1, -1, -1, index]
position = index * 6
for item in items_to_add:
obs.insert(position, int(item))
position += 1
self.current_observation[:] = obs
def obtain_node_id_using_ip(self, ip_address):
"""Uses IP address of Nodes to find the ID.
Resolves IP address -> x (node id e.g. 1 or 2 or 3 or 4) for observation space
"""
print(type(self.env.nodes))
for key, node in self.env.nodes.items():
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
if node.ip_address == ip_address:
return key
_LOGGER.info(f"Node ID was not found from IP Address {ip_address}")
return -1
self.current_observation = obs
print("current observation space:", self.current_observation)
class ObservationsHandler:

View File

@@ -45,7 +45,7 @@ from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_nod
from primaite.transactions.transaction import Transaction
_LOGGER = logging.getLogger(__name__)
_LOGGER.setLevel(logging.INFO)
# _LOGGER.setLevel(logging.INFO)
class Primaite(Env):
@@ -119,6 +119,7 @@ class Primaite(Env):
# Create the Access Control List
self.acl = AccessControlList(
self.training_config.apply_implicit_rule,
self.training_config.implicit_acl_rule,
self.training_config.max_number_acl_rules,
)
@@ -546,6 +547,7 @@ class Primaite(Env):
action_destination_ip = readable_action[3]
action_protocol = readable_action[4]
action_port = readable_action[5]
acl_rule_position = readable_action[6]
if action_decision == 0:
# It's decided to do nothing
@@ -595,6 +597,7 @@ class Primaite(Env):
acl_rule_destination,
acl_rule_protocol,
acl_rule_port,
acl_rule_position,
)
elif action_decision == 2:
# Remove the rule
@@ -1172,13 +1175,9 @@ class Primaite(Env):
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
# [0, max acl rules - 1] - Position (0 = first index, then 1 -> x index resolving to acl rule in acl list)
# reserve 0 action to be a nothing action
actions = {0: [0, 0, 0, 0, 0, 0]}
actions = {0: [0, 0, 0, 0, 0, 0, 0]}
action_key = 1
print(
"what is this primaite_env.py 1177",
self.training_config.max_number_acl_rules - 1,
)
# 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE
for action_decision in range(3):
# 2 possible action permissions 0 = DENY, 1 = CREATE
@@ -1188,9 +1187,7 @@ class Primaite(Env):
for dest_ip in range(self.num_nodes + 1):
for protocol in range(self.num_services + 1):
for port in range(self.num_ports + 1):
for position in range(
self.training_config.max_number_acl_rules - 1
):
for position in range(self.max_number_acl_rules - 1):
action = [
action_decision,
action_permission,

View File

@@ -0,0 +1,86 @@
- item_type: PORTS
ports_list:
- port: '80'
- port: '21'
- item_type: SERVICES
service_list:
- name: TCP
- name: FTP
########################################
# Nodes
- item_type: NODE
node_id: '1'
name: PC1
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.1
software_state: COMPROMISED
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: FTP
port: '21'
state: GOOD
- item_type: NODE
node_id: '2'
name: SERVER
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.2
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: FTP
port: '21'
state: OVERWHELMED
- item_type: NODE
node_id: '3'
name: SWITCH1
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.1.3
software_state: GOOD
file_system_state: GOOD
########################################
# Links
- item_type: LINK
id: '4'
name: link1
bandwidth: 1000
source: '1'
destination: '3'
- item_type: LINK
id: '5'
name: link2
bandwidth: 1000
source: '3'
destination: '2'
#########################################
# IERS
- item_type: GREEN_IER
id: '5'
start_step: 0
end_step: 5
load: 999
protocol: TCP
port: '80'
source: '1'
destination: '2'
mission_criticality: 5
#########################################
# ACL Rules

View File

@@ -24,6 +24,19 @@ load_agent: False
# File path and file name of agent if you're loading one in
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Choice whether to have an ALLOW or DENY implicit rule or not (TRUE or FALSE)
apply_implicit_rule: True
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
implicit_acl_rule: DENY
# Total number of ACL rules allowed in the environment
max_number_acl_rules: 10
observation_space:
components:
- name: ACCESS_CONTROL_LIST
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000

View File

@@ -7,7 +7,7 @@ from primaite.acl.acl_rule import ACLRule
def test_acl_address_match_1():
"""Test that matching IP addresses produce True."""
acl = AccessControlList("DENY", 10)
acl = AccessControlList(True, "DENY", 10)
rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80")
@@ -16,7 +16,7 @@ def test_acl_address_match_1():
def test_acl_address_match_2():
"""Test that mismatching IP addresses produce False."""
acl = AccessControlList("DENY", 10)
acl = AccessControlList(True, "DENY", 10)
rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80")
@@ -25,7 +25,7 @@ def test_acl_address_match_2():
def test_acl_address_match_3():
"""Test the ANY condition for source IP addresses produce True."""
acl = AccessControlList("DENY", 10)
acl = AccessControlList(True, "DENY", 10)
rule = ACLRule("ALLOW", "ANY", "192.168.1.2", "TCP", "80")
@@ -34,7 +34,7 @@ def test_acl_address_match_3():
def test_acl_address_match_4():
"""Test the ANY condition for dest IP addresses produce True."""
acl = AccessControlList("DENY", 10)
acl = AccessControlList(True, "DENY", 10)
rule = ACLRule("ALLOW", "192.168.1.1", "ANY", "TCP", "80")
@@ -44,7 +44,7 @@ def test_acl_address_match_4():
def test_check_acl_block_affirmative():
"""Test the block function (affirmative)."""
# Create the Access Control List
acl = AccessControlList("DENY", 10)
acl = AccessControlList(True, "DENY", 10)
# Create a rule
acl_rule_permission = "ALLOW"
@@ -68,7 +68,7 @@ def test_check_acl_block_affirmative():
def test_check_acl_block_negative():
"""Test the block function (negative)."""
# Create the Access Control List
acl = AccessControlList("DENY", 10)
acl = AccessControlList(True, "DENY", 10)
# Create a rule
acl_rule_permission = "DENY"
@@ -93,7 +93,7 @@ def test_check_acl_block_negative():
def test_rule_hash():
"""Test the rule hash."""
# Create the Access Control List
acl = AccessControlList("DENY", 10)
acl = AccessControlList(True, "DENY", 10)
rule = ACLRule("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
hash_value_local = hash(rule)

View File

@@ -1,4 +1,7 @@
"""Test env creation and behaviour with different observation spaces."""
import time
import numpy as np
import pytest
@@ -12,6 +15,46 @@ from tests import TEST_CONFIG_ROOT
from tests.conftest import _get_primaite_env_from_config
def run_generic_set_actions(env: Primaite):
"""Run against a generic agent with specified blue agent actions."""
# Reset the environment at the start of the episode
# env.reset()
training_config = env.training_config
for episode in range(0, training_config.num_episodes):
for step in range(0, training_config.num_steps):
# Send the observation space to the agent to get an action
# TEMP - random action for now
# action = env.blue_agent_action(obs)
action = 0
print("\nStep:", step)
if step == 5:
# [1, 1, 2, 1, 1, 1, 2] ACL Action
# Creates an ACL rule
# Allows traffic from SERVER to PC1 on port TCP 80 and place ACL at position 2
action = 291
elif step == 7:
# [1, 1, 3, 1, 2, 2, 1] ACL Action
# Creates an ACL rule
# Allows traffic from PC1 to SWITCH 1 on port UDP at position 1
action = 425
# Run the simulation step on the live environment
obs, reward, done, info = env.step(action)
# Update observations space and return
env.update_environent_obs()
# Break if done is True
if done:
break
# Introduce a delay between steps
time.sleep(training_config.time_delay / 1000)
# Reset the environment at the end of the episode
# env.reset()
# env.close()
@pytest.fixture
def env(request):
"""Build Primaite environment for integration tests of observation space."""
@@ -131,11 +174,11 @@ class TestNodeLinkTable:
assert np.array_equal(
obs,
[
[1, 1, 3, 1, 1, 1],
[2, 1, 1, 1, 1, 4],
[3, 1, 1, 1, 0, 0],
[4, 0, 0, 0, 999, 0],
[5, 0, 0, 0, 999, 0],
[1, 1, 3, 1, 1, 1, 0],
[2, 1, 1, 1, 1, 4, 1],
[3, 1, 1, 1, 0, 0, 2],
[4, 0, 0, 0, 999, 0, 3],
[5, 0, 0, 0, 999, 0, 4],
],
)
@@ -260,4 +303,16 @@ class TestAccessControlList:
# therefore the first and third elements should be 6 and all others 0
# (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%)
print(obs)
assert np.array_equal(obs, [6, 0, 6, 0])
assert np.array_equal(obs, [])
def test_observation_space(self):
"""Test observation space is what is expected when an agent adds ACLs during an episode."""
# Used to use env from test fixture but AtrributeError function object has no 'training_config'
env = _get_primaite_env_from_config(
training_config_path=TEST_CONFIG_ROOT
/ "single_action_space_fixed_blue_actions_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown_ACL.yaml",
)
run_generic_set_actions(env)
# print("observation space",env.observation_space)

View File

@@ -66,7 +66,7 @@ def test_single_action_space_is_valid():
if len(dict_item) == 4:
contains_node_actions = True
# Link action detected
elif len(dict_item) == 6:
elif len(dict_item) == 7:
contains_acl_actions = True
# If both are there then the ANY action type is working
if contains_node_actions and contains_acl_actions:
@@ -92,7 +92,7 @@ def test_agent_is_executing_actions_from_both_spaces():
access_control_list = env.acl
# Use the Access Control List object acl object attribute to get dictionary
# Use dictionary.values() to get total list of all items in the dictionary
acl_rules_list = access_control_list.acl.values()
acl_rules_list = access_control_list.acl
# Length of this list tells you how many items are in the dictionary
# This number is the frequency of Access Control Rules in the environment
# In the scenario, we specified that the agent should create only 1 acl rule