From 3f440c0a281917f0c9c9469f2651534e84313ed8 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Wed, 5 Jul 2023 09:08:03 +0100 Subject: [PATCH] 901 - updated observations.py to change and add new mapping of ACL rules to represent no rule present in list --- src/primaite/acl/access_control_list.py | 27 ++--- src/primaite/common/enums.py | 7 +- .../training/training_config_main.yaml | 6 +- src/primaite/environment/observations.py | 98 +++++++++---------- src/primaite/environment/primaite_env.py | 3 +- ..._space_fixed_blue_actions_main_config.yaml | 4 +- tests/test_observation_space.py | 26 +++-- 7 files changed, 95 insertions(+), 76 deletions(-) diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index 9cc1225a..9e51e066 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -39,7 +39,9 @@ class AccessControlList: """ if self.acl_implicit_rule is not None: acl_list = self._acl + [self.acl_implicit_rule] - return acl_list + [-1] * (self.max_acl_rules - len(acl_list)) + else: + acl_list = self._acl + return acl_list + [None] * (self.max_acl_rules - len(acl_list)) def check_address_match(self, _rule, _source_ip_address, _dest_ip_address): """ @@ -86,15 +88,18 @@ class AccessControlList: Indicates block if all conditions are satisfied. """ for rule in self.acl: - if self.check_address_match(rule, _source_ip_address, _dest_ip_address): - if ( - rule.get_protocol() == _protocol or rule.get_protocol() == "ANY" - ) and (str(rule.get_port()) == str(_port) or rule.get_port() == "ANY"): - # There's a matching rule. Get the permission - if rule.get_permission() == "DENY": - return True - elif rule.get_permission() == "ALLOW": - return False + if isinstance(rule, ACLRule): + if self.check_address_match(rule, _source_ip_address, _dest_ip_address): + if ( + rule.get_protocol() == _protocol or rule.get_protocol() == "ANY" + ) and ( + str(rule.get_port()) == str(_port) or rule.get_port() == "ANY" + ): + # There's a matching rule. Get the permission + if rule.get_permission() == "DENY": + return True + elif rule.get_permission() == "ALLOW": + return False # If there has been no rule to allow the IER through, it will return a blocked signal by default return True @@ -115,7 +120,6 @@ class AccessControlList: """ position_index = int(_position) new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) - print(len(self._acl)) if len(self._acl) + 1 < self.max_acl_rules: if _position is not None: if self.max_acl_rules - 1 > position_index > -1: @@ -136,6 +140,7 @@ class AccessControlList: f"The ACL list is FULL." f"The list of ACLs has length {len(self.acl)} and it has a max capacity of {self.max_acl_rules}." ) + # print("length of this list", len(self._acl)) def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port): """ diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 6a0c8f29..ad6c84a1 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -131,7 +131,8 @@ class LinkStatus(Enum): class RulePermissionType(Enum): - """Implicit firewall rule.""" + """Any firewall rule type.""" - DENY = 0 - ALLOW = 1 + NA = 0 + DENY = 1 + ALLOW = 2 diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index d01f51f3..233c299e 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -5,14 +5,14 @@ # "STABLE_BASELINES3_PPO" # "STABLE_BASELINES3_A2C" # "GENERIC" -agent_identifier: STABLE_BASELINES3_A2C +agent_identifier: STABLE_BASELINES3_PPO # Sets How the Action Space is defined: # "NODE" # "ACL" # "ANY" node and acl actions -action_type: NODE +action_type: ACL # Number of episodes to run per session -num_episodes: 10 +num_episodes: 1000 # Number of time_steps per episode num_steps: 256 # Time delay between steps (for generic agents) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 2aacda8f..d254598b 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -10,7 +10,6 @@ from primaite.acl.acl_rule import ACLRule from primaite.common.enums import ( FileSystemState, HardwareState, - Protocol, RulePermissionType, SoftwareState, ) @@ -330,13 +329,14 @@ class AccessControlList(AbstractObservationComponent): ] """ + 0, # Terms (for ACL observation space): - # [0, 1] - Permission (0 = DENY, 1 = ALLOW) - # [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses) - # [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses) - # [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol) - # [0, num ports] - Port (0 = any, then 1 -> x resolving to port) - # [0, max acl rules - 1] - Position (0 = first index, then 1 -> x index resolving to acl rule in acl list) + # [0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW) + # [0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to IP addresses) + # [0, num nodes] - Dest IP (0 = NA, 1 = any, then 2 -> x resolving to IP addresses) + # [0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol) + # [0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port) + # [0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list) _DATA_TYPE: type = np.int64 @@ -346,18 +346,17 @@ class AccessControlList(AbstractObservationComponent): # 1. Define the shape of your observation space component acl_shape = [ len(RulePermissionType), - len(env.nodes) + 1, - len(env.nodes) + 1, - len(env.services_list), - len(env.ports_list), - env.max_number_acl_rules, + len(env.nodes) + 2, + len(env.nodes) + 2, + len(env.services_list) + 1, + len(env.ports_list) + 1, + env.max_number_acl_rules + 1, ] - # shape = acl_shape shape = acl_shape * self.env.max_number_acl_rules # 2. Create Observation space self.space = spaces.MultiDiscrete(shape) - print("obs space:", self.space) + # print("obs space:", self.space) # 3. Initialise observation with zeroes self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) @@ -367,8 +366,8 @@ class AccessControlList(AbstractObservationComponent): The structure of the observation space is described in :class:`.AccessControlList` """ obs = [] - - for index in range(len(self.env.acl.acl)): + # print("starting len", len(self.env.acl.acl)) + for index in range(0, len(self.env.acl.acl)): acl_rule = self.env.acl.acl[index] if isinstance(acl_rule, ACLRule): permission = acl_rule.permission @@ -378,26 +377,25 @@ class AccessControlList(AbstractObservationComponent): port = acl_rule.port position = index - source_ip_int = -1 - dest_ip_int = -1 + source_ip_int = None + dest_ip_int = None if permission == "DENY": - permission_int = 0 - else: permission_int = 1 + else: + permission_int = 2 if source_ip == "ANY": - source_ip_int = 0 + source_ip_int = 1 else: nodes = list(self.env.nodes.values()) for node in nodes: - # print(node.ip_address, source_ip, node.ip_address == source_ip) if ( isinstance(node, ServiceNode) or isinstance(node, ActiveNode) ) and node.ip_address == source_ip: - source_ip_int = node.node_id + source_ip_int = int(node.node_id) + 1 break if dest_ip == "ANY": - dest_ip_int = 0 + dest_ip_int = 1 else: nodes = list(self.env.nodes.values()) for node in nodes: @@ -405,46 +403,46 @@ class AccessControlList(AbstractObservationComponent): isinstance(node, ServiceNode) or isinstance(node, ActiveNode) ) and node.ip_address == dest_ip: - dest_ip_int = node.node_id + dest_ip_int = int(node.node_id) + 1 if protocol == "ANY": - protocol_int = 0 + protocol_int = 1 else: try: - protocol_int = Protocol[protocol].value + protocol_int = self.env.services_list.index(protocol) + 2 except AttributeError: _LOGGER.info(f"Service {protocol} could not be found") - protocol_int = -1 + protocol_int = None if port == "ANY": - port_int = 0 + port_int = 1 else: if port in self.env.ports_list: - port_int = self.env.ports_list.index(port) + port_int = self.env.ports_list.index(port) + 2 else: _LOGGER.info(f"Port {port} could not be found.") # Either do the multiply on the obs space # Change the obs to - if source_ip_int != -1 and dest_ip_int != -1: - items_to_add = [ - permission_int, - source_ip_int, - dest_ip_int, - protocol_int, - port_int, - position, - ] - position = position * 6 - for item in items_to_add: - obs.insert(position, int(item)) - position += 1 - else: - items_to_add = [-1, -1, -1, -1, -1, index] - position = index * 6 - for item in items_to_add: - obs.insert(position, int(item)) - position += 1 + items_to_add = [ + permission_int, + source_ip_int, + dest_ip_int, + protocol_int, + port_int, + position, + ] + position = position * 6 + for item in items_to_add: + # print("position", position, "\nitem", int(item)) + obs.insert(position, int(item)) + position += 1 + else: + starting_position = index * 6 + for placeholder in range(6): + obs.insert(starting_position, 0) + starting_position += 1 - self.current_observation = obs + # print("current obs", obs, "\n" ,len(obs)) + self.current_observation[:] = obs class ObservationsHandler: diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index f6a3d48e..3386a96c 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -1204,7 +1204,8 @@ class Primaite(Env): # [0, max acl rules - 1] - Position (0 = first index, then 1 -> x index resolving to acl rule in acl list) # reserve 0 action to be a nothing action actions = {0: [0, 0, 0, 0, 0, 0, 0]} - + # [1, 1, 2, 1, 1, 1, 2] CREATE RULE ALLOW NODE 2 TO NODE 1 ON SERVICE 1 PORT 1 AT INDEX 2 + # 1, 2, 1, 6, 0, 0, 1 ALLOW NODE 2 TO NODE 1 ON SERVICE 1 SERVICE ANY PORT ANY AT INDEX 1 action_key = 1 # 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE for action_decision in range(3): diff --git a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml index e2718c53..3c2e8125 100644 --- a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml +++ b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml @@ -26,12 +26,12 @@ agent_load_file: C:\[Path]\[agent_saved_filename.zip] -# Choice whether to have an ALLOW or DENY implicit rule or not (TRUE or FALSE) +# Choice whether to have an ALLOW or DENY implicit rule or not (True or False) apply_implicit_rule: True # Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY) implicit_acl_rule: DENY # Total number of ACL rules allowed in the environment -max_number_acl_rules: 10 +max_number_acl_rules: 3 observation_space: components: diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index 5408bee6..bde8a826 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -18,7 +18,7 @@ from tests.conftest import _get_primaite_env_from_config def run_generic_set_actions(env: Primaite): """Run against a generic agent with specified blue agent actions.""" # Reset the environment at the start of the episode - # env.reset() + env.reset() training_config = env.training_config for episode in range(0, training_config.num_episodes): for step in range(0, training_config.num_steps): @@ -31,12 +31,14 @@ def run_generic_set_actions(env: Primaite): # [1, 1, 2, 1, 1, 1, 2] ACL Action # Creates an ACL rule # Allows traffic from SERVER to PC1 on port TCP 80 and place ACL at position 2 - action = 291 + # Rule in current observation: [2, 2, 3, 2, 2, 2] + action = 43 elif step == 7: # [1, 1, 3, 1, 2, 2, 1] ACL Action # Creates an ACL rule # Allows traffic from PC1 to SWITCH 1 on port UDP at position 1 - action = 425 + # 3, 1, 1, 1, 1, + action = 96 # Run the simulation step on the live environment obs, reward, done, info = env.step(action) # Update observations space and return @@ -282,7 +284,7 @@ class TestAccessControlList: env.update_environent_obs() # we have two ACLs - assert env.env_obs.shape == (5, 2) + assert env.env_obs.shape == (6 * 3) def test_values(self, env: Primaite): """Test that traffic values are encoded correctly. @@ -305,7 +307,7 @@ class TestAccessControlList: print(obs) assert np.array_equal(obs, []) - def test_observation_space(self): + def test_observation_space_with_implicit_rule(self): """Test observation space is what is expected when an agent adds ACLs during an episode.""" # Used to use env from test fixture but AtrributeError function object has no 'training_config' env = _get_primaite_env_from_config( @@ -314,5 +316,17 @@ class TestAccessControlList: lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown_ACL.yaml", ) run_generic_set_actions(env) + obs = env.env_obs + """ + Observation space at the end of the episode. + At the start of the episode, there is a single implicit Deny rule = 1,1,1,1,1,0 + (0 represents its initial position at top of ACL list) + On Step 5, there is a rule added at POSITION 2: 2,2,3,2,3,0 + On Step 7, there is a second rule added at POSITION 1: 2,4,2,3,3,1 + THINK THE RULES SHOULD BE THE OTHER WAY AROUND IN THE CURRENT OBSERVATION + """ - # print("observation space",env.observation_space) + # assert current_obs == [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2] + assert np.array_equal( + obs, [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2] + )