From df42a791c9bd6729c8ab36d3e92a62398bd8b8c5 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 20 Jun 2023 11:47:20 +0100 Subject: [PATCH] 901 - changed ACL instantiation and changed acl t private _acl (list not dict) attribute, added laydown_ACL.yaml for testing, fixed encoding of acl rules to integers for obs space, added ACL position to node action space and added generic test where agents adds two ACL rules. --- src/primaite/acl/access_control_list.py | 66 +++++--- src/primaite/config/training_config.py | 2 - src/primaite/environment/observations.py | 145 ++++++++++-------- src/primaite/environment/primaite_env.py | 15 +- tests/config/obs_tests/laydown_ACL.yaml | 86 +++++++++++ ..._space_fixed_blue_actions_main_config.yaml | 13 ++ tests/test_acl.py | 14 +- tests/test_observation_space.py | 67 +++++++- tests/test_single_action_space.py | 4 +- 9 files changed, 305 insertions(+), 107 deletions(-) create mode 100644 tests/config/obs_tests/laydown_ACL.yaml diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index d75b9756..219ba002 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -11,21 +11,43 @@ _LOGGER: Final[logging.Logger] = logging.getLogger(__name__) class AccessControlList: """Access Control List class.""" - def __init__(self, implicit_permission, max_acl_rules): + def __init__(self, apply_implicit_rule, implicit_permission, max_acl_rules): """Init.""" + # Bool option in main_config to decide to use implicit rule or not + self.apply_implicit_rule: bool = apply_implicit_rule # Implicit ALLOW or DENY firewall spec # Last rule in the ACL list - self.acl_implicit_rule = implicit_permission - # Create implicit rule based on input - if self.acl_implicit_rule == "DENY": - implicit_rule = ACLRule("DENY", "ANY", "ANY", "ANY", "ANY") - else: - implicit_rule = ACLRule("ALLOW", "ANY", "ANY", "ANY", "ANY") - + self.acl_implicit_permission = implicit_permission # Maximum number of ACL Rules in ACL self.max_acl_rules: int = max_acl_rules # A list of ACL Rules - self.acl: List[ACLRule] = [implicit_rule] + self._acl: List[ACLRule] = [] + # Implicit rule + + @property + def acl_implicit_rule(self): + """ACL implicit rule class attribute with added logic to change it depending on option in main_config.""" + # Create implicit rule based on input + if self.apply_implicit_rule: + if self.acl_implicit_permission == "DENY": + return ACLRule("DENY", "ANY", "ANY", "ANY", "ANY") + elif self.acl_implicit_permission == "ALLOW": + return ACLRule("ALLOW", "ANY", "ANY", "ANY", "ANY") + else: + return None + else: + return None + + @property + def acl(self): + """Public access method for private _acl. + + Adds implicit rule to end of acl list and + Pads out rest of list (if empty) with -1. + """ + if self.acl_implicit_rule is not None: + acl_list = self._acl + [self.acl_implicit_rule] + return acl_list + [-1] * (self.max_acl_rules - len(acl_list)) def check_address_match(self, _rule, _source_ip_address, _dest_ip_address): """ @@ -85,7 +107,9 @@ class AccessControlList: # If there has been no rule to allow the IER through, it will return a blocked signal by default return True - def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port, _position): + def add_rule( + self, _permission, _source_ip, _dest_ip, _protocol, _port, _position=None + ): """ Adds a new rule. @@ -99,18 +123,22 @@ class AccessControlList: """ position_index = int(_position) new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) - if len(self.acl) < self.max_acl_rules: - if len(self.acl) > position_index > -1: - try: - self.acl.insert(position_index, new_rule) - except Exception: + print(len(self._acl)) + if len(self._acl) + 1 < self.max_acl_rules: + if _position is not None: + if self.max_acl_rules - 1 > position_index > -1: + try: + self._acl.insert(position_index, new_rule) + except Exception: + _LOGGER.info( + f"New Rule could NOT be added to list at position {position_index}." + ) + else: _LOGGER.info( - f"New Rule could NOT be added to list at position {position_index}." + f"Position {position_index} is an invalid index for list/overwrites implicit firewall rule" ) else: - _LOGGER.info( - f"Position {position_index} is an invalid index for list and/or overwrites implicit firewall rule" - ) + self.acl.append(new_rule) else: _LOGGER.info( f"The ACL list is FULL." diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 14102432..9a21d087 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -1,5 +1,4 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. -import logging from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, Final, Optional, Union @@ -10,7 +9,6 @@ from primaite import USERS_CONFIG_DIR, getLogger from primaite.common.enums import ActionType _LOGGER = getLogger(__name__) -logging.basicConfig(level=logging.DEBUG, format="%(message)s") _EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training" diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index fe43c9e3..eb7ad2bf 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Dict, Final, List, Tuple, Union import numpy as np from gym import spaces +from primaite.acl.acl_rule import ACLRule from primaite.common.enums import ( FileSystemState, HardwareState, @@ -22,7 +23,6 @@ from primaite.nodes.service_node import ServiceNode if TYPE_CHECKING: from primaite.environment.primaite_env import Primaite - _LOGGER = logging.getLogger(__name__) @@ -346,16 +346,19 @@ class AccessControlList(AbstractObservationComponent): # 1. Define the shape of your observation space component acl_shape = [ len(RulePermissionType), - len(env.nodes), - len(env.nodes), + len(env.nodes) + 1, + len(env.nodes) + 1, len(env.services_list), len(env.ports_list), + env.max_number_acl_rules, ] + len(acl_shape) + # shape = acl_shape shape = acl_shape * self.env.max_number_acl_rules # 2. Create Observation space self.space = spaces.MultiDiscrete(shape) - + print("obs space:", self.space) # 3. Initialise observation with zeroes self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) @@ -365,67 +368,85 @@ class AccessControlList(AbstractObservationComponent): The structure of the observation space is described in :class:`.AccessControlList` """ obs = [] - for acl_rule in self.env.acl.acl: - permission = acl_rule.permission - source_ip = acl_rule.source_ip - dest_ip = acl_rule.dest_ip - protocol = acl_rule.protocol - port = acl_rule.port - position = self.env.acl.acl.index(acl_rule) - if permission == "DENY": - permission_int = 0 - else: - permission_int = 1 - if source_ip == "ANY": - source_ip_int = 0 - else: - source_ip_int = self.obtain_node_id_using_ip(source_ip) - if dest_ip == "ANY": - dest_ip_int = 0 - else: - dest_ip_int = self.obtain_node_id_using_ip(dest_ip) - if protocol == "ANY": - protocol_int = 0 - else: - try: - protocol_int = Protocol[protocol].value - except AttributeError: - _LOGGER.info(f"Service {protocol} could not be found") - protocol_int = -1 - if port == "ANY": - port_int = 0 - else: - if port in self.env.ports_list: - port_int = self.env.ports_list.index(port) + + for index in range(len(self.env.acl.acl)): + acl_rule = self.env.acl.acl[index] + if isinstance(acl_rule, ACLRule): + permission = acl_rule.permission + source_ip = acl_rule.source_ip + dest_ip = acl_rule.dest_ip + protocol = acl_rule.protocol + port = acl_rule.port + position = index + + source_ip_int = -1 + dest_ip_int = -1 + if permission == "DENY": + permission_int = 0 else: - _LOGGER.info(f"Port {port} could not be found.") + permission_int = 1 + if source_ip == "ANY": + source_ip_int = 0 + else: + nodes = list(self.env.nodes.values()) + for node in nodes: + # print(node.ip_address, source_ip, node.ip_address == source_ip) + if ( + isinstance(node, ServiceNode) + or isinstance(node, ActiveNode) + ) and node.ip_address == source_ip: + source_ip_int = node.node_id + break + if dest_ip == "ANY": + dest_ip_int = 0 + else: + nodes = list(self.env.nodes.values()) + for node in nodes: + if ( + isinstance(node, ServiceNode) + or isinstance(node, ActiveNode) + ) and node.ip_address == dest_ip: + dest_ip_int = node.node_id + if protocol == "ANY": + protocol_int = 0 + else: + try: + protocol_int = Protocol[protocol].value + except AttributeError: + _LOGGER.info(f"Service {protocol} could not be found") + protocol_int = -1 + if port == "ANY": + port_int = 0 + else: + if port in self.env.ports_list: + port_int = self.env.ports_list.index(port) + else: + _LOGGER.info(f"Port {port} could not be found.") - print(permission_int, source_ip, dest_ip, protocol_int, port_int, position) - obs.extend( - [ - permission_int, - source_ip_int, - dest_ip_int, - protocol_int, - port_int, - position, - ] - ) + # Either do the multiply on the obs space + # Change the obs to + if source_ip_int != -1 and dest_ip_int != -1: + items_to_add = [ + permission_int, + source_ip_int, + dest_ip_int, + protocol_int, + port_int, + position, + ] + position = position * 6 + for item in items_to_add: + obs.insert(position, int(item)) + position += 1 + else: + items_to_add = [-1, -1, -1, -1, -1, index] + position = index * 6 + for item in items_to_add: + obs.insert(position, int(item)) + position += 1 - self.current_observation[:] = obs - - def obtain_node_id_using_ip(self, ip_address): - """Uses IP address of Nodes to find the ID. - - Resolves IP address -> x (node id e.g. 1 or 2 or 3 or 4) for observation space - """ - print(type(self.env.nodes)) - for key, node in self.env.nodes.items(): - if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): - if node.ip_address == ip_address: - return key - _LOGGER.info(f"Node ID was not found from IP Address {ip_address}") - return -1 + self.current_observation = obs + print("current observation space:", self.current_observation) class ObservationsHandler: diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 783b4267..c0ee04f5 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -45,7 +45,7 @@ from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_nod from primaite.transactions.transaction import Transaction _LOGGER = logging.getLogger(__name__) -_LOGGER.setLevel(logging.INFO) +# _LOGGER.setLevel(logging.INFO) class Primaite(Env): @@ -119,6 +119,7 @@ class Primaite(Env): # Create the Access Control List self.acl = AccessControlList( + self.training_config.apply_implicit_rule, self.training_config.implicit_acl_rule, self.training_config.max_number_acl_rules, ) @@ -546,6 +547,7 @@ class Primaite(Env): action_destination_ip = readable_action[3] action_protocol = readable_action[4] action_port = readable_action[5] + acl_rule_position = readable_action[6] if action_decision == 0: # It's decided to do nothing @@ -595,6 +597,7 @@ class Primaite(Env): acl_rule_destination, acl_rule_protocol, acl_rule_port, + acl_rule_position, ) elif action_decision == 2: # Remove the rule @@ -1172,13 +1175,9 @@ class Primaite(Env): # [0, num ports] - Port (0 = any, then 1 -> x resolving to port) # [0, max acl rules - 1] - Position (0 = first index, then 1 -> x index resolving to acl rule in acl list) # reserve 0 action to be a nothing action - actions = {0: [0, 0, 0, 0, 0, 0]} + actions = {0: [0, 0, 0, 0, 0, 0, 0]} action_key = 1 - print( - "what is this primaite_env.py 1177", - self.training_config.max_number_acl_rules - 1, - ) # 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE for action_decision in range(3): # 2 possible action permissions 0 = DENY, 1 = CREATE @@ -1188,9 +1187,7 @@ class Primaite(Env): for dest_ip in range(self.num_nodes + 1): for protocol in range(self.num_services + 1): for port in range(self.num_ports + 1): - for position in range( - self.training_config.max_number_acl_rules - 1 - ): + for position in range(self.max_number_acl_rules - 1): action = [ action_decision, action_permission, diff --git a/tests/config/obs_tests/laydown_ACL.yaml b/tests/config/obs_tests/laydown_ACL.yaml new file mode 100644 index 00000000..cffd8b1c --- /dev/null +++ b/tests/config/obs_tests/laydown_ACL.yaml @@ -0,0 +1,86 @@ +- item_type: PORTS + ports_list: + - port: '80' + - port: '21' +- item_type: SERVICES + service_list: + - name: TCP + - name: FTP + +######################################## +# Nodes +- item_type: NODE + node_id: '1' + name: PC1 + node_class: SERVICE + node_type: COMPUTER + priority: P5 + hardware_state: 'ON' + ip_address: 192.168.1.1 + software_state: COMPROMISED + file_system_state: GOOD + services: + - name: TCP + port: '80' + state: GOOD + - name: FTP + port: '21' + state: GOOD +- item_type: NODE + node_id: '2' + name: SERVER + node_class: SERVICE + node_type: SERVER + priority: P5 + hardware_state: 'ON' + ip_address: 192.168.1.2 + software_state: GOOD + file_system_state: GOOD + services: + - name: TCP + port: '80' + state: GOOD + - name: FTP + port: '21' + state: OVERWHELMED +- item_type: NODE + node_id: '3' + name: SWITCH1 + node_class: ACTIVE + node_type: SWITCH + priority: P2 + hardware_state: 'ON' + ip_address: 192.168.1.3 + software_state: GOOD + file_system_state: GOOD + +######################################## +# Links +- item_type: LINK + id: '4' + name: link1 + bandwidth: 1000 + source: '1' + destination: '3' +- item_type: LINK + id: '5' + name: link2 + bandwidth: 1000 + source: '3' + destination: '2' + +######################################### +# IERS +- item_type: GREEN_IER + id: '5' + start_step: 0 + end_step: 5 + load: 999 + protocol: TCP + port: '80' + source: '1' + destination: '2' + mission_criticality: 5 + +######################################### +# ACL Rules diff --git a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml index 5c5db582..e2718c53 100644 --- a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml +++ b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml @@ -24,6 +24,19 @@ load_agent: False # File path and file name of agent if you're loading one in agent_load_file: C:\[Path]\[agent_saved_filename.zip] + + +# Choice whether to have an ALLOW or DENY implicit rule or not (TRUE or FALSE) +apply_implicit_rule: True +# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY) +implicit_acl_rule: DENY +# Total number of ACL rules allowed in the environment +max_number_acl_rules: 10 + +observation_space: + components: + - name: ACCESS_CONTROL_LIST + # Environment config values # The high value for the observation space observation_space_high_value: 1000000000 diff --git a/tests/test_acl.py b/tests/test_acl.py index e99f5ee0..f790a5cf 100644 --- a/tests/test_acl.py +++ b/tests/test_acl.py @@ -7,7 +7,7 @@ from primaite.acl.acl_rule import ACLRule def test_acl_address_match_1(): """Test that matching IP addresses produce True.""" - acl = AccessControlList("DENY", 10) + acl = AccessControlList(True, "DENY", 10) rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80") @@ -16,7 +16,7 @@ def test_acl_address_match_1(): def test_acl_address_match_2(): """Test that mismatching IP addresses produce False.""" - acl = AccessControlList("DENY", 10) + acl = AccessControlList(True, "DENY", 10) rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80") @@ -25,7 +25,7 @@ def test_acl_address_match_2(): def test_acl_address_match_3(): """Test the ANY condition for source IP addresses produce True.""" - acl = AccessControlList("DENY", 10) + acl = AccessControlList(True, "DENY", 10) rule = ACLRule("ALLOW", "ANY", "192.168.1.2", "TCP", "80") @@ -34,7 +34,7 @@ def test_acl_address_match_3(): def test_acl_address_match_4(): """Test the ANY condition for dest IP addresses produce True.""" - acl = AccessControlList("DENY", 10) + acl = AccessControlList(True, "DENY", 10) rule = ACLRule("ALLOW", "192.168.1.1", "ANY", "TCP", "80") @@ -44,7 +44,7 @@ def test_acl_address_match_4(): def test_check_acl_block_affirmative(): """Test the block function (affirmative).""" # Create the Access Control List - acl = AccessControlList("DENY", 10) + acl = AccessControlList(True, "DENY", 10) # Create a rule acl_rule_permission = "ALLOW" @@ -68,7 +68,7 @@ def test_check_acl_block_affirmative(): def test_check_acl_block_negative(): """Test the block function (negative).""" # Create the Access Control List - acl = AccessControlList("DENY", 10) + acl = AccessControlList(True, "DENY", 10) # Create a rule acl_rule_permission = "DENY" @@ -93,7 +93,7 @@ def test_check_acl_block_negative(): def test_rule_hash(): """Test the rule hash.""" # Create the Access Control List - acl = AccessControlList("DENY", 10) + acl = AccessControlList(True, "DENY", 10) rule = ACLRule("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80") hash_value_local = hash(rule) diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index 4e8df7e1..5408bee6 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -1,4 +1,7 @@ """Test env creation and behaviour with different observation spaces.""" + +import time + import numpy as np import pytest @@ -12,6 +15,46 @@ from tests import TEST_CONFIG_ROOT from tests.conftest import _get_primaite_env_from_config +def run_generic_set_actions(env: Primaite): + """Run against a generic agent with specified blue agent actions.""" + # Reset the environment at the start of the episode + # env.reset() + training_config = env.training_config + for episode in range(0, training_config.num_episodes): + for step in range(0, training_config.num_steps): + # Send the observation space to the agent to get an action + # TEMP - random action for now + # action = env.blue_agent_action(obs) + action = 0 + print("\nStep:", step) + if step == 5: + # [1, 1, 2, 1, 1, 1, 2] ACL Action + # Creates an ACL rule + # Allows traffic from SERVER to PC1 on port TCP 80 and place ACL at position 2 + action = 291 + elif step == 7: + # [1, 1, 3, 1, 2, 2, 1] ACL Action + # Creates an ACL rule + # Allows traffic from PC1 to SWITCH 1 on port UDP at position 1 + action = 425 + # Run the simulation step on the live environment + obs, reward, done, info = env.step(action) + # Update observations space and return + env.update_environent_obs() + + # Break if done is True + if done: + break + + # Introduce a delay between steps + time.sleep(training_config.time_delay / 1000) + + # Reset the environment at the end of the episode + # env.reset() + + # env.close() + + @pytest.fixture def env(request): """Build Primaite environment for integration tests of observation space.""" @@ -131,11 +174,11 @@ class TestNodeLinkTable: assert np.array_equal( obs, [ - [1, 1, 3, 1, 1, 1], - [2, 1, 1, 1, 1, 4], - [3, 1, 1, 1, 0, 0], - [4, 0, 0, 0, 999, 0], - [5, 0, 0, 0, 999, 0], + [1, 1, 3, 1, 1, 1, 0], + [2, 1, 1, 1, 1, 4, 1], + [3, 1, 1, 1, 0, 0, 2], + [4, 0, 0, 0, 999, 0, 3], + [5, 0, 0, 0, 999, 0, 4], ], ) @@ -260,4 +303,16 @@ class TestAccessControlList: # therefore the first and third elements should be 6 and all others 0 # (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%) print(obs) - assert np.array_equal(obs, [6, 0, 6, 0]) + assert np.array_equal(obs, []) + + def test_observation_space(self): + """Test observation space is what is expected when an agent adds ACLs during an episode.""" + # Used to use env from test fixture but AtrributeError function object has no 'training_config' + env = _get_primaite_env_from_config( + training_config_path=TEST_CONFIG_ROOT + / "single_action_space_fixed_blue_actions_main_config.yaml", + lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown_ACL.yaml", + ) + run_generic_set_actions(env) + + # print("observation space",env.observation_space) diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index 4d6136a9..78764976 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -66,7 +66,7 @@ def test_single_action_space_is_valid(): if len(dict_item) == 4: contains_node_actions = True # Link action detected - elif len(dict_item) == 6: + elif len(dict_item) == 7: contains_acl_actions = True # If both are there then the ANY action type is working if contains_node_actions and contains_acl_actions: @@ -92,7 +92,7 @@ def test_agent_is_executing_actions_from_both_spaces(): access_control_list = env.acl # Use the Access Control List object acl object attribute to get dictionary # Use dictionary.values() to get total list of all items in the dictionary - acl_rules_list = access_control_list.acl.values() + acl_rules_list = access_control_list.acl # Length of this list tells you how many items are in the dictionary # This number is the frequency of Access Control Rules in the environment # In the scenario, we specified that the agent should create only 1 acl rule