diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index ce942111..c9674e48 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -35,13 +35,14 @@ class AccessControlList: def acl(self): """Public access method for private _acl. - Adds implicit rule to end of acl list and - Pads out rest of list (if empty) with -1. + Adds implicit rule to the BACK of the list after ALL the OTHER ACL rules and + pads out rest of list (if it is empty) with None. """ if self.acl_implicit_rule is not None: acl_list = self._acl + [self.acl_implicit_rule] else: acl_list = self._acl + return acl_list + [None] * (self.max_acl_rules - len(acl_list)) def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool: @@ -113,13 +114,17 @@ class AccessControlList: return new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) + # Checks position is in correct range if self.max_acl_rules - 1 > position_index > -1: try: _LOGGER.info(f"Position {position_index} is valid.") + # Check to see Agent will not overwrite current ACL in ACL list if self._acl[position_index] is None: _LOGGER.info(f"Inserting rule {new_rule} at position {position_index}") + # Adds rule self._acl[position_index] = new_rule else: + # Cannot overwrite it _LOGGER.info(f"Error: inserting rule at non-empty position {position_index}") return except Exception: @@ -140,7 +145,7 @@ class AccessControlList: """ # Add check so you cant remove implicit rule rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) - # There will not always be something 'popable' since the agent will be trying random things + # There will not always be something removable since the agent will be trying random things try: self.acl.remove(rule) except Exception: diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index c743e41a..66f9e1eb 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -408,9 +408,6 @@ class AccessControlList(AbstractObservationComponent): The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by integers. - :param env: The environment that forms the basis of the observations - :type env: Primaite - Each ACL Rule has 6 elements. It will have the following structure: .. code-block:: [ @@ -429,6 +426,7 @@ class AccessControlList(AbstractObservationComponent): ... ] + Terms (for ACL Observation Space): [0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW) [0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to IP addresses) @@ -436,27 +434,37 @@ class AccessControlList(AbstractObservationComponent): [0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol) [0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port) [0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list) + + NOTE: NA is Non-Applicable - this means the ACL Rule in the list is a NoneType and NOT an ACLRule object. """ _DATA_TYPE: type = np.int64 def __init__(self, env: "Primaite"): + """ + Initialise an AccessControlList observation component. + + :param env: The environment that forms the basis of the observations + :type env: Primaite + """ super().__init__(env) # 1. Define the shape of your observation space component + # The NA and ANY types means that there are 2 extra items for Nodes, Services and Ports. + # Number of ACL rules incremented by 1 for positions starting at index 0. acl_shape = [ len(RulePermissionType), len(env.nodes) + 2, len(env.nodes) + 2, len(env.services_list) + 2, len(env.ports_list) + 2, - env.max_number_acl_rules + 1, + env.max_number_acl_rules, ] shape = acl_shape * self.env.max_number_acl_rules # 2. Create Observation space self.space = spaces.MultiDiscrete(shape) - # print("obs space:", self.space) + # 3. Initialise observation with zeroes self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) @@ -468,7 +476,7 @@ class AccessControlList(AbstractObservationComponent): The structure of the observation space is described in :class:`.AccessControlList` """ obs = [] - # print("starting len", len(self.env.acl.acl)) + for index in range(0, len(self.env.acl.acl)): acl_rule = self.env.acl.acl[index] if isinstance(acl_rule, ACLRule): @@ -478,7 +486,7 @@ class AccessControlList(AbstractObservationComponent): protocol = acl_rule.protocol port = acl_rule.port position = index - + # Map each ACL attribute from what it was to an integer to fit the observation space source_ip_int = None dest_ip_int = None if permission == "DENY": @@ -488,6 +496,7 @@ class AccessControlList(AbstractObservationComponent): if source_ip == "ANY": source_ip_int = 1 else: + # Map Node ID (+ 1) to source IP address nodes = list(self.env.nodes.values()) for node in nodes: if ( @@ -498,6 +507,8 @@ class AccessControlList(AbstractObservationComponent): if dest_ip == "ANY": dest_ip_int = 1 else: + # Map Node ID (+ 1) to dest IP address + # Index of Nodes start at 1 so + 1 is needed so NA can be added. nodes = list(self.env.nodes.values()) for node in nodes: if ( @@ -507,6 +518,7 @@ class AccessControlList(AbstractObservationComponent): if protocol == "ANY": protocol_int = 1 else: + # Index of protocols and ports start from 0 so + 2 is needed to add NA and ANY try: protocol_int = self.env.services_list.index(protocol) + 2 except AttributeError: @@ -520,7 +532,7 @@ class AccessControlList(AbstractObservationComponent): else: _LOGGER.info(f"Port {port} could not be found.") port_int = None - + # Add to current obs obs.extend( [ permission_int, @@ -533,9 +545,9 @@ class AccessControlList(AbstractObservationComponent): ) else: + # The Nothing or NA representation of 'NONE' ACL rules obs.extend([0, 0, 0, 0, 0, 0]) - # print("current obs", obs, "\n" ,len(obs)) self.current_observation[:] = obs def generate_structure(self): diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index e4702c84..a06e93ed 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -19,7 +19,7 @@ def run_generic_set_actions(env: Primaite): # TEMP - random action for now # action = env.blue_agent_action(obs) action = 0 - print("Episode:", episode, "\nStep:", step) + # print("Episode:", episode, "\nStep:", step) if step == 5: # [1, 1, 2, 1, 1, 1, 1(position)] # Creates an ACL rule