diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index 7eeef731..0b403556 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -8,10 +8,12 @@ from primaite.acl.acl_rule import ACLRule class AccessControlList: """Access Control List class.""" - def __init__(self): + def __init__(self, implicit_permission): """Init.""" # A list of ACL Rules self.acl: List[ACLRule] = [] + self.acl_implicit_rule = implicit_permission + self.max_acl_rules: int def check_address_match(self, _rule, _source_ip_address, _dest_ip_address): """ @@ -103,6 +105,7 @@ class AccessControlList: _protocol: the protocol _port: the port """ + # Add check so you cant remove implicit rule rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) hash_value = hash(rule) # There will not always be something 'popable' since the agent will be trying random things diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 2faff0f5..801494ef 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -128,7 +128,7 @@ class LinkStatus(Enum): OVERLOAD = 4 -class ImplicitFirewallRule(Enum): +class RulePermissionType(Enum): """Implicit firewall rule.""" DENY = 0 diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index d9155e47..865b4328 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -9,7 +9,7 @@ from gym import spaces from primaite.common.enums import ( FileSystemState, HardwareState, - ImplicitFirewallRule, + RulePermissionType, SoftwareState, ) from primaite.nodes.active_node import ActiveNode @@ -336,25 +336,16 @@ class AccessControlList(AbstractObservationComponent): _DATA_TYPE: type = np.int64 - def __init__( - self, - env: "Primaite", - acl_implicit_rule=ImplicitFirewallRule.DENY, - max_acl_rules: int = 5, - ): + def __init__(self, env: "Primaite"): super().__init__(env) - self.acl_implicit_rule: ImplicitFirewallRule = acl_implicit_rule - self.max_acl_rules = max_acl_rules - # 1. Define the shape of your observation space component acl_shape = [ - len(ImplicitFirewallRule), + len(RulePermissionType), len(env.nodes), len(env.nodes), len(env.services_list), len(env.ports_list), - len(env.acl), ] shape = acl_shape * self.env.max_acl_rules @@ -394,6 +385,7 @@ class ObservationsHandler: "NODE_LINK_TABLE": NodeLinkTable, "NODE_STATUSES": NodeStatuses, "LINK_TRAFFIC_LEVELS": LinkTrafficLevels, + "ACCESS_CONTROL_LIST": AccessControlList, } def __init__(self): diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index a61372ad..c5aaf9cc 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -21,11 +21,11 @@ from primaite.common.enums import ( ActionType, FileSystemState, HardwareState, - ImplicitFirewallRule, NodePOLInitiator, NodePOLType, NodeType, Priority, + RulePermissionType, SoftwareState, ) from primaite.common.service import Service @@ -160,7 +160,7 @@ class Primaite(Env): # Set by main_config # Adds a DENY ALL or ALLOW ALL to the end of the Access Control List - self.acl_implicit_rule = ImplicitFirewallRule.DENY + self.acl_implicit_rule = RulePermissionType.DENY # Sets a limit to how many ACL self.max_acl_rules = 0 @@ -1173,7 +1173,7 @@ class Primaite(Env): def create_acl_action_dict(self): """Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.""" # reserve 0 action to be a nothing action - actions = {0: [0, 0, 0, 0, 0, 0]} + actions = {0: [0, 0, 0, 0, 0, 0, 0]} action_key = 1 # 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE @@ -1185,14 +1185,16 @@ class Primaite(Env): for dest_ip in range(self.num_nodes + 1): for protocol in range(self.num_services + 1): for port in range(self.num_ports + 1): - action = [ - action_decision, - action_permission, - source_ip, - dest_ip, - protocol, - port, - ] + for position in range(self.max_acl_rules - 1): + action = [ + action_decision, + action_permission, + source_ip, + dest_ip, + protocol, + port, + position, + ] # Check to see if its an action we want to include as possible i.e. not a nothing action if is_valid_acl_action_extra(action): actions[action_key] = action