diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index 950ceb3d..dacd5c12 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -72,7 +72,7 @@ class _LevelFormatter(Formatter): Credit to: https://stackoverflow.com/a/68154386 """ - def __init__(self, formats: Dict[int, str], **kwargs: Any) -> str: + def __init__(self, formats: Dict[int, str], **kwargs: Any) -> None: super().__init__() if "fmt" in kwargs: diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index f7e65bd4..d4d843e3 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -1,6 +1,6 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """A class that implements the access control list implementation for the network.""" -from typing import Dict, Optional +from typing import Dict from primaite.acl.acl_rule import ACLRule @@ -76,9 +76,7 @@ class AccessControlList: hash_value = hash(new_rule) self.acl[hash_value] = new_rule - def remove_rule( - self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str - ) -> Optional[int]: + def remove_rule(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None: """ Removes a rule. diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index 2e6b3f0c..353978f1 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -34,7 +34,7 @@ def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]: else: property_action = "NONE" - new_action = [action[0], action_node_property, property_action, action[3]] + new_action: list[Union[int, str]] = [action[0], action_node_property, property_action, action[3]] return new_action diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 08da043c..628e2818 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -88,7 +88,7 @@ class TrainingConfig: session_type: SessionType = SessionType.TRAIN "The type of PrimAITE session to run" - load_agent: str = False + load_agent: bool = False "Determine whether to load an agent from file" agent_load_file: Optional[str] = None @@ -194,7 +194,7 @@ class TrainingConfig: "The random number generator seed to be used while training the agent" @classmethod - def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig: + def from_dict(cls, config_dict: Dict[str, Any]) -> TrainingConfig: """ Create an instance of TrainingConfig from a dict. @@ -211,9 +211,11 @@ class TrainingConfig: "hard_coded_agent_view": HardCodedAgentView, } + # convert the string representation of enums into the actual enum values themselves? for key, value in field_enum_map.items(): if key in config_dict: config_dict[key] = value[config_dict[key]] + return TrainingConfig(**config_dict) def to_dict(self, json_serializable: bool = True) -> Dict: @@ -335,7 +337,7 @@ def convert_legacy_training_config_dict( return config_dict -def _get_new_key_from_legacy(legacy_key: str) -> str: +def _get_new_key_from_legacy(legacy_key: str) -> Optional[str]: """ Maps legacy training config keys to the new format keys. diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index c9acd921..a0efac4d 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -1,6 +1,6 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """Implements reward function.""" -from typing import Dict, TYPE_CHECKING +from typing import Dict, TYPE_CHECKING, Union from primaite import getLogger from primaite.common.custom_typing import NodeUnion @@ -152,7 +152,10 @@ def score_node_operating_state( def score_node_os_state( - final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig" + final_node: Union[ActiveNode, ServiceNode], + initial_node: Union[ActiveNode, ServiceNode], + reference_node: Union[ActiveNode, ServiceNode], + config_values: "TrainingConfig", ) -> float: """ Calculates score relating to the Software State of a node. @@ -205,7 +208,7 @@ def score_node_os_state( def score_node_service_state( - final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig" + final_node: ServiceNode, initial_node: ServiceNode, reference_node: ServiceNode, config_values: "TrainingConfig" ) -> float: """ Calculates score relating to the service state(s) of a node. @@ -279,7 +282,10 @@ def score_node_service_state( def score_node_file_system( - final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig" + final_node: Union[ActiveNode, ServiceNode], + initial_node: Union[ActiveNode, ServiceNode], + reference_node: Union[ActiveNode, ServiceNode], + config_values: "TrainingConfig", ) -> float: """ Calculates score relating to the file system state of a node. diff --git a/src/primaite/pol/green_pol.py b/src/primaite/pol/green_pol.py index 89bda871..7df87590 100644 --- a/src/primaite/pol/green_pol.py +++ b/src/primaite/pol/green_pol.py @@ -1,6 +1,6 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """Implements Pattern of Life on the network (nodes and links).""" -from typing import Dict, Union +from typing import Dict from networkx import MultiGraph, shortest_path @@ -10,7 +10,6 @@ from primaite.common.enums import HardwareState, NodePOLType, NodeType, Software from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen -from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from primaite.nodes.service_node import ServiceNode from primaite.pol.ier import IER @@ -65,6 +64,8 @@ def apply_iers( dest_node = nodes[dest_node_id] # 1. Check the source node situation + # TODO: should be using isinstance rather than checking node type attribute. IE. just because it's a switch + # doesn't mean it has a software state? It could be a PassiveNode or ActiveNode if source_node.node_type == NodeType.SWITCH: # It's a switch if ( @@ -215,7 +216,7 @@ def apply_iers( def apply_node_pol( nodes: Dict[str, NodeUnion], - node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]], + node_pol: Dict[str, NodeStateInstructionGreen], step: int, ) -> None: """ diff --git a/src/primaite/pol/red_agent_pol.py b/src/primaite/pol/red_agent_pol.py index 09c25fa1..c9f75850 100644 --- a/src/primaite/pol/red_agent_pol.py +++ b/src/primaite/pol/red_agent_pol.py @@ -74,6 +74,9 @@ def apply_red_agent_iers( pass else: # It's not a switch or an actuator (so active node) + # TODO: this occurs after ruling out the possibility that the node is a switch or an actuator, but it + # could still be a passive/active node, therefore it won't have a hardware_state. The logic here needs + # to change according to duck typing. if source_node.hardware_state == HardwareState.ON: if source_node.has_service(protocol): # Red agents IERs can only be valid if the source service is in a compromised state diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index 67f67e43..09ec2cec 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -1,7 +1,7 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """The Transaction class.""" from datetime import datetime -from typing import List, Tuple, TYPE_CHECKING, Union +from typing import List, Optional, Tuple, TYPE_CHECKING, Union from primaite.common.enums import AgentIdentifier @@ -31,15 +31,15 @@ class Transaction(object): "The step number" self.obs_space: "spaces.Space" = None "The observation space (pre)" - self.obs_space_pre: Union["np.ndarray", Tuple["np.ndarray"]] = None + self.obs_space_pre: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None "The observation space before any actions are taken" - self.obs_space_post: Union["np.ndarray", Tuple["np.ndarray"]] = None + self.obs_space_post: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None "The observation space after any actions are taken" - self.reward: float = None + self.reward: Optional[float] = None "The reward value" - self.action_space: int = None + self.action_space: Optional[int] = None "The action space invoked by the agent" - self.obs_space_description: List[str] = None + self.obs_space_description: Optional[List[str]] = None "The env observation space description" def as_csv_data(self) -> Tuple[List, List]: