Type hint ACLs
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||||
"""A class that implements the access control list implementation for the network."""
|
"""A class that implements the access control list implementation for the network."""
|
||||||
from typing import Dict
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from primaite.acl.acl_rule import ACLRule
|
from primaite.acl.acl_rule import ACLRule
|
||||||
|
|
||||||
@@ -8,9 +8,9 @@ from primaite.acl.acl_rule import ACLRule
|
|||||||
class AccessControlList:
|
class AccessControlList:
|
||||||
"""Access Control List class."""
|
"""Access Control List class."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
"""Initialise an empty AccessControlList."""
|
"""Initialise an empty AccessControlList."""
|
||||||
self.acl: Dict[str, ACLRule] = {} # A dictionary of ACL Rules
|
self.acl: Dict[int, ACLRule] = {} # A dictionary of ACL Rules
|
||||||
|
|
||||||
def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool:
|
def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool:
|
||||||
"""Checks for IP address matches.
|
"""Checks for IP address matches.
|
||||||
@@ -61,7 +61,7 @@ class AccessControlList:
|
|||||||
# If there has been no rule to allow the IER through, it will return a blocked signal by default
|
# If there has been no rule to allow the IER through, it will return a blocked signal by default
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
def add_rule(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None:
|
||||||
"""
|
"""
|
||||||
Adds a new rule.
|
Adds a new rule.
|
||||||
|
|
||||||
@@ -76,7 +76,9 @@ class AccessControlList:
|
|||||||
hash_value = hash(new_rule)
|
hash_value = hash(new_rule)
|
||||||
self.acl[hash_value] = new_rule
|
self.acl[hash_value] = new_rule
|
||||||
|
|
||||||
def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
def remove_rule(
|
||||||
|
self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
|
||||||
|
) -> Optional[int]:
|
||||||
"""
|
"""
|
||||||
Removes a rule.
|
Removes a rule.
|
||||||
|
|
||||||
@@ -95,11 +97,11 @@ class AccessControlList:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return
|
return
|
||||||
|
|
||||||
def remove_all_rules(self):
|
def remove_all_rules(self) -> None:
|
||||||
"""Removes all rules."""
|
"""Removes all rules."""
|
||||||
self.acl.clear()
|
self.acl.clear()
|
||||||
|
|
||||||
def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
def get_dictionary_hash(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> int:
|
||||||
"""
|
"""
|
||||||
Produces a hash value for a rule.
|
Produces a hash value for a rule.
|
||||||
|
|
||||||
@@ -117,7 +119,9 @@ class AccessControlList:
|
|||||||
hash_value = hash(rule)
|
hash_value = hash(rule)
|
||||||
return hash_value
|
return hash_value
|
||||||
|
|
||||||
def get_relevant_rules(self, _source_ip_address, _dest_ip_address, _protocol, _port):
|
def get_relevant_rules(
|
||||||
|
self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str
|
||||||
|
) -> Dict[int, ACLRule]:
|
||||||
"""Get all ACL rules that relate to the given arguments.
|
"""Get all ACL rules that relate to the given arguments.
|
||||||
|
|
||||||
:param _source_ip_address: the source IP address to check
|
:param _source_ip_address: the source IP address to check
|
||||||
@@ -125,9 +129,9 @@ class AccessControlList:
|
|||||||
:param _protocol: the protocol to check
|
:param _protocol: the protocol to check
|
||||||
:param _port: the port to check
|
:param _port: the port to check
|
||||||
:return: Dictionary of all ACL rules that relate to the given arguments
|
:return: Dictionary of all ACL rules that relate to the given arguments
|
||||||
:rtype: Dict[str, ACLRule]
|
:rtype: Dict[int, ACLRule]
|
||||||
"""
|
"""
|
||||||
relevant_rules = {}
|
relevant_rules: Dict[int, ACLRule] = {}
|
||||||
|
|
||||||
for rule_key, rule_value in self.acl.items():
|
for rule_key, rule_value in self.acl.items():
|
||||||
if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address):
|
if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address):
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
class ACLRule:
|
class ACLRule:
|
||||||
"""Access Control List Rule class."""
|
"""Access Control List Rule class."""
|
||||||
|
|
||||||
def __init__(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
def __init__(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None:
|
||||||
"""
|
"""
|
||||||
Initialise an ACL Rule.
|
Initialise an ACL Rule.
|
||||||
|
|
||||||
@@ -15,13 +15,13 @@ class ACLRule:
|
|||||||
:param _protocol: The rule protocol
|
:param _protocol: The rule protocol
|
||||||
:param _port: The rule port
|
:param _port: The rule port
|
||||||
"""
|
"""
|
||||||
self.permission = _permission
|
self.permission: str = _permission
|
||||||
self.source_ip = _source_ip
|
self.source_ip: str = _source_ip
|
||||||
self.dest_ip = _dest_ip
|
self.dest_ip: str = _dest_ip
|
||||||
self.protocol = _protocol
|
self.protocol: str = _protocol
|
||||||
self.port = _port
|
self.port: str = _port
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self) -> int:
|
||||||
"""
|
"""
|
||||||
Override the hash function.
|
Override the hash function.
|
||||||
|
|
||||||
@@ -38,7 +38,7 @@ class ACLRule:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_permission(self):
|
def get_permission(self) -> str:
|
||||||
"""
|
"""
|
||||||
Gets the permission attribute.
|
Gets the permission attribute.
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ class ACLRule:
|
|||||||
"""
|
"""
|
||||||
return self.permission
|
return self.permission
|
||||||
|
|
||||||
def get_source_ip(self):
|
def get_source_ip(self) -> str:
|
||||||
"""
|
"""
|
||||||
Gets the source IP address attribute.
|
Gets the source IP address attribute.
|
||||||
|
|
||||||
@@ -56,7 +56,7 @@ class ACLRule:
|
|||||||
"""
|
"""
|
||||||
return self.source_ip
|
return self.source_ip
|
||||||
|
|
||||||
def get_dest_ip(self):
|
def get_dest_ip(self) -> str:
|
||||||
"""
|
"""
|
||||||
Gets the desintation IP address attribute.
|
Gets the desintation IP address attribute.
|
||||||
|
|
||||||
@@ -65,7 +65,7 @@ class ACLRule:
|
|||||||
"""
|
"""
|
||||||
return self.dest_ip
|
return self.dest_ip
|
||||||
|
|
||||||
def get_protocol(self):
|
def get_protocol(self) -> str:
|
||||||
"""
|
"""
|
||||||
Gets the protocol attribute.
|
Gets the protocol attribute.
|
||||||
|
|
||||||
@@ -74,7 +74,7 @@ class ACLRule:
|
|||||||
"""
|
"""
|
||||||
return self.protocol
|
return self.protocol
|
||||||
|
|
||||||
def get_port(self):
|
def get_port(self) -> str:
|
||||||
"""
|
"""
|
||||||
Gets the port attribute.
|
Gets the port attribute.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user