diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index bf008d26..936dcb12 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -4,6 +4,7 @@ import logging from typing import Final, List, Union from primaite.acl.acl_rule import ACLRule +from primaite.common.enums import RulePermissionType _LOGGER: Final[logging.Logger] = logging.getLogger(__name__) @@ -15,12 +16,11 @@ class AccessControlList: """Init.""" # Implicit ALLOW or DENY firewall spec self.acl_implicit_permission = implicit_permission - print(self.acl_implicit_permission, "ACL IMPLICIT PERMISSION") # Implicit rule in ACL list - if self.acl_implicit_permission == "DENY": - self.acl_implicit_rule = ACLRule("DENY", "ANY", "ANY", "ANY", "ANY") - elif self.acl_implicit_permission == "ALLOW": - self.acl_implicit_rule = ACLRule("ALLOW", "ANY", "ANY", "ANY", "ANY") + if self.acl_implicit_permission == RulePermissionType.DENY: + self.acl_implicit_rule = ACLRule(RulePermissionType.DENY, "ANY", "ANY", "ANY", "ANY") + elif self.acl_implicit_permission == RulePermissionType.ALLOW: + self.acl_implicit_rule = ACLRule(RulePermissionType.ALLOW, "ANY", "ANY", "ANY", "ANY") else: raise ValueError(f"implicit permission must be ALLOW or DENY, got {self.acl_implicit_permission}") @@ -76,9 +76,9 @@ class AccessControlList: str(rule.get_port()) == str(_port) or rule.get_port() == "ANY" ): # There's a matching rule. Get the permission - if rule.get_permission() == "DENY": + if rule.get_permission() == RulePermissionType.DENY: return True - elif rule.get_permission() == "ALLOW": + elif rule.get_permission() == RulePermissionType.ALLOW: return False # If there has been no rule to allow the IER through, it will return a blocked signal by default @@ -121,7 +121,9 @@ class AccessControlList: else: _LOGGER.info(f"Position {position_index} is an invalid/overwrites implicit firewall rule") - def remove_rule(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None: + def remove_rule( + self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str + ) -> None: """ Removes a rule. diff --git a/src/primaite/acl/acl_rule.py b/src/primaite/acl/acl_rule.py index a1fd93f2..49c0a84c 100644 --- a/src/primaite/acl/acl_rule.py +++ b/src/primaite/acl/acl_rule.py @@ -1,11 +1,12 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """A class that implements an access control list rule.""" +from primaite.common.enums import RulePermissionType class ACLRule: """Access Control List Rule class.""" - def __init__(self, _permission, _source_ip, _dest_ip, _protocol, _port): + def __init__(self, _permission: RulePermissionType, _source_ip, _dest_ip, _protocol, _port): """ Initialise an ACL Rule. diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index a95f720c..7695c916 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -489,7 +489,7 @@ class AccessControlList(AbstractObservationComponent): # Map each ACL attribute from what it was to an integer to fit the observation space source_ip_int = None dest_ip_int = None - if permission == "DENY": + if permission == RulePermissionType.DENY: permission_int = 1 else: permission_int = 2 diff --git a/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml b/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml index 2ac8f59a..df826c87 100644 --- a/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml +++ b/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml @@ -38,6 +38,11 @@ observation_space: # Time delay between steps (for generic agents) time_delay: 1 +# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY) +implicit_acl_rule: ALLOW +# Total number of ACL rules allowed in the environment +max_number_acl_rules: 4 + # Type of session to be run (TRAINING or EVALUATION) session_type: TRAIN # Determine whether to load an agent from file diff --git a/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml b/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml index a9986d5b..aa1cce38 100644 --- a/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml +++ b/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml @@ -36,6 +36,11 @@ observation_space: time_delay: 1 # Filename of the scenario / laydown +# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY) +implicit_acl_rule: ALLOW +# Total number of ACL rules allowed in the environment +max_number_acl_rules: 4 + session_type: TRAIN # Determine whether to load an agent from file load_agent: False diff --git a/tests/test_acl.py b/tests/test_acl.py index 088da5eb..aeb95149 100644 --- a/tests/test_acl.py +++ b/tests/test_acl.py @@ -3,40 +3,41 @@ from primaite.acl.access_control_list import AccessControlList from primaite.acl.acl_rule import ACLRule +from primaite.common.enums import RulePermissionType def test_acl_address_match_1(): """Test that matching IP addresses produce True.""" - acl = AccessControlList("DENY", 10) + acl = AccessControlList(RulePermissionType.DENY, 10) - rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80") + rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "192.168.1.2", "TCP", "80") assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True def test_acl_address_match_2(): """Test that mismatching IP addresses produce False.""" - acl = AccessControlList("DENY", 10) + acl = AccessControlList(RulePermissionType.DENY, 10) - rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80") + rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "192.168.1.2", "TCP", "80") assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.3") == False def test_acl_address_match_3(): """Test the ANY condition for source IP addresses produce True.""" - acl = AccessControlList("DENY", 10) + acl = AccessControlList(RulePermissionType.DENY, 10) - rule = ACLRule("ALLOW", "ANY", "192.168.1.2", "TCP", "80") + rule = ACLRule(RulePermissionType.ALLOW, "ANY", "192.168.1.2", "TCP", "80") assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True def test_acl_address_match_4(): """Test the ANY condition for dest IP addresses produce True.""" - acl = AccessControlList("DENY", 10) + acl = AccessControlList(RulePermissionType.DENY, 10) - rule = ACLRule("ALLOW", "192.168.1.1", "ANY", "TCP", "80") + rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "ANY", "TCP", "80") assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True @@ -44,10 +45,10 @@ def test_acl_address_match_4(): def test_check_acl_block_affirmative(): """Test the block function (affirmative).""" # Create the Access Control List - acl = AccessControlList("DENY", 10) + acl = AccessControlList(RulePermissionType.DENY, 10) # Create a rule - acl_rule_permission = "ALLOW" + acl_rule_permission = RulePermissionType.ALLOW acl_rule_source = "192.168.1.1" acl_rule_destination = "192.168.1.2" acl_rule_protocol = "TCP" @@ -68,10 +69,10 @@ def test_check_acl_block_affirmative(): def test_check_acl_block_negative(): """Test the block function (negative).""" # Create the Access Control List - acl = AccessControlList("DENY", 10) + acl = AccessControlList(RulePermissionType.DENY, 10) # Create a rule - acl_rule_permission = "DENY" + acl_rule_permission = RulePermissionType.DENY acl_rule_source = "192.168.1.1" acl_rule_destination = "192.168.1.2" acl_rule_protocol = "TCP" @@ -93,12 +94,12 @@ def test_check_acl_block_negative(): def test_rule_hash(): """Test the rule hash.""" # Create the Access Control List - acl = AccessControlList("DENY", 10) + acl = AccessControlList(RulePermissionType.DENY, 10) - rule = ACLRule("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80") + rule = ACLRule(RulePermissionType.DENY, "192.168.1.1", "192.168.1.2", "TCP", "80") hash_value_local = hash(rule) - hash_value_remote = acl.get_dictionary_hash("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80") + hash_value_remote = acl.get_dictionary_hash(RulePermissionType.DENY, "192.168.1.1", "192.168.1.2", "TCP", "80") assert hash_value_local == hash_value_remote @@ -106,10 +107,10 @@ def test_rule_hash(): def test_delete_rule(): """Adds 3 rules and deletes 1 rule and checks its deletion.""" # Create the Access Control List - acl = AccessControlList("ALLOW", 10) + acl = AccessControlList(RulePermissionType.ALLOW, 10) # Create a first rule - acl_rule_permission = "DENY" + acl_rule_permission = RulePermissionType.DENY acl_rule_source = "192.168.1.1" acl_rule_destination = "192.168.1.2" acl_rule_protocol = "TCP" @@ -126,7 +127,7 @@ def test_delete_rule(): ) # Create a second rule - acl_rule_permission = "DENY" + acl_rule_permission = RulePermissionType.DENY acl_rule_source = "20" acl_rule_destination = "30" acl_rule_protocol = "FTP" @@ -143,7 +144,7 @@ def test_delete_rule(): ) # Create a third rule - acl_rule_permission = "ALLOW" + acl_rule_permission = RulePermissionType.ALLOW acl_rule_source = "192.168.1.3" acl_rule_destination = "192.168.1.1" acl_rule_protocol = "UDP" @@ -159,7 +160,7 @@ def test_delete_rule(): acl_position_in_list, ) # Remove the second ACL rule added from the list - acl.remove_rule("DENY", "20", "30", "FTP", "21") + acl.remove_rule(RulePermissionType.DENY, "20", "30", "FTP", "21") assert len(acl.acl) == 10 assert acl.acl[2] is None diff --git a/tests/test_seeding_and_deterministic_session.py b/tests/test_seeding_and_deterministic_session.py index 637c1693..1dcb11a3 100644 --- a/tests/test_seeding_and_deterministic_session.py +++ b/tests/test_seeding_and_deterministic_session.py @@ -26,16 +26,16 @@ def test_seeded_learning(temp_primaite_session): now work. If not, then you've got a bug :). """ expected_mean_reward_per_episode = { - 1: -90.703125, - 2: -91.15234375, - 3: -87.5, - 4: -92.2265625, - 5: -94.6875, - 6: -91.19140625, - 7: -88.984375, - 8: -88.3203125, - 9: -112.79296875, - 10: -100.01953125, + 1: -20.7421875, + 2: -19.82421875, + 3: -17.01171875, + 4: -19.08203125, + 5: -21.93359375, + 6: -20.21484375, + 7: -15.546875, + 8: -12.08984375, + 9: -17.59765625, + 10: -14.6875, } with temp_primaite_session as session: @@ -44,6 +44,7 @@ def test_seeded_learning(temp_primaite_session): ), "Expected output is based upon a agent that was trained with seed 67890" session.learn() actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict() + print(actual_mean_reward_per_episode, "THISt") assert actual_mean_reward_per_episode == expected_mean_reward_per_episode