Fix types according to mypy

This commit is contained in:
Marek Wolan
2023-07-14 16:38:55 +01:00
parent 31fedb945e
commit 2bb71623fa
8 changed files with 32 additions and 22 deletions

View File

@@ -72,7 +72,7 @@ class _LevelFormatter(Formatter):
Credit to: https://stackoverflow.com/a/68154386 Credit to: https://stackoverflow.com/a/68154386
""" """
def __init__(self, formats: Dict[int, str], **kwargs: Any) -> str: def __init__(self, formats: Dict[int, str], **kwargs: Any) -> None:
super().__init__() super().__init__()
if "fmt" in kwargs: if "fmt" in kwargs:

View File

@@ -1,6 +1,6 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""A class that implements the access control list implementation for the network.""" """A class that implements the access control list implementation for the network."""
from typing import Dict, Optional from typing import Dict
from primaite.acl.acl_rule import ACLRule from primaite.acl.acl_rule import ACLRule
@@ -76,9 +76,7 @@ class AccessControlList:
hash_value = hash(new_rule) hash_value = hash(new_rule)
self.acl[hash_value] = new_rule self.acl[hash_value] = new_rule
def remove_rule( def remove_rule(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None:
self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
) -> Optional[int]:
""" """
Removes a rule. Removes a rule.

View File

@@ -34,7 +34,7 @@ def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]:
else: else:
property_action = "NONE" property_action = "NONE"
new_action = [action[0], action_node_property, property_action, action[3]] new_action: list[Union[int, str]] = [action[0], action_node_property, property_action, action[3]]
return new_action return new_action

View File

@@ -88,7 +88,7 @@ class TrainingConfig:
session_type: SessionType = SessionType.TRAIN session_type: SessionType = SessionType.TRAIN
"The type of PrimAITE session to run" "The type of PrimAITE session to run"
load_agent: str = False load_agent: bool = False
"Determine whether to load an agent from file" "Determine whether to load an agent from file"
agent_load_file: Optional[str] = None agent_load_file: Optional[str] = None
@@ -194,7 +194,7 @@ class TrainingConfig:
"The random number generator seed to be used while training the agent" "The random number generator seed to be used while training the agent"
@classmethod @classmethod
def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig: def from_dict(cls, config_dict: Dict[str, Any]) -> TrainingConfig:
""" """
Create an instance of TrainingConfig from a dict. Create an instance of TrainingConfig from a dict.
@@ -211,9 +211,11 @@ class TrainingConfig:
"hard_coded_agent_view": HardCodedAgentView, "hard_coded_agent_view": HardCodedAgentView,
} }
# convert the string representation of enums into the actual enum values themselves?
for key, value in field_enum_map.items(): for key, value in field_enum_map.items():
if key in config_dict: if key in config_dict:
config_dict[key] = value[config_dict[key]] config_dict[key] = value[config_dict[key]]
return TrainingConfig(**config_dict) return TrainingConfig(**config_dict)
def to_dict(self, json_serializable: bool = True) -> Dict: def to_dict(self, json_serializable: bool = True) -> Dict:
@@ -335,7 +337,7 @@ def convert_legacy_training_config_dict(
return config_dict return config_dict
def _get_new_key_from_legacy(legacy_key: str) -> str: def _get_new_key_from_legacy(legacy_key: str) -> Optional[str]:
""" """
Maps legacy training config keys to the new format keys. Maps legacy training config keys to the new format keys.

View File

@@ -1,6 +1,6 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""Implements reward function.""" """Implements reward function."""
from typing import Dict, TYPE_CHECKING from typing import Dict, TYPE_CHECKING, Union
from primaite import getLogger from primaite import getLogger
from primaite.common.custom_typing import NodeUnion from primaite.common.custom_typing import NodeUnion
@@ -152,7 +152,10 @@ def score_node_operating_state(
def score_node_os_state( def score_node_os_state(
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig" final_node: Union[ActiveNode, ServiceNode],
initial_node: Union[ActiveNode, ServiceNode],
reference_node: Union[ActiveNode, ServiceNode],
config_values: "TrainingConfig",
) -> float: ) -> float:
""" """
Calculates score relating to the Software State of a node. Calculates score relating to the Software State of a node.
@@ -205,7 +208,7 @@ def score_node_os_state(
def score_node_service_state( def score_node_service_state(
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig" final_node: ServiceNode, initial_node: ServiceNode, reference_node: ServiceNode, config_values: "TrainingConfig"
) -> float: ) -> float:
""" """
Calculates score relating to the service state(s) of a node. Calculates score relating to the service state(s) of a node.
@@ -279,7 +282,10 @@ def score_node_service_state(
def score_node_file_system( def score_node_file_system(
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig" final_node: Union[ActiveNode, ServiceNode],
initial_node: Union[ActiveNode, ServiceNode],
reference_node: Union[ActiveNode, ServiceNode],
config_values: "TrainingConfig",
) -> float: ) -> float:
""" """
Calculates score relating to the file system state of a node. Calculates score relating to the file system state of a node.

View File

@@ -1,6 +1,6 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""Implements Pattern of Life on the network (nodes and links).""" """Implements Pattern of Life on the network (nodes and links)."""
from typing import Dict, Union from typing import Dict
from networkx import MultiGraph, shortest_path from networkx import MultiGraph, shortest_path
@@ -10,7 +10,6 @@ from primaite.common.enums import HardwareState, NodePOLType, NodeType, Software
from primaite.links.link import Link from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode from primaite.nodes.active_node import ActiveNode
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from primaite.nodes.service_node import ServiceNode from primaite.nodes.service_node import ServiceNode
from primaite.pol.ier import IER from primaite.pol.ier import IER
@@ -65,6 +64,8 @@ def apply_iers(
dest_node = nodes[dest_node_id] dest_node = nodes[dest_node_id]
# 1. Check the source node situation # 1. Check the source node situation
# TODO: should be using isinstance rather than checking node type attribute. IE. just because it's a switch
# doesn't mean it has a software state? It could be a PassiveNode or ActiveNode
if source_node.node_type == NodeType.SWITCH: if source_node.node_type == NodeType.SWITCH:
# It's a switch # It's a switch
if ( if (
@@ -215,7 +216,7 @@ def apply_iers(
def apply_node_pol( def apply_node_pol(
nodes: Dict[str, NodeUnion], nodes: Dict[str, NodeUnion],
node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]], node_pol: Dict[str, NodeStateInstructionGreen],
step: int, step: int,
) -> None: ) -> None:
""" """

View File

@@ -74,6 +74,9 @@ def apply_red_agent_iers(
pass pass
else: else:
# It's not a switch or an actuator (so active node) # It's not a switch or an actuator (so active node)
# TODO: this occurs after ruling out the possibility that the node is a switch or an actuator, but it
# could still be a passive/active node, therefore it won't have a hardware_state. The logic here needs
# to change according to duck typing.
if source_node.hardware_state == HardwareState.ON: if source_node.hardware_state == HardwareState.ON:
if source_node.has_service(protocol): if source_node.has_service(protocol):
# Red agents IERs can only be valid if the source service is in a compromised state # Red agents IERs can only be valid if the source service is in a compromised state

View File

@@ -1,7 +1,7 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""The Transaction class.""" """The Transaction class."""
from datetime import datetime from datetime import datetime
from typing import List, Tuple, TYPE_CHECKING, Union from typing import List, Optional, Tuple, TYPE_CHECKING, Union
from primaite.common.enums import AgentIdentifier from primaite.common.enums import AgentIdentifier
@@ -31,15 +31,15 @@ class Transaction(object):
"The step number" "The step number"
self.obs_space: "spaces.Space" = None self.obs_space: "spaces.Space" = None
"The observation space (pre)" "The observation space (pre)"
self.obs_space_pre: Union["np.ndarray", Tuple["np.ndarray"]] = None self.obs_space_pre: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None
"The observation space before any actions are taken" "The observation space before any actions are taken"
self.obs_space_post: Union["np.ndarray", Tuple["np.ndarray"]] = None self.obs_space_post: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None
"The observation space after any actions are taken" "The observation space after any actions are taken"
self.reward: float = None self.reward: Optional[float] = None
"The reward value" "The reward value"
self.action_space: int = None self.action_space: Optional[int] = None
"The action space invoked by the agent" "The action space invoked by the agent"
self.obs_space_description: List[str] = None self.obs_space_description: Optional[List[str]] = None
"The env observation space description" "The env observation space description"
def as_csv_data(self) -> Tuple[List, List]: def as_csv_data(self) -> Tuple[List, List]: