Type hint ACLs

This commit is contained in:
Marek Wolan
2023-07-12 16:58:12 +01:00
parent c61770825a
commit f4a70394e0
2 changed files with 26 additions and 22 deletions

View File

@@ -1,6 +1,6 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""A class that implements the access control list implementation for the network.""" """A class that implements the access control list implementation for the network."""
from typing import Dict from typing import Dict, Optional
from primaite.acl.acl_rule import ACLRule from primaite.acl.acl_rule import ACLRule
@@ -8,9 +8,9 @@ from primaite.acl.acl_rule import ACLRule
class AccessControlList: class AccessControlList:
"""Access Control List class.""" """Access Control List class."""
def __init__(self): def __init__(self) -> None:
"""Initialise an empty AccessControlList.""" """Initialise an empty AccessControlList."""
self.acl: Dict[str, ACLRule] = {} # A dictionary of ACL Rules self.acl: Dict[int, ACLRule] = {} # A dictionary of ACL Rules
def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool: def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool:
"""Checks for IP address matches. """Checks for IP address matches.
@@ -61,7 +61,7 @@ class AccessControlList:
# If there has been no rule to allow the IER through, it will return a blocked signal by default # If there has been no rule to allow the IER through, it will return a blocked signal by default
return True return True
def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port): def add_rule(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None:
""" """
Adds a new rule. Adds a new rule.
@@ -76,7 +76,9 @@ class AccessControlList:
hash_value = hash(new_rule) hash_value = hash(new_rule)
self.acl[hash_value] = new_rule self.acl[hash_value] = new_rule
def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port): def remove_rule(
self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
) -> Optional[int]:
""" """
Removes a rule. Removes a rule.
@@ -95,11 +97,11 @@ class AccessControlList:
except Exception: except Exception:
return return
def remove_all_rules(self): def remove_all_rules(self) -> None:
"""Removes all rules.""" """Removes all rules."""
self.acl.clear() self.acl.clear()
def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port): def get_dictionary_hash(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> int:
""" """
Produces a hash value for a rule. Produces a hash value for a rule.
@@ -117,7 +119,9 @@ class AccessControlList:
hash_value = hash(rule) hash_value = hash(rule)
return hash_value return hash_value
def get_relevant_rules(self, _source_ip_address, _dest_ip_address, _protocol, _port): def get_relevant_rules(
self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str
) -> Dict[int, ACLRule]:
"""Get all ACL rules that relate to the given arguments. """Get all ACL rules that relate to the given arguments.
:param _source_ip_address: the source IP address to check :param _source_ip_address: the source IP address to check
@@ -125,9 +129,9 @@ class AccessControlList:
:param _protocol: the protocol to check :param _protocol: the protocol to check
:param _port: the port to check :param _port: the port to check
:return: Dictionary of all ACL rules that relate to the given arguments :return: Dictionary of all ACL rules that relate to the given arguments
:rtype: Dict[str, ACLRule] :rtype: Dict[int, ACLRule]
""" """
relevant_rules = {} relevant_rules: Dict[int, ACLRule] = {}
for rule_key, rule_value in self.acl.items(): for rule_key, rule_value in self.acl.items():
if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address): if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address):

View File

@@ -5,7 +5,7 @@
class ACLRule: class ACLRule:
"""Access Control List Rule class.""" """Access Control List Rule class."""
def __init__(self, _permission, _source_ip, _dest_ip, _protocol, _port): def __init__(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None:
""" """
Initialise an ACL Rule. Initialise an ACL Rule.
@@ -15,13 +15,13 @@ class ACLRule:
:param _protocol: The rule protocol :param _protocol: The rule protocol
:param _port: The rule port :param _port: The rule port
""" """
self.permission = _permission self.permission: str = _permission
self.source_ip = _source_ip self.source_ip: str = _source_ip
self.dest_ip = _dest_ip self.dest_ip: str = _dest_ip
self.protocol = _protocol self.protocol: str = _protocol
self.port = _port self.port: str = _port
def __hash__(self): def __hash__(self) -> int:
""" """
Override the hash function. Override the hash function.
@@ -38,7 +38,7 @@ class ACLRule:
) )
) )
def get_permission(self): def get_permission(self) -> str:
""" """
Gets the permission attribute. Gets the permission attribute.
@@ -47,7 +47,7 @@ class ACLRule:
""" """
return self.permission return self.permission
def get_source_ip(self): def get_source_ip(self) -> str:
""" """
Gets the source IP address attribute. Gets the source IP address attribute.
@@ -56,7 +56,7 @@ class ACLRule:
""" """
return self.source_ip return self.source_ip
def get_dest_ip(self): def get_dest_ip(self) -> str:
""" """
Gets the desintation IP address attribute. Gets the desintation IP address attribute.
@@ -65,7 +65,7 @@ class ACLRule:
""" """
return self.dest_ip return self.dest_ip
def get_protocol(self): def get_protocol(self) -> str:
""" """
Gets the protocol attribute. Gets the protocol attribute.
@@ -74,7 +74,7 @@ class ACLRule:
""" """
return self.protocol return self.protocol
def get_port(self): def get_port(self) -> str:
""" """
Gets the port attribute. Gets the port attribute.