#901 - Replaced "ALLOW" with RulePermissionType.ALLOW
- Added Explicit ALLOW to test_configs in order for tests to work - Added typing to access_control_list.py and acl_rule.py
This commit is contained in:
@@ -4,6 +4,7 @@ import logging
|
||||
from typing import Final, List, Union
|
||||
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.common.enums import RulePermissionType
|
||||
|
||||
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
||||
|
||||
@@ -15,12 +16,11 @@ class AccessControlList:
|
||||
"""Init."""
|
||||
# Implicit ALLOW or DENY firewall spec
|
||||
self.acl_implicit_permission = implicit_permission
|
||||
print(self.acl_implicit_permission, "ACL IMPLICIT PERMISSION")
|
||||
# Implicit rule in ACL list
|
||||
if self.acl_implicit_permission == "DENY":
|
||||
self.acl_implicit_rule = ACLRule("DENY", "ANY", "ANY", "ANY", "ANY")
|
||||
elif self.acl_implicit_permission == "ALLOW":
|
||||
self.acl_implicit_rule = ACLRule("ALLOW", "ANY", "ANY", "ANY", "ANY")
|
||||
if self.acl_implicit_permission == RulePermissionType.DENY:
|
||||
self.acl_implicit_rule = ACLRule(RulePermissionType.DENY, "ANY", "ANY", "ANY", "ANY")
|
||||
elif self.acl_implicit_permission == RulePermissionType.ALLOW:
|
||||
self.acl_implicit_rule = ACLRule(RulePermissionType.ALLOW, "ANY", "ANY", "ANY", "ANY")
|
||||
else:
|
||||
raise ValueError(f"implicit permission must be ALLOW or DENY, got {self.acl_implicit_permission}")
|
||||
|
||||
@@ -76,9 +76,9 @@ class AccessControlList:
|
||||
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY"
|
||||
):
|
||||
# There's a matching rule. Get the permission
|
||||
if rule.get_permission() == "DENY":
|
||||
if rule.get_permission() == RulePermissionType.DENY:
|
||||
return True
|
||||
elif rule.get_permission() == "ALLOW":
|
||||
elif rule.get_permission() == RulePermissionType.ALLOW:
|
||||
return False
|
||||
|
||||
# If there has been no rule to allow the IER through, it will return a blocked signal by default
|
||||
@@ -121,7 +121,9 @@ class AccessControlList:
|
||||
else:
|
||||
_LOGGER.info(f"Position {position_index} is an invalid/overwrites implicit firewall rule")
|
||||
|
||||
def remove_rule(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None:
|
||||
def remove_rule(
|
||||
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
|
||||
) -> None:
|
||||
"""
|
||||
Removes a rule.
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""A class that implements an access control list rule."""
|
||||
from primaite.common.enums import RulePermissionType
|
||||
|
||||
|
||||
class ACLRule:
|
||||
"""Access Control List Rule class."""
|
||||
|
||||
def __init__(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
def __init__(self, _permission: RulePermissionType, _source_ip, _dest_ip, _protocol, _port):
|
||||
"""
|
||||
Initialise an ACL Rule.
|
||||
|
||||
|
||||
@@ -489,7 +489,7 @@ class AccessControlList(AbstractObservationComponent):
|
||||
# Map each ACL attribute from what it was to an integer to fit the observation space
|
||||
source_ip_int = None
|
||||
dest_ip_int = None
|
||||
if permission == "DENY":
|
||||
if permission == RulePermissionType.DENY:
|
||||
permission_int = 1
|
||||
else:
|
||||
permission_int = 2
|
||||
|
||||
@@ -38,6 +38,11 @@ observation_space:
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 1
|
||||
|
||||
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
|
||||
implicit_acl_rule: ALLOW
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 4
|
||||
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAIN
|
||||
# Determine whether to load an agent from file
|
||||
|
||||
@@ -36,6 +36,11 @@ observation_space:
|
||||
time_delay: 1
|
||||
# Filename of the scenario / laydown
|
||||
|
||||
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
|
||||
implicit_acl_rule: ALLOW
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 4
|
||||
|
||||
session_type: TRAIN
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
|
||||
@@ -3,40 +3,41 @@
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.common.enums import RulePermissionType
|
||||
|
||||
|
||||
def test_acl_address_match_1():
|
||||
"""Test that matching IP addresses produce True."""
|
||||
acl = AccessControlList("DENY", 10)
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
|
||||
|
||||
|
||||
def test_acl_address_match_2():
|
||||
"""Test that mismatching IP addresses produce False."""
|
||||
acl = AccessControlList("DENY", 10)
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.3") == False
|
||||
|
||||
|
||||
def test_acl_address_match_3():
|
||||
"""Test the ANY condition for source IP addresses produce True."""
|
||||
acl = AccessControlList("DENY", 10)
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule("ALLOW", "ANY", "192.168.1.2", "TCP", "80")
|
||||
rule = ACLRule(RulePermissionType.ALLOW, "ANY", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
|
||||
|
||||
|
||||
def test_acl_address_match_4():
|
||||
"""Test the ANY condition for dest IP addresses produce True."""
|
||||
acl = AccessControlList("DENY", 10)
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule("ALLOW", "192.168.1.1", "ANY", "TCP", "80")
|
||||
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "ANY", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
|
||||
|
||||
@@ -44,10 +45,10 @@ def test_acl_address_match_4():
|
||||
def test_check_acl_block_affirmative():
|
||||
"""Test the block function (affirmative)."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList("DENY", 10)
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
# Create a rule
|
||||
acl_rule_permission = "ALLOW"
|
||||
acl_rule_permission = RulePermissionType.ALLOW
|
||||
acl_rule_source = "192.168.1.1"
|
||||
acl_rule_destination = "192.168.1.2"
|
||||
acl_rule_protocol = "TCP"
|
||||
@@ -68,10 +69,10 @@ def test_check_acl_block_affirmative():
|
||||
def test_check_acl_block_negative():
|
||||
"""Test the block function (negative)."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList("DENY", 10)
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
# Create a rule
|
||||
acl_rule_permission = "DENY"
|
||||
acl_rule_permission = RulePermissionType.DENY
|
||||
acl_rule_source = "192.168.1.1"
|
||||
acl_rule_destination = "192.168.1.2"
|
||||
acl_rule_protocol = "TCP"
|
||||
@@ -93,12 +94,12 @@ def test_check_acl_block_negative():
|
||||
def test_rule_hash():
|
||||
"""Test the rule hash."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList("DENY", 10)
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
rule = ACLRule(RulePermissionType.DENY, "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
hash_value_local = hash(rule)
|
||||
|
||||
hash_value_remote = acl.get_dictionary_hash("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
hash_value_remote = acl.get_dictionary_hash(RulePermissionType.DENY, "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert hash_value_local == hash_value_remote
|
||||
|
||||
@@ -106,10 +107,10 @@ def test_rule_hash():
|
||||
def test_delete_rule():
|
||||
"""Adds 3 rules and deletes 1 rule and checks its deletion."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList("ALLOW", 10)
|
||||
acl = AccessControlList(RulePermissionType.ALLOW, 10)
|
||||
|
||||
# Create a first rule
|
||||
acl_rule_permission = "DENY"
|
||||
acl_rule_permission = RulePermissionType.DENY
|
||||
acl_rule_source = "192.168.1.1"
|
||||
acl_rule_destination = "192.168.1.2"
|
||||
acl_rule_protocol = "TCP"
|
||||
@@ -126,7 +127,7 @@ def test_delete_rule():
|
||||
)
|
||||
|
||||
# Create a second rule
|
||||
acl_rule_permission = "DENY"
|
||||
acl_rule_permission = RulePermissionType.DENY
|
||||
acl_rule_source = "20"
|
||||
acl_rule_destination = "30"
|
||||
acl_rule_protocol = "FTP"
|
||||
@@ -143,7 +144,7 @@ def test_delete_rule():
|
||||
)
|
||||
|
||||
# Create a third rule
|
||||
acl_rule_permission = "ALLOW"
|
||||
acl_rule_permission = RulePermissionType.ALLOW
|
||||
acl_rule_source = "192.168.1.3"
|
||||
acl_rule_destination = "192.168.1.1"
|
||||
acl_rule_protocol = "UDP"
|
||||
@@ -159,7 +160,7 @@ def test_delete_rule():
|
||||
acl_position_in_list,
|
||||
)
|
||||
# Remove the second ACL rule added from the list
|
||||
acl.remove_rule("DENY", "20", "30", "FTP", "21")
|
||||
acl.remove_rule(RulePermissionType.DENY, "20", "30", "FTP", "21")
|
||||
|
||||
assert len(acl.acl) == 10
|
||||
assert acl.acl[2] is None
|
||||
|
||||
@@ -26,16 +26,16 @@ def test_seeded_learning(temp_primaite_session):
|
||||
now work. If not, then you've got a bug :).
|
||||
"""
|
||||
expected_mean_reward_per_episode = {
|
||||
1: -90.703125,
|
||||
2: -91.15234375,
|
||||
3: -87.5,
|
||||
4: -92.2265625,
|
||||
5: -94.6875,
|
||||
6: -91.19140625,
|
||||
7: -88.984375,
|
||||
8: -88.3203125,
|
||||
9: -112.79296875,
|
||||
10: -100.01953125,
|
||||
1: -20.7421875,
|
||||
2: -19.82421875,
|
||||
3: -17.01171875,
|
||||
4: -19.08203125,
|
||||
5: -21.93359375,
|
||||
6: -20.21484375,
|
||||
7: -15.546875,
|
||||
8: -12.08984375,
|
||||
9: -17.59765625,
|
||||
10: -14.6875,
|
||||
}
|
||||
|
||||
with temp_primaite_session as session:
|
||||
@@ -44,6 +44,7 @@ def test_seeded_learning(temp_primaite_session):
|
||||
), "Expected output is based upon a agent that was trained with seed 67890"
|
||||
session.learn()
|
||||
actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict()
|
||||
print(actual_mean_reward_per_episode, "THISt")
|
||||
|
||||
assert actual_mean_reward_per_episode == expected_mean_reward_per_episode
|
||||
|
||||
|
||||
Reference in New Issue
Block a user