Fix types according to mypy

This commit is contained in:
Marek Wolan
2023-07-14 16:38:55 +01:00
parent 31fedb945e
commit 2bb71623fa
8 changed files with 32 additions and 22 deletions

View File

@@ -72,7 +72,7 @@ class _LevelFormatter(Formatter):
Credit to: https://stackoverflow.com/a/68154386
"""
def __init__(self, formats: Dict[int, str], **kwargs: Any) -> str:
def __init__(self, formats: Dict[int, str], **kwargs: Any) -> None:
super().__init__()
if "fmt" in kwargs:

View File

@@ -1,6 +1,6 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""A class that implements the access control list implementation for the network."""
from typing import Dict, Optional
from typing import Dict
from primaite.acl.acl_rule import ACLRule
@@ -76,9 +76,7 @@ class AccessControlList:
hash_value = hash(new_rule)
self.acl[hash_value] = new_rule
def remove_rule(
self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
) -> Optional[int]:
def remove_rule(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None:
"""
Removes a rule.

View File

@@ -34,7 +34,7 @@ def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]:
else:
property_action = "NONE"
new_action = [action[0], action_node_property, property_action, action[3]]
new_action: list[Union[int, str]] = [action[0], action_node_property, property_action, action[3]]
return new_action

View File

@@ -88,7 +88,7 @@ class TrainingConfig:
session_type: SessionType = SessionType.TRAIN
"The type of PrimAITE session to run"
load_agent: str = False
load_agent: bool = False
"Determine whether to load an agent from file"
agent_load_file: Optional[str] = None
@@ -194,7 +194,7 @@ class TrainingConfig:
"The random number generator seed to be used while training the agent"
@classmethod
def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig:
def from_dict(cls, config_dict: Dict[str, Any]) -> TrainingConfig:
"""
Create an instance of TrainingConfig from a dict.
@@ -211,9 +211,11 @@ class TrainingConfig:
"hard_coded_agent_view": HardCodedAgentView,
}
# convert the string representation of enums into the actual enum values themselves?
for key, value in field_enum_map.items():
if key in config_dict:
config_dict[key] = value[config_dict[key]]
return TrainingConfig(**config_dict)
def to_dict(self, json_serializable: bool = True) -> Dict:
@@ -335,7 +337,7 @@ def convert_legacy_training_config_dict(
return config_dict
def _get_new_key_from_legacy(legacy_key: str) -> str:
def _get_new_key_from_legacy(legacy_key: str) -> Optional[str]:
"""
Maps legacy training config keys to the new format keys.

View File

@@ -1,6 +1,6 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""Implements reward function."""
from typing import Dict, TYPE_CHECKING
from typing import Dict, TYPE_CHECKING, Union
from primaite import getLogger
from primaite.common.custom_typing import NodeUnion
@@ -152,7 +152,10 @@ def score_node_operating_state(
def score_node_os_state(
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig"
final_node: Union[ActiveNode, ServiceNode],
initial_node: Union[ActiveNode, ServiceNode],
reference_node: Union[ActiveNode, ServiceNode],
config_values: "TrainingConfig",
) -> float:
"""
Calculates score relating to the Software State of a node.
@@ -205,7 +208,7 @@ def score_node_os_state(
def score_node_service_state(
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig"
final_node: ServiceNode, initial_node: ServiceNode, reference_node: ServiceNode, config_values: "TrainingConfig"
) -> float:
"""
Calculates score relating to the service state(s) of a node.
@@ -279,7 +282,10 @@ def score_node_service_state(
def score_node_file_system(
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig"
final_node: Union[ActiveNode, ServiceNode],
initial_node: Union[ActiveNode, ServiceNode],
reference_node: Union[ActiveNode, ServiceNode],
config_values: "TrainingConfig",
) -> float:
"""
Calculates score relating to the file system state of a node.

View File

@@ -1,6 +1,6 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""Implements Pattern of Life on the network (nodes and links)."""
from typing import Dict, Union
from typing import Dict
from networkx import MultiGraph, shortest_path
@@ -10,7 +10,6 @@ from primaite.common.enums import HardwareState, NodePOLType, NodeType, Software
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from primaite.nodes.service_node import ServiceNode
from primaite.pol.ier import IER
@@ -65,6 +64,8 @@ def apply_iers(
dest_node = nodes[dest_node_id]
# 1. Check the source node situation
# TODO: should be using isinstance rather than checking node type attribute. IE. just because it's a switch
# doesn't mean it has a software state? It could be a PassiveNode or ActiveNode
if source_node.node_type == NodeType.SWITCH:
# It's a switch
if (
@@ -215,7 +216,7 @@ def apply_iers(
def apply_node_pol(
nodes: Dict[str, NodeUnion],
node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]],
node_pol: Dict[str, NodeStateInstructionGreen],
step: int,
) -> None:
"""

View File

@@ -74,6 +74,9 @@ def apply_red_agent_iers(
pass
else:
# It's not a switch or an actuator (so active node)
# TODO: this occurs after ruling out the possibility that the node is a switch or an actuator, but it
# could still be a passive/active node, therefore it won't have a hardware_state. The logic here needs
# to change according to duck typing.
if source_node.hardware_state == HardwareState.ON:
if source_node.has_service(protocol):
# Red agents IERs can only be valid if the source service is in a compromised state

View File

@@ -1,7 +1,7 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""The Transaction class."""
from datetime import datetime
from typing import List, Tuple, TYPE_CHECKING, Union
from typing import List, Optional, Tuple, TYPE_CHECKING, Union
from primaite.common.enums import AgentIdentifier
@@ -31,15 +31,15 @@ class Transaction(object):
"The step number"
self.obs_space: "spaces.Space" = None
"The observation space (pre)"
self.obs_space_pre: Union["np.ndarray", Tuple["np.ndarray"]] = None
self.obs_space_pre: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None
"The observation space before any actions are taken"
self.obs_space_post: Union["np.ndarray", Tuple["np.ndarray"]] = None
self.obs_space_post: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None
"The observation space after any actions are taken"
self.reward: float = None
self.reward: Optional[float] = None
"The reward value"
self.action_space: int = None
self.action_space: Optional[int] = None
"The action space invoked by the agent"
self.obs_space_description: List[str] = None
self.obs_space_description: Optional[List[str]] = None
"The env observation space description"
def as_csv_data(self) -> Tuple[List, List]: