diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index 3b0e9234..c155deed 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -116,3 +116,27 @@ class AccessControlList: rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) hash_value = hash(rule) return hash_value + + def get_relevant_rules(self, _source_ip_address, _dest_ip_address, _protocol, _port): + """Get all ACL rules that relate to the given arguments + + :param _source_ip_address: the source IP address to check + :param _dest_ip_address: the destination IP address to check + :param _protocol: the protocol to check + :param _port: the port to check + :return: Dictionary of all ACL rules that relate to the given arguments + :rtype: Dict[str, ACLRule] + """ + relevant_rules = {} + + for rule_key, rule_value in self.acl.items(): + if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address): + if ( + rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY" or _protocol == "ANY" + ) and ( + str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY" or str(_port) == "ANY" + ): + # There's a matching rule. + relevant_rules[rule_key] = rule_value + + return relevant_rules