From 05ebd15053c9da300000529cdf9ac1aef80e7e73 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Fri, 26 May 2023 09:43:37 +0100 Subject: [PATCH] #1355 - Renamed the NodeType custom type in custom_typing.py as it clased with the NodeType enum in enums.py --- src/primaite/common/custom_typing.py | 2 +- src/primaite/environment/primaite_env.py | 8 ++++---- src/primaite/pol/green_pol.py | 6 +++--- src/primaite/pol/red_agent_pol.py | 8 ++++---- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/primaite/common/custom_typing.py b/src/primaite/common/custom_typing.py index 03ce6a3b..37b10efe 100644 --- a/src/primaite/common/custom_typing.py +++ b/src/primaite/common/custom_typing.py @@ -4,5 +4,5 @@ from primaite.nodes.active_node import ActiveNode from primaite.nodes.passive_node import PassiveNode from primaite.nodes.service_node import ServiceNode -NodeType: Type = Union[ActiveNode, PassiveNode, ServiceNode] +NodeUnion: Type = Union[ActiveNode, PassiveNode, ServiceNode] """A Union of ActiveNode, PassiveNode, and ServiceNode.""" diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 5d68009d..99c7c09f 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -15,7 +15,7 @@ from gym import Env, spaces from matplotlib import pyplot as plt from primaite.acl.access_control_list import AccessControlList -from primaite.common.custom_typing import NodeType +from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( ActionType, FileSystemState, @@ -80,10 +80,10 @@ class Primaite(Env): self.agent_identifier = self.config_values.agent_identifier # Create a dictionary to hold all the nodes - self.nodes: Dict[str, NodeType] = {} + self.nodes: Dict[str, NodeUnion] = {} # Create a dictionary to hold a reference set of nodes - self.nodes_reference: Dict[str, NodeType] = {} + self.nodes_reference: Dict[str, NodeUnion] = {} # Create a dictionary to hold all the links self.links: Dict[str, Link] = {} @@ -1128,7 +1128,7 @@ class Primaite(Env): node_class = item["node_class"] node_hardware_state: HardwareState = HardwareState[item["hardware_state"]] - node: NodeType = self.nodes[node_id] + node: NodeUnion = self.nodes[node_id] node_ref = self.nodes_reference[node_id] # Reset the hardware state (common for all node types) diff --git a/src/primaite/pol/green_pol.py b/src/primaite/pol/green_pol.py index e8dbd2c2..1d05dc3f 100644 --- a/src/primaite/pol/green_pol.py +++ b/src/primaite/pol/green_pol.py @@ -5,7 +5,7 @@ from typing import Dict, Union from networkx import MultiGraph, shortest_path from primaite.acl.access_control_list import AccessControlList -from primaite.common.custom_typing import NodeType +from primaite.common.custom_typing import NodeUnion from primaite.common.enums import HardwareState, NodePOLType, NodeType, SoftwareState from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode @@ -19,7 +19,7 @@ _VERBOSE = False def apply_iers( network: MultiGraph, - nodes: Dict[str, NodeType], + nodes: Dict[str, NodeUnion], links: Dict[str, Link], iers: Dict[str, IER], acl: AccessControlList, @@ -230,7 +230,7 @@ def apply_iers( def apply_node_pol( - nodes: Dict[str, NodeType], + nodes: Dict[str, NodeUnion], node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]], step: int, ): diff --git a/src/primaite/pol/red_agent_pol.py b/src/primaite/pol/red_agent_pol.py index 6a060373..b23992e7 100644 --- a/src/primaite/pol/red_agent_pol.py +++ b/src/primaite/pol/red_agent_pol.py @@ -5,7 +5,7 @@ from typing import Dict from networkx import MultiGraph, shortest_path from primaite.acl.access_control_list import AccessControlList -from primaite.common.custom_typing import NodeType +from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( HardwareState, NodePOLInitiator, @@ -24,7 +24,7 @@ _VERBOSE = False def apply_red_agent_iers( network: MultiGraph, - nodes: Dict[str, NodeType], + nodes: Dict[str, NodeUnion], links: Dict[str, Link], iers: Dict[str, IER], acl: AccessControlList, @@ -221,7 +221,7 @@ def apply_red_agent_iers( def apply_red_agent_node_pol( - nodes: NodeType, + nodes: Dict[str, NodeUnion], iers: Dict[str, IER], node_pol: Dict[str, NodeStateInstructionRed], step: int, @@ -256,7 +256,7 @@ def apply_red_agent_node_pol( if step >= start_step and step <= stop_step: # continue -------------------------- - target_node: NodeType = nodes[target_node_id] + target_node: NodeUnion = nodes[target_node_id] # Based the action taken on the initiator type if initiator == NodePOLInitiator.DIRECT: