#917 - Fixed the RLlib integration

- Dropped support for overriding the num_episodes and num_steps at the agent level. It's just not needed and will add complexity when overriding and writing output files.
This commit is contained in:
Chris McCarthy
2023-06-30 16:52:57 +01:00
parent 00185d3dad
commit e11fd2ced4
43 changed files with 284 additions and 896 deletions

View File

@@ -6,17 +6,10 @@ from networkx import MultiGraph, shortest_path
from primaite.acl.access_control_list import AccessControlList
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
HardwareState,
NodePOLType,
NodeType,
SoftwareState,
)
from primaite.common.enums import HardwareState, NodePOLType, NodeType, SoftwareState
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.node_state_instruction_green import (
NodeStateInstructionGreen,
)
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from primaite.nodes.service_node import ServiceNode
from primaite.pol.ier import IER
@@ -93,9 +86,7 @@ def apply_iers(
and source_node.software_state != SoftwareState.PATCHING
):
if source_node.has_service(protocol):
if source_node.service_running(
protocol
) and not source_node.service_is_overwhelmed(protocol):
if source_node.service_running(protocol) and not source_node.service_is_overwhelmed(protocol):
source_valid = True
else:
source_valid = False
@@ -110,10 +101,7 @@ def apply_iers(
# 2. Check the dest node situation
if dest_node.node_type == NodeType.SWITCH:
# It's a switch
if (
dest_node.hardware_state == HardwareState.ON
and dest_node.software_state != SoftwareState.PATCHING
):
if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING:
dest_valid = True
else:
# IER no longer valid
@@ -123,14 +111,9 @@ def apply_iers(
pass
else:
# It's not a switch or an actuator (so active node)
if (
dest_node.hardware_state == HardwareState.ON
and dest_node.software_state != SoftwareState.PATCHING
):
if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING:
if dest_node.has_service(protocol):
if dest_node.service_running(
protocol
) and not dest_node.service_is_overwhelmed(protocol):
if dest_node.service_running(protocol) and not dest_node.service_is_overwhelmed(protocol):
dest_valid = True
else:
dest_valid = False
@@ -143,9 +126,7 @@ def apply_iers(
dest_valid = False
# 3. Check that the ACL doesn't block it
acl_block = acl.is_blocked(
source_node.ip_address, dest_node.ip_address, protocol, port
)
acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port)
if acl_block:
if _VERBOSE:
print(
@@ -176,10 +157,7 @@ def apply_iers(
# We might have a switch in the path, so check all nodes are operational
for node in path_node_list:
if (
node.hardware_state != HardwareState.ON
or node.software_state == SoftwareState.PATCHING
):
if node.hardware_state != HardwareState.ON or node.software_state == SoftwareState.PATCHING:
path_valid = False
if path_valid:
@@ -191,15 +169,11 @@ def apply_iers(
# Check that the link capacity is not exceeded by the new load
while count < path_node_list_length - 1:
# Get the link between the next two nodes
edge_dict = network.get_edge_data(
path_node_list[count], path_node_list[count + 1]
)
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1])
link_id = edge_dict[0].get("id")
link = links[link_id]
# Check whether the new load exceeds the bandwidth
if (
link.get_current_load() + load
) > link.get_bandwidth():
if (link.get_current_load() + load) > link.get_bandwidth():
link_capacity_exceeded = True
if _VERBOSE:
print("Link capacity exceeded")
@@ -226,9 +200,7 @@ def apply_iers(
else:
# One of the nodes is not operational
if _VERBOSE:
print(
"Path not valid - one or more nodes not operational"
)
print("Path not valid - one or more nodes not operational")
pass
else:
@@ -243,9 +215,7 @@ def apply_iers(
def apply_node_pol(
nodes: Dict[str, NodeUnion],
node_pol: Dict[
any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]
],
node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]],
step: int,
):
"""
@@ -277,22 +247,16 @@ def apply_node_pol(
elif node_pol_type == NodePOLType.OS:
# Change OS state
# Don't allow PoL to fix something that is compromised. Only the Blue agent can do this
if isinstance(node, ActiveNode) or isinstance(
node, ServiceNode
):
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
node.set_software_state_if_not_compromised(state)
elif node_pol_type == NodePOLType.SERVICE:
# Change a service state
# Don't allow PoL to fix something that is compromised. Only the Blue agent can do this
if isinstance(node, ServiceNode):
node.set_service_state_if_not_compromised(
service_name, state
)
node.set_service_state_if_not_compromised(service_name, state)
else:
# Change the file system status
if isinstance(node, ActiveNode) or isinstance(
node, ServiceNode
):
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
node.set_file_system_state_if_not_compromised(state)
else:
# PoL is not valid in this time step

View File

@@ -6,13 +6,7 @@ from networkx import MultiGraph, shortest_path
from primaite.acl.access_control_list import AccessControlList
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
HardwareState,
NodePOLInitiator,
NodePOLType,
NodeType,
SoftwareState,
)
from primaite.common.enums import HardwareState, NodePOLInitiator, NodePOLType, NodeType, SoftwareState
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
@@ -83,10 +77,7 @@ def apply_red_agent_iers(
if source_node.hardware_state == HardwareState.ON:
if source_node.has_service(protocol):
# Red agents IERs can only be valid if the source service is in a compromised state
if (
source_node.get_service_state(protocol)
== SoftwareState.COMPROMISED
):
if source_node.get_service_state(protocol) == SoftwareState.COMPROMISED:
source_valid = True
else:
source_valid = False
@@ -124,9 +115,7 @@ def apply_red_agent_iers(
dest_valid = False
# 3. Check that the ACL doesn't block it
acl_block = acl.is_blocked(
source_node.ip_address, dest_node.ip_address, protocol, port
)
acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port)
if acl_block:
if _VERBOSE:
print(
@@ -170,15 +159,11 @@ def apply_red_agent_iers(
# Check that the link capacity is not exceeded by the new load
while count < path_node_list_length - 1:
# Get the link between the next two nodes
edge_dict = network.get_edge_data(
path_node_list[count], path_node_list[count + 1]
)
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1])
link_id = edge_dict[0].get("id")
link = links[link_id]
# Check whether the new load exceeds the bandwidth
if (
link.get_current_load() + load
) > link.get_bandwidth():
if (link.get_current_load() + load) > link.get_bandwidth():
link_capacity_exceeded = True
if _VERBOSE:
print("Link capacity exceeded")
@@ -203,23 +188,16 @@ def apply_red_agent_iers(
# This IER is now valid, so set it to running
ier_value.set_is_running(True)
if _VERBOSE:
print(
"Red IER was allowed to run in step "
+ str(step)
)
print("Red IER was allowed to run in step " + str(step))
else:
# One of the nodes is not operational
if _VERBOSE:
print(
"Path not valid - one or more nodes not operational"
)
print("Path not valid - one or more nodes not operational")
pass
else:
if _VERBOSE:
print(
"Red IER was NOT allowed to run in step " + str(step)
)
print("Red IER was NOT allowed to run in step " + str(step))
print("Source, Dest or ACL were not valid")
pass
# ------------------------------------
@@ -258,9 +236,7 @@ def apply_red_agent_node_pol(
state = node_instruction.get_state()
source_node_id = node_instruction.get_source_node_id()
source_node_service_name = node_instruction.get_source_node_service()
source_node_service_state_value = (
node_instruction.get_source_node_service_state()
)
source_node_service_state_value = node_instruction.get_source_node_service_state()
passed_checks = False
@@ -274,9 +250,7 @@ def apply_red_agent_node_pol(
passed_checks = True
elif initiator == NodePOLInitiator.IER:
# Need to check there is a red IER incoming
passed_checks = is_red_ier_incoming(
target_node, iers, pol_type
)
passed_checks = is_red_ier_incoming(target_node, iers, pol_type)
elif initiator == NodePOLInitiator.SERVICE:
# Need to check the condition of a service on another node
source_node = nodes[source_node_id]
@@ -304,9 +278,7 @@ def apply_red_agent_node_pol(
target_node.hardware_state = state
elif pol_type == NodePOLType.OS:
# Change OS state
if isinstance(target_node, ActiveNode) or isinstance(
target_node, ServiceNode
):
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
target_node.software_state = state
elif pol_type == NodePOLType.SERVICE:
# Change a service state
@@ -314,15 +286,11 @@ def apply_red_agent_node_pol(
target_node.set_service_state(service_name, state)
else:
# Change the file system status
if isinstance(target_node, ActiveNode) or isinstance(
target_node, ServiceNode
):
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
target_node.set_file_system_state(state)
else:
if _VERBOSE:
print(
"Node Red Agent PoL not allowed - did not pass checks"
)
print("Node Red Agent PoL not allowed - did not pass checks")
else:
# PoL is not valid in this time step
pass
@@ -337,10 +305,7 @@ def is_red_ier_incoming(node, iers, node_pol_type):
node_id = node.node_id
for ier_key, ier_value in iers.items():
if (
ier_value.get_is_running()
and ier_value.get_dest_node_id() == node_id
):
if ier_value.get_is_running() and ier_value.get_dest_node_id() == node_id:
if (
node_pol_type == NodePOLType.OPERATING
or node_pol_type == NodePOLType.OS