Type hint ACLs
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""A class that implements the access control list implementation for the network."""
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
|
||||
@@ -8,9 +8,9 @@ from primaite.acl.acl_rule import ACLRule
|
||||
class AccessControlList:
|
||||
"""Access Control List class."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""Initialise an empty AccessControlList."""
|
||||
self.acl: Dict[str, ACLRule] = {} # A dictionary of ACL Rules
|
||||
self.acl: Dict[int, ACLRule] = {} # A dictionary of ACL Rules
|
||||
|
||||
def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool:
|
||||
"""Checks for IP address matches.
|
||||
@@ -61,7 +61,7 @@ class AccessControlList:
|
||||
# If there has been no rule to allow the IER through, it will return a blocked signal by default
|
||||
return True
|
||||
|
||||
def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
def add_rule(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None:
|
||||
"""
|
||||
Adds a new rule.
|
||||
|
||||
@@ -76,7 +76,9 @@ class AccessControlList:
|
||||
hash_value = hash(new_rule)
|
||||
self.acl[hash_value] = new_rule
|
||||
|
||||
def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
def remove_rule(
|
||||
self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
Removes a rule.
|
||||
|
||||
@@ -95,11 +97,11 @@ class AccessControlList:
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def remove_all_rules(self):
|
||||
def remove_all_rules(self) -> None:
|
||||
"""Removes all rules."""
|
||||
self.acl.clear()
|
||||
|
||||
def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
def get_dictionary_hash(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> int:
|
||||
"""
|
||||
Produces a hash value for a rule.
|
||||
|
||||
@@ -117,7 +119,9 @@ class AccessControlList:
|
||||
hash_value = hash(rule)
|
||||
return hash_value
|
||||
|
||||
def get_relevant_rules(self, _source_ip_address, _dest_ip_address, _protocol, _port):
|
||||
def get_relevant_rules(
|
||||
self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Get all ACL rules that relate to the given arguments.
|
||||
|
||||
:param _source_ip_address: the source IP address to check
|
||||
@@ -125,9 +129,9 @@ class AccessControlList:
|
||||
:param _protocol: the protocol to check
|
||||
:param _port: the port to check
|
||||
:return: Dictionary of all ACL rules that relate to the given arguments
|
||||
:rtype: Dict[str, ACLRule]
|
||||
:rtype: Dict[int, ACLRule]
|
||||
"""
|
||||
relevant_rules = {}
|
||||
relevant_rules: Dict[int, ACLRule] = {}
|
||||
|
||||
for rule_key, rule_value in self.acl.items():
|
||||
if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address):
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
class ACLRule:
|
||||
"""Access Control List Rule class."""
|
||||
|
||||
def __init__(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
def __init__(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None:
|
||||
"""
|
||||
Initialise an ACL Rule.
|
||||
|
||||
@@ -15,13 +15,13 @@ class ACLRule:
|
||||
:param _protocol: The rule protocol
|
||||
:param _port: The rule port
|
||||
"""
|
||||
self.permission = _permission
|
||||
self.source_ip = _source_ip
|
||||
self.dest_ip = _dest_ip
|
||||
self.protocol = _protocol
|
||||
self.port = _port
|
||||
self.permission: str = _permission
|
||||
self.source_ip: str = _source_ip
|
||||
self.dest_ip: str = _dest_ip
|
||||
self.protocol: str = _protocol
|
||||
self.port: str = _port
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
"""
|
||||
Override the hash function.
|
||||
|
||||
@@ -38,7 +38,7 @@ class ACLRule:
|
||||
)
|
||||
)
|
||||
|
||||
def get_permission(self):
|
||||
def get_permission(self) -> str:
|
||||
"""
|
||||
Gets the permission attribute.
|
||||
|
||||
@@ -47,7 +47,7 @@ class ACLRule:
|
||||
"""
|
||||
return self.permission
|
||||
|
||||
def get_source_ip(self):
|
||||
def get_source_ip(self) -> str:
|
||||
"""
|
||||
Gets the source IP address attribute.
|
||||
|
||||
@@ -56,7 +56,7 @@ class ACLRule:
|
||||
"""
|
||||
return self.source_ip
|
||||
|
||||
def get_dest_ip(self):
|
||||
def get_dest_ip(self) -> str:
|
||||
"""
|
||||
Gets the desintation IP address attribute.
|
||||
|
||||
@@ -65,7 +65,7 @@ class ACLRule:
|
||||
"""
|
||||
return self.dest_ip
|
||||
|
||||
def get_protocol(self):
|
||||
def get_protocol(self) -> str:
|
||||
"""
|
||||
Gets the protocol attribute.
|
||||
|
||||
@@ -74,7 +74,7 @@ class ACLRule:
|
||||
"""
|
||||
return self.protocol
|
||||
|
||||
def get_port(self):
|
||||
def get_port(self) -> str:
|
||||
"""
|
||||
Gets the port attribute.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user