#917 - Fixed the RLlib integration

- Dropped support for overriding the num_episodes and num_steps at the agent level. It's just not needed and will add complexity when overriding and writing output files.
This commit is contained in:
Chris McCarthy
2023-06-30 16:52:57 +01:00
parent 00185d3dad
commit e11fd2ced4
43 changed files with 284 additions and 896 deletions

View File

@@ -10,9 +10,7 @@ class AccessControlList:
def __init__(self):
"""Init."""
self.acl: Dict[
str, AccessControlList
] = {} # A dictionary of ACL Rules
self.acl: Dict[str, AccessControlList] = {} # A dictionary of ACL Rules
def check_address_match(self, _rule, _source_ip_address, _dest_ip_address):
"""
@@ -27,29 +25,16 @@ class AccessControlList:
True if match; False otherwise.
"""
if (
(
_rule.get_source_ip() == _source_ip_address
and _rule.get_dest_ip() == _dest_ip_address
)
or (
_rule.get_source_ip() == "ANY"
and _rule.get_dest_ip() == _dest_ip_address
)
or (
_rule.get_source_ip() == _source_ip_address
and _rule.get_dest_ip() == "ANY"
)
or (
_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY"
)
(_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address)
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == _dest_ip_address)
or (_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == "ANY")
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY")
):
return True
else:
return False
def is_blocked(
self, _source_ip_address, _dest_ip_address, _protocol, _port
):
def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port):
"""
Checks for rules that block a protocol / port.
@@ -63,15 +48,9 @@ class AccessControlList:
Indicates block if all conditions are satisfied.
"""
for rule_key, rule_value in self.acl.items():
if self.check_address_match(
rule_value, _source_ip_address, _dest_ip_address
):
if (
rule_value.get_protocol() == _protocol
or rule_value.get_protocol() == "ANY"
) and (
str(rule_value.get_port()) == str(_port)
or rule_value.get_port() == "ANY"
if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address):
if (rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY") and (
str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY"
):
# There's a matching rule. Get the permission
if rule_value.get_permission() == "DENY":
@@ -93,9 +72,7 @@ class AccessControlList:
_protocol: the protocol
_port: the port
"""
new_rule = ACLRule(
_permission, _source_ip, _dest_ip, _protocol, str(_port)
)
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
hash_value = hash(new_rule)
self.acl[hash_value] = new_rule
@@ -110,9 +87,7 @@ class AccessControlList:
_protocol: the protocol
_port: the port
"""
rule = ACLRule(
_permission, _source_ip, _dest_ip, _protocol, str(_port)
)
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
hash_value = hash(rule)
# There will not always be something 'popable' since the agent will be trying random things
try:
@@ -124,9 +99,7 @@ class AccessControlList:
"""Removes all rules."""
self.acl.clear()
def get_dictionary_hash(
self, _permission, _source_ip, _dest_ip, _protocol, _port
):
def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port):
"""
Produces a hash value for a rule.
@@ -140,8 +113,6 @@ class AccessControlList:
Returns:
Hash value based on rule parameters.
"""
rule = ACLRule(
_permission, _source_ip, _dest_ip, _protocol, str(_port)
)
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
hash_value = hash(rule)
return hash_value