#917 - Fixed the RLlib integration
- Dropped support for overriding the num_episodes and num_steps at the agent level. It's just not needed and will add complexity when overriding and writing output files.
This commit is contained in:
@@ -6,17 +6,10 @@ from networkx import MultiGraph, shortest_path
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import (
|
||||
HardwareState,
|
||||
NodePOLType,
|
||||
NodeType,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.enums import HardwareState, NodePOLType, NodeType, SoftwareState
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node_state_instruction_green import (
|
||||
NodeStateInstructionGreen,
|
||||
)
|
||||
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.ier import IER
|
||||
@@ -93,9 +86,7 @@ def apply_iers(
|
||||
and source_node.software_state != SoftwareState.PATCHING
|
||||
):
|
||||
if source_node.has_service(protocol):
|
||||
if source_node.service_running(
|
||||
protocol
|
||||
) and not source_node.service_is_overwhelmed(protocol):
|
||||
if source_node.service_running(protocol) and not source_node.service_is_overwhelmed(protocol):
|
||||
source_valid = True
|
||||
else:
|
||||
source_valid = False
|
||||
@@ -110,10 +101,7 @@ def apply_iers(
|
||||
# 2. Check the dest node situation
|
||||
if dest_node.node_type == NodeType.SWITCH:
|
||||
# It's a switch
|
||||
if (
|
||||
dest_node.hardware_state == HardwareState.ON
|
||||
and dest_node.software_state != SoftwareState.PATCHING
|
||||
):
|
||||
if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING:
|
||||
dest_valid = True
|
||||
else:
|
||||
# IER no longer valid
|
||||
@@ -123,14 +111,9 @@ def apply_iers(
|
||||
pass
|
||||
else:
|
||||
# It's not a switch or an actuator (so active node)
|
||||
if (
|
||||
dest_node.hardware_state == HardwareState.ON
|
||||
and dest_node.software_state != SoftwareState.PATCHING
|
||||
):
|
||||
if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING:
|
||||
if dest_node.has_service(protocol):
|
||||
if dest_node.service_running(
|
||||
protocol
|
||||
) and not dest_node.service_is_overwhelmed(protocol):
|
||||
if dest_node.service_running(protocol) and not dest_node.service_is_overwhelmed(protocol):
|
||||
dest_valid = True
|
||||
else:
|
||||
dest_valid = False
|
||||
@@ -143,9 +126,7 @@ def apply_iers(
|
||||
dest_valid = False
|
||||
|
||||
# 3. Check that the ACL doesn't block it
|
||||
acl_block = acl.is_blocked(
|
||||
source_node.ip_address, dest_node.ip_address, protocol, port
|
||||
)
|
||||
acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port)
|
||||
if acl_block:
|
||||
if _VERBOSE:
|
||||
print(
|
||||
@@ -176,10 +157,7 @@ def apply_iers(
|
||||
|
||||
# We might have a switch in the path, so check all nodes are operational
|
||||
for node in path_node_list:
|
||||
if (
|
||||
node.hardware_state != HardwareState.ON
|
||||
or node.software_state == SoftwareState.PATCHING
|
||||
):
|
||||
if node.hardware_state != HardwareState.ON or node.software_state == SoftwareState.PATCHING:
|
||||
path_valid = False
|
||||
|
||||
if path_valid:
|
||||
@@ -191,15 +169,11 @@ def apply_iers(
|
||||
# Check that the link capacity is not exceeded by the new load
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count], path_node_list[count + 1]
|
||||
)
|
||||
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1])
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Check whether the new load exceeds the bandwidth
|
||||
if (
|
||||
link.get_current_load() + load
|
||||
) > link.get_bandwidth():
|
||||
if (link.get_current_load() + load) > link.get_bandwidth():
|
||||
link_capacity_exceeded = True
|
||||
if _VERBOSE:
|
||||
print("Link capacity exceeded")
|
||||
@@ -226,9 +200,7 @@ def apply_iers(
|
||||
else:
|
||||
# One of the nodes is not operational
|
||||
if _VERBOSE:
|
||||
print(
|
||||
"Path not valid - one or more nodes not operational"
|
||||
)
|
||||
print("Path not valid - one or more nodes not operational")
|
||||
pass
|
||||
|
||||
else:
|
||||
@@ -243,9 +215,7 @@ def apply_iers(
|
||||
|
||||
def apply_node_pol(
|
||||
nodes: Dict[str, NodeUnion],
|
||||
node_pol: Dict[
|
||||
any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]
|
||||
],
|
||||
node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]],
|
||||
step: int,
|
||||
):
|
||||
"""
|
||||
@@ -277,22 +247,16 @@ def apply_node_pol(
|
||||
elif node_pol_type == NodePOLType.OS:
|
||||
# Change OS state
|
||||
# Don't allow PoL to fix something that is compromised. Only the Blue agent can do this
|
||||
if isinstance(node, ActiveNode) or isinstance(
|
||||
node, ServiceNode
|
||||
):
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
node.set_software_state_if_not_compromised(state)
|
||||
elif node_pol_type == NodePOLType.SERVICE:
|
||||
# Change a service state
|
||||
# Don't allow PoL to fix something that is compromised. Only the Blue agent can do this
|
||||
if isinstance(node, ServiceNode):
|
||||
node.set_service_state_if_not_compromised(
|
||||
service_name, state
|
||||
)
|
||||
node.set_service_state_if_not_compromised(service_name, state)
|
||||
else:
|
||||
# Change the file system status
|
||||
if isinstance(node, ActiveNode) or isinstance(
|
||||
node, ServiceNode
|
||||
):
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
node.set_file_system_state_if_not_compromised(state)
|
||||
else:
|
||||
# PoL is not valid in this time step
|
||||
|
||||
@@ -6,13 +6,7 @@ from networkx import MultiGraph, shortest_path
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import (
|
||||
HardwareState,
|
||||
NodePOLInitiator,
|
||||
NodePOLType,
|
||||
NodeType,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.enums import HardwareState, NodePOLInitiator, NodePOLType, NodeType, SoftwareState
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
@@ -83,10 +77,7 @@ def apply_red_agent_iers(
|
||||
if source_node.hardware_state == HardwareState.ON:
|
||||
if source_node.has_service(protocol):
|
||||
# Red agents IERs can only be valid if the source service is in a compromised state
|
||||
if (
|
||||
source_node.get_service_state(protocol)
|
||||
== SoftwareState.COMPROMISED
|
||||
):
|
||||
if source_node.get_service_state(protocol) == SoftwareState.COMPROMISED:
|
||||
source_valid = True
|
||||
else:
|
||||
source_valid = False
|
||||
@@ -124,9 +115,7 @@ def apply_red_agent_iers(
|
||||
dest_valid = False
|
||||
|
||||
# 3. Check that the ACL doesn't block it
|
||||
acl_block = acl.is_blocked(
|
||||
source_node.ip_address, dest_node.ip_address, protocol, port
|
||||
)
|
||||
acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port)
|
||||
if acl_block:
|
||||
if _VERBOSE:
|
||||
print(
|
||||
@@ -170,15 +159,11 @@ def apply_red_agent_iers(
|
||||
# Check that the link capacity is not exceeded by the new load
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count], path_node_list[count + 1]
|
||||
)
|
||||
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1])
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Check whether the new load exceeds the bandwidth
|
||||
if (
|
||||
link.get_current_load() + load
|
||||
) > link.get_bandwidth():
|
||||
if (link.get_current_load() + load) > link.get_bandwidth():
|
||||
link_capacity_exceeded = True
|
||||
if _VERBOSE:
|
||||
print("Link capacity exceeded")
|
||||
@@ -203,23 +188,16 @@ def apply_red_agent_iers(
|
||||
# This IER is now valid, so set it to running
|
||||
ier_value.set_is_running(True)
|
||||
if _VERBOSE:
|
||||
print(
|
||||
"Red IER was allowed to run in step "
|
||||
+ str(step)
|
||||
)
|
||||
print("Red IER was allowed to run in step " + str(step))
|
||||
else:
|
||||
# One of the nodes is not operational
|
||||
if _VERBOSE:
|
||||
print(
|
||||
"Path not valid - one or more nodes not operational"
|
||||
)
|
||||
print("Path not valid - one or more nodes not operational")
|
||||
pass
|
||||
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print(
|
||||
"Red IER was NOT allowed to run in step " + str(step)
|
||||
)
|
||||
print("Red IER was NOT allowed to run in step " + str(step))
|
||||
print("Source, Dest or ACL were not valid")
|
||||
pass
|
||||
# ------------------------------------
|
||||
@@ -258,9 +236,7 @@ def apply_red_agent_node_pol(
|
||||
state = node_instruction.get_state()
|
||||
source_node_id = node_instruction.get_source_node_id()
|
||||
source_node_service_name = node_instruction.get_source_node_service()
|
||||
source_node_service_state_value = (
|
||||
node_instruction.get_source_node_service_state()
|
||||
)
|
||||
source_node_service_state_value = node_instruction.get_source_node_service_state()
|
||||
|
||||
passed_checks = False
|
||||
|
||||
@@ -274,9 +250,7 @@ def apply_red_agent_node_pol(
|
||||
passed_checks = True
|
||||
elif initiator == NodePOLInitiator.IER:
|
||||
# Need to check there is a red IER incoming
|
||||
passed_checks = is_red_ier_incoming(
|
||||
target_node, iers, pol_type
|
||||
)
|
||||
passed_checks = is_red_ier_incoming(target_node, iers, pol_type)
|
||||
elif initiator == NodePOLInitiator.SERVICE:
|
||||
# Need to check the condition of a service on another node
|
||||
source_node = nodes[source_node_id]
|
||||
@@ -304,9 +278,7 @@ def apply_red_agent_node_pol(
|
||||
target_node.hardware_state = state
|
||||
elif pol_type == NodePOLType.OS:
|
||||
# Change OS state
|
||||
if isinstance(target_node, ActiveNode) or isinstance(
|
||||
target_node, ServiceNode
|
||||
):
|
||||
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
|
||||
target_node.software_state = state
|
||||
elif pol_type == NodePOLType.SERVICE:
|
||||
# Change a service state
|
||||
@@ -314,15 +286,11 @@ def apply_red_agent_node_pol(
|
||||
target_node.set_service_state(service_name, state)
|
||||
else:
|
||||
# Change the file system status
|
||||
if isinstance(target_node, ActiveNode) or isinstance(
|
||||
target_node, ServiceNode
|
||||
):
|
||||
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
|
||||
target_node.set_file_system_state(state)
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print(
|
||||
"Node Red Agent PoL not allowed - did not pass checks"
|
||||
)
|
||||
print("Node Red Agent PoL not allowed - did not pass checks")
|
||||
else:
|
||||
# PoL is not valid in this time step
|
||||
pass
|
||||
@@ -337,10 +305,7 @@ def is_red_ier_incoming(node, iers, node_pol_type):
|
||||
node_id = node.node_id
|
||||
|
||||
for ier_key, ier_value in iers.items():
|
||||
if (
|
||||
ier_value.get_is_running()
|
||||
and ier_value.get_dest_node_id() == node_id
|
||||
):
|
||||
if ier_value.get_is_running() and ier_value.get_dest_node_id() == node_id:
|
||||
if (
|
||||
node_pol_type == NodePOLType.OPERATING
|
||||
or node_pol_type == NodePOLType.OS
|
||||
|
||||
Reference in New Issue
Block a user