#901 - Replaced "ALLOW" with RulePermissionType.ALLOW

- Added Explicit ALLOW to test_configs in order for tests to work
- Added typing to access_control_list.py and acl_rule.py
This commit is contained in:
SunilSamra
2023-07-17 20:40:00 +01:00
parent a2f43b5abc
commit 9520cfea24
7 changed files with 55 additions and 40 deletions

View File

@@ -4,6 +4,7 @@ import logging
from typing import Final, List, Union
from primaite.acl.acl_rule import ACLRule
from primaite.common.enums import RulePermissionType
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
@@ -15,12 +16,11 @@ class AccessControlList:
"""Init."""
# Implicit ALLOW or DENY firewall spec
self.acl_implicit_permission = implicit_permission
print(self.acl_implicit_permission, "ACL IMPLICIT PERMISSION")
# Implicit rule in ACL list
if self.acl_implicit_permission == "DENY":
self.acl_implicit_rule = ACLRule("DENY", "ANY", "ANY", "ANY", "ANY")
elif self.acl_implicit_permission == "ALLOW":
self.acl_implicit_rule = ACLRule("ALLOW", "ANY", "ANY", "ANY", "ANY")
if self.acl_implicit_permission == RulePermissionType.DENY:
self.acl_implicit_rule = ACLRule(RulePermissionType.DENY, "ANY", "ANY", "ANY", "ANY")
elif self.acl_implicit_permission == RulePermissionType.ALLOW:
self.acl_implicit_rule = ACLRule(RulePermissionType.ALLOW, "ANY", "ANY", "ANY", "ANY")
else:
raise ValueError(f"implicit permission must be ALLOW or DENY, got {self.acl_implicit_permission}")
@@ -76,9 +76,9 @@ class AccessControlList:
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY"
):
# There's a matching rule. Get the permission
if rule.get_permission() == "DENY":
if rule.get_permission() == RulePermissionType.DENY:
return True
elif rule.get_permission() == "ALLOW":
elif rule.get_permission() == RulePermissionType.ALLOW:
return False
# If there has been no rule to allow the IER through, it will return a blocked signal by default
@@ -121,7 +121,9 @@ class AccessControlList:
else:
_LOGGER.info(f"Position {position_index} is an invalid/overwrites implicit firewall rule")
def remove_rule(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None:
def remove_rule(
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
) -> None:
"""
Removes a rule.

View File

@@ -1,11 +1,12 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""A class that implements an access control list rule."""
from primaite.common.enums import RulePermissionType
class ACLRule:
"""Access Control List Rule class."""
def __init__(self, _permission, _source_ip, _dest_ip, _protocol, _port):
def __init__(self, _permission: RulePermissionType, _source_ip, _dest_ip, _protocol, _port):
"""
Initialise an ACL Rule.

View File

@@ -489,7 +489,7 @@ class AccessControlList(AbstractObservationComponent):
# Map each ACL attribute from what it was to an integer to fit the observation space
source_ip_int = None
dest_ip_int = None
if permission == "DENY":
if permission == RulePermissionType.DENY:
permission_int = 1
else:
permission_int = 2

View File

@@ -38,6 +38,11 @@ observation_space:
# Time delay between steps (for generic agents)
time_delay: 1
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
implicit_acl_rule: ALLOW
# Total number of ACL rules allowed in the environment
max_number_acl_rules: 4
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAIN
# Determine whether to load an agent from file

View File

@@ -36,6 +36,11 @@ observation_space:
time_delay: 1
# Filename of the scenario / laydown
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
implicit_acl_rule: ALLOW
# Total number of ACL rules allowed in the environment
max_number_acl_rules: 4
session_type: TRAIN
# Determine whether to load an agent from file
load_agent: False

View File

@@ -3,40 +3,41 @@
from primaite.acl.access_control_list import AccessControlList
from primaite.acl.acl_rule import ACLRule
from primaite.common.enums import RulePermissionType
def test_acl_address_match_1():
"""Test that matching IP addresses produce True."""
acl = AccessControlList("DENY", 10)
acl = AccessControlList(RulePermissionType.DENY, 10)
rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80")
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "192.168.1.2", "TCP", "80")
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
def test_acl_address_match_2():
"""Test that mismatching IP addresses produce False."""
acl = AccessControlList("DENY", 10)
acl = AccessControlList(RulePermissionType.DENY, 10)
rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80")
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "192.168.1.2", "TCP", "80")
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.3") == False
def test_acl_address_match_3():
"""Test the ANY condition for source IP addresses produce True."""
acl = AccessControlList("DENY", 10)
acl = AccessControlList(RulePermissionType.DENY, 10)
rule = ACLRule("ALLOW", "ANY", "192.168.1.2", "TCP", "80")
rule = ACLRule(RulePermissionType.ALLOW, "ANY", "192.168.1.2", "TCP", "80")
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
def test_acl_address_match_4():
"""Test the ANY condition for dest IP addresses produce True."""
acl = AccessControlList("DENY", 10)
acl = AccessControlList(RulePermissionType.DENY, 10)
rule = ACLRule("ALLOW", "192.168.1.1", "ANY", "TCP", "80")
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "ANY", "TCP", "80")
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
@@ -44,10 +45,10 @@ def test_acl_address_match_4():
def test_check_acl_block_affirmative():
"""Test the block function (affirmative)."""
# Create the Access Control List
acl = AccessControlList("DENY", 10)
acl = AccessControlList(RulePermissionType.DENY, 10)
# Create a rule
acl_rule_permission = "ALLOW"
acl_rule_permission = RulePermissionType.ALLOW
acl_rule_source = "192.168.1.1"
acl_rule_destination = "192.168.1.2"
acl_rule_protocol = "TCP"
@@ -68,10 +69,10 @@ def test_check_acl_block_affirmative():
def test_check_acl_block_negative():
"""Test the block function (negative)."""
# Create the Access Control List
acl = AccessControlList("DENY", 10)
acl = AccessControlList(RulePermissionType.DENY, 10)
# Create a rule
acl_rule_permission = "DENY"
acl_rule_permission = RulePermissionType.DENY
acl_rule_source = "192.168.1.1"
acl_rule_destination = "192.168.1.2"
acl_rule_protocol = "TCP"
@@ -93,12 +94,12 @@ def test_check_acl_block_negative():
def test_rule_hash():
"""Test the rule hash."""
# Create the Access Control List
acl = AccessControlList("DENY", 10)
acl = AccessControlList(RulePermissionType.DENY, 10)
rule = ACLRule("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
rule = ACLRule(RulePermissionType.DENY, "192.168.1.1", "192.168.1.2", "TCP", "80")
hash_value_local = hash(rule)
hash_value_remote = acl.get_dictionary_hash("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
hash_value_remote = acl.get_dictionary_hash(RulePermissionType.DENY, "192.168.1.1", "192.168.1.2", "TCP", "80")
assert hash_value_local == hash_value_remote
@@ -106,10 +107,10 @@ def test_rule_hash():
def test_delete_rule():
"""Adds 3 rules and deletes 1 rule and checks its deletion."""
# Create the Access Control List
acl = AccessControlList("ALLOW", 10)
acl = AccessControlList(RulePermissionType.ALLOW, 10)
# Create a first rule
acl_rule_permission = "DENY"
acl_rule_permission = RulePermissionType.DENY
acl_rule_source = "192.168.1.1"
acl_rule_destination = "192.168.1.2"
acl_rule_protocol = "TCP"
@@ -126,7 +127,7 @@ def test_delete_rule():
)
# Create a second rule
acl_rule_permission = "DENY"
acl_rule_permission = RulePermissionType.DENY
acl_rule_source = "20"
acl_rule_destination = "30"
acl_rule_protocol = "FTP"
@@ -143,7 +144,7 @@ def test_delete_rule():
)
# Create a third rule
acl_rule_permission = "ALLOW"
acl_rule_permission = RulePermissionType.ALLOW
acl_rule_source = "192.168.1.3"
acl_rule_destination = "192.168.1.1"
acl_rule_protocol = "UDP"
@@ -159,7 +160,7 @@ def test_delete_rule():
acl_position_in_list,
)
# Remove the second ACL rule added from the list
acl.remove_rule("DENY", "20", "30", "FTP", "21")
acl.remove_rule(RulePermissionType.DENY, "20", "30", "FTP", "21")
assert len(acl.acl) == 10
assert acl.acl[2] is None

View File

@@ -26,16 +26,16 @@ def test_seeded_learning(temp_primaite_session):
now work. If not, then you've got a bug :).
"""
expected_mean_reward_per_episode = {
1: -90.703125,
2: -91.15234375,
3: -87.5,
4: -92.2265625,
5: -94.6875,
6: -91.19140625,
7: -88.984375,
8: -88.3203125,
9: -112.79296875,
10: -100.01953125,
1: -20.7421875,
2: -19.82421875,
3: -17.01171875,
4: -19.08203125,
5: -21.93359375,
6: -20.21484375,
7: -15.546875,
8: -12.08984375,
9: -17.59765625,
10: -14.6875,
}
with temp_primaite_session as session:
@@ -44,6 +44,7 @@ def test_seeded_learning(temp_primaite_session):
), "Expected output is based upon a agent that was trained with seed 67890"
session.learn()
actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict()
print(actual_mean_reward_per_episode, "THISt")
assert actual_mean_reward_per_episode == expected_mean_reward_per_episode