Fix types according to mypy
This commit is contained in:
@@ -72,7 +72,7 @@ class _LevelFormatter(Formatter):
|
||||
Credit to: https://stackoverflow.com/a/68154386
|
||||
"""
|
||||
|
||||
def __init__(self, formats: Dict[int, str], **kwargs: Any) -> str:
|
||||
def __init__(self, formats: Dict[int, str], **kwargs: Any) -> None:
|
||||
super().__init__()
|
||||
|
||||
if "fmt" in kwargs:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""A class that implements the access control list implementation for the network."""
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict
|
||||
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
|
||||
@@ -76,9 +76,7 @@ class AccessControlList:
|
||||
hash_value = hash(new_rule)
|
||||
self.acl[hash_value] = new_rule
|
||||
|
||||
def remove_rule(
|
||||
self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
|
||||
) -> Optional[int]:
|
||||
def remove_rule(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None:
|
||||
"""
|
||||
Removes a rule.
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]:
|
||||
else:
|
||||
property_action = "NONE"
|
||||
|
||||
new_action = [action[0], action_node_property, property_action, action[3]]
|
||||
new_action: list[Union[int, str]] = [action[0], action_node_property, property_action, action[3]]
|
||||
return new_action
|
||||
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ class TrainingConfig:
|
||||
session_type: SessionType = SessionType.TRAIN
|
||||
"The type of PrimAITE session to run"
|
||||
|
||||
load_agent: str = False
|
||||
load_agent: bool = False
|
||||
"Determine whether to load an agent from file"
|
||||
|
||||
agent_load_file: Optional[str] = None
|
||||
@@ -194,7 +194,7 @@ class TrainingConfig:
|
||||
"The random number generator seed to be used while training the agent"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig:
|
||||
def from_dict(cls, config_dict: Dict[str, Any]) -> TrainingConfig:
|
||||
"""
|
||||
Create an instance of TrainingConfig from a dict.
|
||||
|
||||
@@ -211,9 +211,11 @@ class TrainingConfig:
|
||||
"hard_coded_agent_view": HardCodedAgentView,
|
||||
}
|
||||
|
||||
# convert the string representation of enums into the actual enum values themselves?
|
||||
for key, value in field_enum_map.items():
|
||||
if key in config_dict:
|
||||
config_dict[key] = value[config_dict[key]]
|
||||
|
||||
return TrainingConfig(**config_dict)
|
||||
|
||||
def to_dict(self, json_serializable: bool = True) -> Dict:
|
||||
@@ -335,7 +337,7 @@ def convert_legacy_training_config_dict(
|
||||
return config_dict
|
||||
|
||||
|
||||
def _get_new_key_from_legacy(legacy_key: str) -> str:
|
||||
def _get_new_key_from_legacy(legacy_key: str) -> Optional[str]:
|
||||
"""
|
||||
Maps legacy training config keys to the new format keys.
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""Implements reward function."""
|
||||
from typing import Dict, TYPE_CHECKING
|
||||
from typing import Dict, TYPE_CHECKING, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
@@ -152,7 +152,10 @@ def score_node_operating_state(
|
||||
|
||||
|
||||
def score_node_os_state(
|
||||
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig"
|
||||
final_node: Union[ActiveNode, ServiceNode],
|
||||
initial_node: Union[ActiveNode, ServiceNode],
|
||||
reference_node: Union[ActiveNode, ServiceNode],
|
||||
config_values: "TrainingConfig",
|
||||
) -> float:
|
||||
"""
|
||||
Calculates score relating to the Software State of a node.
|
||||
@@ -205,7 +208,7 @@ def score_node_os_state(
|
||||
|
||||
|
||||
def score_node_service_state(
|
||||
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig"
|
||||
final_node: ServiceNode, initial_node: ServiceNode, reference_node: ServiceNode, config_values: "TrainingConfig"
|
||||
) -> float:
|
||||
"""
|
||||
Calculates score relating to the service state(s) of a node.
|
||||
@@ -279,7 +282,10 @@ def score_node_service_state(
|
||||
|
||||
|
||||
def score_node_file_system(
|
||||
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig"
|
||||
final_node: Union[ActiveNode, ServiceNode],
|
||||
initial_node: Union[ActiveNode, ServiceNode],
|
||||
reference_node: Union[ActiveNode, ServiceNode],
|
||||
config_values: "TrainingConfig",
|
||||
) -> float:
|
||||
"""
|
||||
Calculates score relating to the file system state of a node.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""Implements Pattern of Life on the network (nodes and links)."""
|
||||
from typing import Dict, Union
|
||||
from typing import Dict
|
||||
|
||||
from networkx import MultiGraph, shortest_path
|
||||
|
||||
@@ -10,7 +10,6 @@ from primaite.common.enums import HardwareState, NodePOLType, NodeType, Software
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
@@ -65,6 +64,8 @@ def apply_iers(
|
||||
dest_node = nodes[dest_node_id]
|
||||
|
||||
# 1. Check the source node situation
|
||||
# TODO: should be using isinstance rather than checking node type attribute. IE. just because it's a switch
|
||||
# doesn't mean it has a software state? It could be a PassiveNode or ActiveNode
|
||||
if source_node.node_type == NodeType.SWITCH:
|
||||
# It's a switch
|
||||
if (
|
||||
@@ -215,7 +216,7 @@ def apply_iers(
|
||||
|
||||
def apply_node_pol(
|
||||
nodes: Dict[str, NodeUnion],
|
||||
node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]],
|
||||
node_pol: Dict[str, NodeStateInstructionGreen],
|
||||
step: int,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
@@ -74,6 +74,9 @@ def apply_red_agent_iers(
|
||||
pass
|
||||
else:
|
||||
# It's not a switch or an actuator (so active node)
|
||||
# TODO: this occurs after ruling out the possibility that the node is a switch or an actuator, but it
|
||||
# could still be a passive/active node, therefore it won't have a hardware_state. The logic here needs
|
||||
# to change according to duck typing.
|
||||
if source_node.hardware_state == HardwareState.ON:
|
||||
if source_node.has_service(protocol):
|
||||
# Red agents IERs can only be valid if the source service is in a compromised state
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""The Transaction class."""
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple, TYPE_CHECKING, Union
|
||||
from typing import List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from primaite.common.enums import AgentIdentifier
|
||||
|
||||
@@ -31,15 +31,15 @@ class Transaction(object):
|
||||
"The step number"
|
||||
self.obs_space: "spaces.Space" = None
|
||||
"The observation space (pre)"
|
||||
self.obs_space_pre: Union["np.ndarray", Tuple["np.ndarray"]] = None
|
||||
self.obs_space_pre: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None
|
||||
"The observation space before any actions are taken"
|
||||
self.obs_space_post: Union["np.ndarray", Tuple["np.ndarray"]] = None
|
||||
self.obs_space_post: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None
|
||||
"The observation space after any actions are taken"
|
||||
self.reward: float = None
|
||||
self.reward: Optional[float] = None
|
||||
"The reward value"
|
||||
self.action_space: int = None
|
||||
self.action_space: Optional[int] = None
|
||||
"The action space invoked by the agent"
|
||||
self.obs_space_description: List[str] = None
|
||||
self.obs_space_description: Optional[List[str]] = None
|
||||
"The env observation space description"
|
||||
|
||||
def as_csv_data(self) -> Tuple[List, List]:
|
||||
|
||||
Reference in New Issue
Block a user