diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index e356d276..7eeef731 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -77,7 +77,7 @@ class AccessControlList: # If there has been no rule to allow the IER through, it will return a blocked signal by default return True - def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port): + def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port, _position): """ Adds a new rule. @@ -87,10 +87,10 @@ class AccessControlList: _dest_ip: the destination IP address _protocol: the protocol _port: the port + _position: position to insert ACL rule into ACL list """ new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) - hash_value = hash(new_rule) - self.acl[hash_value] = new_rule + self.acl.insert(_position, new_rule) def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port): """ diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index c268a766..2faff0f5 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -126,3 +126,10 @@ class LinkStatus(Enum): MEDIUM = 2 HIGH = 3 OVERLOAD = 4 + + +class ImplicitFirewallRule(Enum): + """Implicit firewall rule.""" + + DENY = 0 + ALLOW = 1 diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 9e71ef1b..d9155e47 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -6,7 +6,12 @@ from typing import TYPE_CHECKING, Dict, Final, List, Tuple, Union import numpy as np from gym import spaces -from primaite.common.enums import FileSystemState, HardwareState, SoftwareState +from primaite.common.enums import ( + FileSystemState, + HardwareState, + ImplicitFirewallRule, + SoftwareState, +) from primaite.nodes.active_node import ActiveNode from primaite.nodes.service_node import ServiceNode @@ -296,6 +301,88 @@ class LinkTrafficLevels(AbstractObservationComponent): self.current_observation[:] = obs +class AccessControlList(AbstractObservationComponent): + """Flat list of all the Access Control Rules in the Access Control List. + + The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by + integers. + + :param env: The environment that forms the basis of the observations + :type env: Primaite + :param acl_implicit_rule: Whether to have an implicit DENY or implicit ALLOW ACL rule at the end of the ACL list + Default is 0 DENY, 1 ALLOW + :type acl_implicit_rule: ImplicitFirewallRule Enumeration (ALLOW or DENY) + :param max_acl_rules: Maximum number of ACLs allowed in the environment + :type max_acl_rules: int + + Each ACL Rule has 6 elements. It will have the following structure: + .. code-block:: + [ + acl_rule1 permission, + acl_rule1 source_ip, + acl_rule1 dest_ip, + acl_rule1 protocol, + acl_rule1 port, + acl_rule1 position, + acl_rule2 permission, + acl_rule2 source_ip, + acl_rule2 dest_ip, + acl_rule2 protocol, + acl_rule2 port, + acl_rule2 position, + ... + ] + """ + + _DATA_TYPE: type = np.int64 + + def __init__( + self, + env: "Primaite", + acl_implicit_rule=ImplicitFirewallRule.DENY, + max_acl_rules: int = 5, + ): + super().__init__(env) + + self.acl_implicit_rule: ImplicitFirewallRule = acl_implicit_rule + self.max_acl_rules = max_acl_rules + + # 1. Define the shape of your observation space component + acl_shape = [ + len(ImplicitFirewallRule), + len(env.nodes), + len(env.nodes), + len(env.services_list), + len(env.ports_list), + len(env.acl), + ] + shape = acl_shape * self.env.max_acl_rules + + # 2. Create Observation space + self.space = spaces.MultiDiscrete(shape) + + # 3. Initialise observation with zeroes + self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) + + def update(self): + """Update the observation based on current environment state. + + The structure of the observation space is described in :class:`.AccessControlList` + """ + obs = [] + for acl_rule in self.env.acl: + permission = acl_rule.permission + source_ip = acl_rule.source_ip + dest_ip = acl_rule.dest_ip + protocol = acl_rule.protocol + port = acl_rule.port + position = self.env.acl.index(acl_rule) + + obs.extend([permission, source_ip, dest_ip, protocol, port, position]) + + self.current_observation[:] = obs + + class ObservationsHandler: """Component-based observation space handler. diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 2f2a071d..a61372ad 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -21,6 +21,7 @@ from primaite.common.enums import ( ActionType, FileSystemState, HardwareState, + ImplicitFirewallRule, NodePOLInitiator, NodePOLType, NodeType, @@ -157,6 +158,12 @@ class Primaite(Env): # It will be initialised later. self.obs_handler: ObservationsHandler + # Set by main_config + # Adds a DENY ALL or ALLOW ALL to the end of the Access Control List + self.acl_implicit_rule = ImplicitFirewallRule.DENY + + # Sets a limit to how many ACL + self.max_acl_rules = 0 # Open the config file and build the environment laydown try: self.config_file = open(self.config_values.config_filename_use_case, "r")