#917 - Fixed the RLlib integration
- Dropped support for overriding the num_episodes and num_steps at the agent level. It's just not needed and will add complexity when overriding and writing output files.
This commit is contained in:
@@ -14,10 +14,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
"""An Agent Session class that implements a deterministic ACL agent."""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
if (
|
||||
self._training_config.hard_coded_agent_view
|
||||
== HardCodedAgentView.BASIC
|
||||
):
|
||||
if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC:
|
||||
# Basic view action using only the current observation
|
||||
return self._calculate_action_basic_view(obs)
|
||||
else:
|
||||
@@ -43,9 +40,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
port = green_ier.get_port()
|
||||
|
||||
# Can be blocked by an ACL or by default (no allow rule exists)
|
||||
if acl.is_blocked(
|
||||
source_node_address, dest_node_address, protocol, port
|
||||
):
|
||||
if acl.is_blocked(source_node_address, dest_node_address, protocol, port):
|
||||
blocked_green_iers[green_ier_id] = green_ier
|
||||
|
||||
return blocked_green_iers
|
||||
@@ -64,9 +59,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
protocol = ier.get_protocol() # e.g. 'TCP'
|
||||
port = ier.get_port()
|
||||
|
||||
matching_rules = acl.get_relevant_rules(
|
||||
source_node_address, dest_node_address, protocol, port
|
||||
)
|
||||
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
|
||||
return matching_rules
|
||||
|
||||
def get_blocking_acl_rules_for_ier(self, ier, acl, nodes):
|
||||
@@ -132,13 +125,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
dest_node_address = dest_node_id
|
||||
|
||||
if protocol != "ANY":
|
||||
protocol = services_list[
|
||||
protocol - 1
|
||||
] # -1 as dont have to account for ANY in list of services
|
||||
protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services
|
||||
|
||||
matching_rules = acl.get_relevant_rules(
|
||||
source_node_address, dest_node_address, protocol, port
|
||||
)
|
||||
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
|
||||
return matching_rules
|
||||
|
||||
def get_allow_acl_rules(
|
||||
@@ -283,19 +272,12 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
action_decision = "DELETE"
|
||||
action_permission = "ALLOW"
|
||||
action_source_ip = rule.get_source_ip()
|
||||
action_source_id = int(
|
||||
get_node_of_ip(action_source_ip, self._env.nodes)
|
||||
)
|
||||
action_source_id = int(get_node_of_ip(action_source_ip, self._env.nodes))
|
||||
action_destination_ip = rule.get_dest_ip()
|
||||
action_destination_id = int(
|
||||
get_node_of_ip(
|
||||
action_destination_ip, self._env.nodes
|
||||
)
|
||||
)
|
||||
action_destination_id = int(get_node_of_ip(action_destination_ip, self._env.nodes))
|
||||
action_protocol_name = rule.get_protocol()
|
||||
action_protocol = (
|
||||
self._env.services_list.index(action_protocol_name)
|
||||
+ 1
|
||||
self._env.services_list.index(action_protocol_name) + 1
|
||||
) # convert name e.g. 'TCP' to index
|
||||
action_port_name = rule.get_port()
|
||||
action_port = (
|
||||
@@ -330,22 +312,16 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
if not found_action:
|
||||
# Which Green IERS are blocked
|
||||
blocked_green_iers = self.get_blocked_green_iers(
|
||||
self._env.green_iers, self._env.acl, self._env.nodes
|
||||
)
|
||||
blocked_green_iers = self.get_blocked_green_iers(self._env.green_iers, self._env.acl, self._env.nodes)
|
||||
for ier_key, ier in blocked_green_iers.items():
|
||||
# Which ALLOW rules are allowing this IER (none)
|
||||
allowing_rules = self.get_allow_acl_rules_for_ier(
|
||||
ier, self._env.acl, self._env.nodes
|
||||
)
|
||||
allowing_rules = self.get_allow_acl_rules_for_ier(ier, self._env.acl, self._env.nodes)
|
||||
|
||||
# If there are no blocking rules, it may be being blocked by default
|
||||
# If there is already an allow rule
|
||||
node_id_to_check = int(ier.get_source_node_id())
|
||||
service_name_to_check = ier.get_protocol()
|
||||
service_id_to_check = self._env.services_list.index(
|
||||
service_name_to_check
|
||||
)
|
||||
service_id_to_check = self._env.services_list.index(service_name_to_check)
|
||||
|
||||
# Service state of the the source node in the ier
|
||||
service_state = s[service_id_to_check][node_id_to_check - 1]
|
||||
@@ -413,31 +389,21 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
if len(r_obs) == 4: # only 1 service
|
||||
s = [*s]
|
||||
|
||||
number_of_nodes = len(
|
||||
[i for i in o if i != "NONE"]
|
||||
) # number of nodes (not links)
|
||||
number_of_nodes = len([i for i in o if i != "NONE"]) # number of nodes (not links)
|
||||
for service_num, service_states in enumerate(s):
|
||||
comprimised_states = [
|
||||
n for n, i in enumerate(service_states) if i == "COMPROMISED"
|
||||
]
|
||||
comprimised_states = [n for n, i in enumerate(service_states) if i == "COMPROMISED"]
|
||||
if len(comprimised_states) == 0:
|
||||
# No states are COMPROMISED, try the next service
|
||||
continue
|
||||
|
||||
compromised_node = (
|
||||
np.random.choice(comprimised_states) + 1
|
||||
) # +1 as 0 would be any
|
||||
compromised_node = np.random.choice(comprimised_states) + 1 # +1 as 0 would be any
|
||||
action_decision = "DELETE"
|
||||
action_permission = "ALLOW"
|
||||
action_source_ip = compromised_node
|
||||
# Randomly select a destination ID to block
|
||||
action_destination_ip = np.random.choice(
|
||||
list(range(1, number_of_nodes + 1)) + ["ANY"]
|
||||
)
|
||||
action_destination_ip = np.random.choice(list(range(1, number_of_nodes + 1)) + ["ANY"])
|
||||
action_destination_ip = (
|
||||
int(action_destination_ip)
|
||||
if action_destination_ip != "ANY"
|
||||
else action_destination_ip
|
||||
int(action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip
|
||||
)
|
||||
action_protocol = service_num + 1 # +1 as 0 is any
|
||||
# Randomly select a port to block
|
||||
|
||||
Reference in New Issue
Block a user