#917 - Fixed the RLlib integration

- Dropped support for overriding the num_episodes and num_steps at the agent level. It's just not needed and will add complexity when overriding and writing output files.
This commit is contained in:
Chris McCarthy
2023-06-30 16:52:57 +01:00
parent 00185d3dad
commit e11fd2ced4
43 changed files with 284 additions and 896 deletions

View File

@@ -14,10 +14,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
"""An Agent Session class that implements a deterministic ACL agent."""
def _calculate_action(self, obs):
if (
self._training_config.hard_coded_agent_view
== HardCodedAgentView.BASIC
):
if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC:
# Basic view action using only the current observation
return self._calculate_action_basic_view(obs)
else:
@@ -43,9 +40,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
port = green_ier.get_port()
# Can be blocked by an ACL or by default (no allow rule exists)
if acl.is_blocked(
source_node_address, dest_node_address, protocol, port
):
if acl.is_blocked(source_node_address, dest_node_address, protocol, port):
blocked_green_iers[green_ier_id] = green_ier
return blocked_green_iers
@@ -64,9 +59,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
protocol = ier.get_protocol() # e.g. 'TCP'
port = ier.get_port()
matching_rules = acl.get_relevant_rules(
source_node_address, dest_node_address, protocol, port
)
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
return matching_rules
def get_blocking_acl_rules_for_ier(self, ier, acl, nodes):
@@ -132,13 +125,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
dest_node_address = dest_node_id
if protocol != "ANY":
protocol = services_list[
protocol - 1
] # -1 as dont have to account for ANY in list of services
protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services
matching_rules = acl.get_relevant_rules(
source_node_address, dest_node_address, protocol, port
)
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
return matching_rules
def get_allow_acl_rules(
@@ -283,19 +272,12 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
action_decision = "DELETE"
action_permission = "ALLOW"
action_source_ip = rule.get_source_ip()
action_source_id = int(
get_node_of_ip(action_source_ip, self._env.nodes)
)
action_source_id = int(get_node_of_ip(action_source_ip, self._env.nodes))
action_destination_ip = rule.get_dest_ip()
action_destination_id = int(
get_node_of_ip(
action_destination_ip, self._env.nodes
)
)
action_destination_id = int(get_node_of_ip(action_destination_ip, self._env.nodes))
action_protocol_name = rule.get_protocol()
action_protocol = (
self._env.services_list.index(action_protocol_name)
+ 1
self._env.services_list.index(action_protocol_name) + 1
) # convert name e.g. 'TCP' to index
action_port_name = rule.get_port()
action_port = (
@@ -330,22 +312,16 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
if not found_action:
# Which Green IERS are blocked
blocked_green_iers = self.get_blocked_green_iers(
self._env.green_iers, self._env.acl, self._env.nodes
)
blocked_green_iers = self.get_blocked_green_iers(self._env.green_iers, self._env.acl, self._env.nodes)
for ier_key, ier in blocked_green_iers.items():
# Which ALLOW rules are allowing this IER (none)
allowing_rules = self.get_allow_acl_rules_for_ier(
ier, self._env.acl, self._env.nodes
)
allowing_rules = self.get_allow_acl_rules_for_ier(ier, self._env.acl, self._env.nodes)
# If there are no blocking rules, it may be being blocked by default
# If there is already an allow rule
node_id_to_check = int(ier.get_source_node_id())
service_name_to_check = ier.get_protocol()
service_id_to_check = self._env.services_list.index(
service_name_to_check
)
service_id_to_check = self._env.services_list.index(service_name_to_check)
# Service state of the the source node in the ier
service_state = s[service_id_to_check][node_id_to_check - 1]
@@ -413,31 +389,21 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
if len(r_obs) == 4: # only 1 service
s = [*s]
number_of_nodes = len(
[i for i in o if i != "NONE"]
) # number of nodes (not links)
number_of_nodes = len([i for i in o if i != "NONE"]) # number of nodes (not links)
for service_num, service_states in enumerate(s):
comprimised_states = [
n for n, i in enumerate(service_states) if i == "COMPROMISED"
]
comprimised_states = [n for n, i in enumerate(service_states) if i == "COMPROMISED"]
if len(comprimised_states) == 0:
# No states are COMPROMISED, try the next service
continue
compromised_node = (
np.random.choice(comprimised_states) + 1
) # +1 as 0 would be any
compromised_node = np.random.choice(comprimised_states) + 1 # +1 as 0 would be any
action_decision = "DELETE"
action_permission = "ALLOW"
action_source_ip = compromised_node
# Randomly select a destination ID to block
action_destination_ip = np.random.choice(
list(range(1, number_of_nodes + 1)) + ["ANY"]
)
action_destination_ip = np.random.choice(list(range(1, number_of_nodes + 1)) + ["ANY"])
action_destination_ip = (
int(action_destination_ip)
if action_destination_ip != "ANY"
else action_destination_ip
int(action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip
)
action_protocol = service_num + 1 # +1 as 0 is any
# Randomly select a port to block