diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index 9a8444e5..f7e65bd4 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -1,6 +1,6 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """A class that implements the access control list implementation for the network.""" -from typing import Dict +from typing import Dict, Optional from primaite.acl.acl_rule import ACLRule @@ -8,9 +8,9 @@ from primaite.acl.acl_rule import ACLRule class AccessControlList: """Access Control List class.""" - def __init__(self): + def __init__(self) -> None: """Initialise an empty AccessControlList.""" - self.acl: Dict[str, ACLRule] = {} # A dictionary of ACL Rules + self.acl: Dict[int, ACLRule] = {} # A dictionary of ACL Rules def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool: """Checks for IP address matches. @@ -61,7 +61,7 @@ class AccessControlList: # If there has been no rule to allow the IER through, it will return a blocked signal by default return True - def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port): + def add_rule(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None: """ Adds a new rule. @@ -76,7 +76,9 @@ class AccessControlList: hash_value = hash(new_rule) self.acl[hash_value] = new_rule - def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port): + def remove_rule( + self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str + ) -> Optional[int]: """ Removes a rule. @@ -95,11 +97,11 @@ class AccessControlList: except Exception: return - def remove_all_rules(self): + def remove_all_rules(self) -> None: """Removes all rules.""" self.acl.clear() - def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port): + def get_dictionary_hash(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> int: """ Produces a hash value for a rule. @@ -117,7 +119,9 @@ class AccessControlList: hash_value = hash(rule) return hash_value - def get_relevant_rules(self, _source_ip_address, _dest_ip_address, _protocol, _port): + def get_relevant_rules( + self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str + ) -> Dict[int, ACLRule]: """Get all ACL rules that relate to the given arguments. :param _source_ip_address: the source IP address to check @@ -125,9 +129,9 @@ class AccessControlList: :param _protocol: the protocol to check :param _port: the port to check :return: Dictionary of all ACL rules that relate to the given arguments - :rtype: Dict[str, ACLRule] + :rtype: Dict[int, ACLRule] """ - relevant_rules = {} + relevant_rules: Dict[int, ACLRule] = {} for rule_key, rule_value in self.acl.items(): if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address): diff --git a/src/primaite/acl/acl_rule.py b/src/primaite/acl/acl_rule.py index a1fd93f2..69532376 100644 --- a/src/primaite/acl/acl_rule.py +++ b/src/primaite/acl/acl_rule.py @@ -5,7 +5,7 @@ class ACLRule: """Access Control List Rule class.""" - def __init__(self, _permission, _source_ip, _dest_ip, _protocol, _port): + def __init__(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None: """ Initialise an ACL Rule. @@ -15,13 +15,13 @@ class ACLRule: :param _protocol: The rule protocol :param _port: The rule port """ - self.permission = _permission - self.source_ip = _source_ip - self.dest_ip = _dest_ip - self.protocol = _protocol - self.port = _port + self.permission: str = _permission + self.source_ip: str = _source_ip + self.dest_ip: str = _dest_ip + self.protocol: str = _protocol + self.port: str = _port - def __hash__(self): + def __hash__(self) -> int: """ Override the hash function. @@ -38,7 +38,7 @@ class ACLRule: ) ) - def get_permission(self): + def get_permission(self) -> str: """ Gets the permission attribute. @@ -47,7 +47,7 @@ class ACLRule: """ return self.permission - def get_source_ip(self): + def get_source_ip(self) -> str: """ Gets the source IP address attribute. @@ -56,7 +56,7 @@ class ACLRule: """ return self.source_ip - def get_dest_ip(self): + def get_dest_ip(self) -> str: """ Gets the desintation IP address attribute. @@ -65,7 +65,7 @@ class ACLRule: """ return self.dest_ip - def get_protocol(self): + def get_protocol(self) -> str: """ Gets the protocol attribute. @@ -74,7 +74,7 @@ class ACLRule: """ return self.protocol - def get_port(self): + def get_port(self) -> str: """ Gets the port attribute.