Type hint ACLs

This commit is contained in:
Marek Wolan
2023-07-12 16:58:12 +01:00
parent c61770825a
commit f4a70394e0
2 changed files with 26 additions and 22 deletions

View File

@@ -1,6 +1,6 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""A class that implements the access control list implementation for the network."""
from typing import Dict
from typing import Dict, Optional
from primaite.acl.acl_rule import ACLRule
@@ -8,9 +8,9 @@ from primaite.acl.acl_rule import ACLRule
class AccessControlList:
"""Access Control List class."""
def __init__(self):
def __init__(self) -> None:
"""Initialise an empty AccessControlList."""
self.acl: Dict[str, ACLRule] = {} # A dictionary of ACL Rules
self.acl: Dict[int, ACLRule] = {} # A dictionary of ACL Rules
def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool:
"""Checks for IP address matches.
@@ -61,7 +61,7 @@ class AccessControlList:
# If there has been no rule to allow the IER through, it will return a blocked signal by default
return True
def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
def add_rule(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None:
"""
Adds a new rule.
@@ -76,7 +76,9 @@ class AccessControlList:
hash_value = hash(new_rule)
self.acl[hash_value] = new_rule
def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
def remove_rule(
self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
) -> Optional[int]:
"""
Removes a rule.
@@ -95,11 +97,11 @@ class AccessControlList:
except Exception:
return
def remove_all_rules(self):
def remove_all_rules(self) -> None:
"""Removes all rules."""
self.acl.clear()
def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port):
def get_dictionary_hash(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> int:
"""
Produces a hash value for a rule.
@@ -117,7 +119,9 @@ class AccessControlList:
hash_value = hash(rule)
return hash_value
def get_relevant_rules(self, _source_ip_address, _dest_ip_address, _protocol, _port):
def get_relevant_rules(
self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str
) -> Dict[int, ACLRule]:
"""Get all ACL rules that relate to the given arguments.
:param _source_ip_address: the source IP address to check
@@ -125,9 +129,9 @@ class AccessControlList:
:param _protocol: the protocol to check
:param _port: the port to check
:return: Dictionary of all ACL rules that relate to the given arguments
:rtype: Dict[str, ACLRule]
:rtype: Dict[int, ACLRule]
"""
relevant_rules = {}
relevant_rules: Dict[int, ACLRule] = {}
for rule_key, rule_value in self.acl.items():
if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address):

View File

@@ -5,7 +5,7 @@
class ACLRule:
"""Access Control List Rule class."""
def __init__(self, _permission, _source_ip, _dest_ip, _protocol, _port):
def __init__(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None:
"""
Initialise an ACL Rule.
@@ -15,13 +15,13 @@ class ACLRule:
:param _protocol: The rule protocol
:param _port: The rule port
"""
self.permission = _permission
self.source_ip = _source_ip
self.dest_ip = _dest_ip
self.protocol = _protocol
self.port = _port
self.permission: str = _permission
self.source_ip: str = _source_ip
self.dest_ip: str = _dest_ip
self.protocol: str = _protocol
self.port: str = _port
def __hash__(self):
def __hash__(self) -> int:
"""
Override the hash function.
@@ -38,7 +38,7 @@ class ACLRule:
)
)
def get_permission(self):
def get_permission(self) -> str:
"""
Gets the permission attribute.
@@ -47,7 +47,7 @@ class ACLRule:
"""
return self.permission
def get_source_ip(self):
def get_source_ip(self) -> str:
"""
Gets the source IP address attribute.
@@ -56,7 +56,7 @@ class ACLRule:
"""
return self.source_ip
def get_dest_ip(self):
def get_dest_ip(self) -> str:
"""
Gets the desintation IP address attribute.
@@ -65,7 +65,7 @@ class ACLRule:
"""
return self.dest_ip
def get_protocol(self):
def get_protocol(self) -> str:
"""
Gets the protocol attribute.
@@ -74,7 +74,7 @@ class ACLRule:
"""
return self.protocol
def get_port(self):
def get_port(self) -> str:
"""
Gets the port attribute.