Fix types according to mypy
This commit is contained in:
@@ -72,7 +72,7 @@ class _LevelFormatter(Formatter):
|
|||||||
Credit to: https://stackoverflow.com/a/68154386
|
Credit to: https://stackoverflow.com/a/68154386
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, formats: Dict[int, str], **kwargs: Any) -> str:
|
def __init__(self, formats: Dict[int, str], **kwargs: Any) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if "fmt" in kwargs:
|
if "fmt" in kwargs:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||||
"""A class that implements the access control list implementation for the network."""
|
"""A class that implements the access control list implementation for the network."""
|
||||||
from typing import Dict, Optional
|
from typing import Dict
|
||||||
|
|
||||||
from primaite.acl.acl_rule import ACLRule
|
from primaite.acl.acl_rule import ACLRule
|
||||||
|
|
||||||
@@ -76,9 +76,7 @@ class AccessControlList:
|
|||||||
hash_value = hash(new_rule)
|
hash_value = hash(new_rule)
|
||||||
self.acl[hash_value] = new_rule
|
self.acl[hash_value] = new_rule
|
||||||
|
|
||||||
def remove_rule(
|
def remove_rule(self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None:
|
||||||
self, _permission: str, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
|
|
||||||
) -> Optional[int]:
|
|
||||||
"""
|
"""
|
||||||
Removes a rule.
|
Removes a rule.
|
||||||
|
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]:
|
|||||||
else:
|
else:
|
||||||
property_action = "NONE"
|
property_action = "NONE"
|
||||||
|
|
||||||
new_action = [action[0], action_node_property, property_action, action[3]]
|
new_action: list[Union[int, str]] = [action[0], action_node_property, property_action, action[3]]
|
||||||
return new_action
|
return new_action
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ class TrainingConfig:
|
|||||||
session_type: SessionType = SessionType.TRAIN
|
session_type: SessionType = SessionType.TRAIN
|
||||||
"The type of PrimAITE session to run"
|
"The type of PrimAITE session to run"
|
||||||
|
|
||||||
load_agent: str = False
|
load_agent: bool = False
|
||||||
"Determine whether to load an agent from file"
|
"Determine whether to load an agent from file"
|
||||||
|
|
||||||
agent_load_file: Optional[str] = None
|
agent_load_file: Optional[str] = None
|
||||||
@@ -194,7 +194,7 @@ class TrainingConfig:
|
|||||||
"The random number generator seed to be used while training the agent"
|
"The random number generator seed to be used while training the agent"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig:
|
def from_dict(cls, config_dict: Dict[str, Any]) -> TrainingConfig:
|
||||||
"""
|
"""
|
||||||
Create an instance of TrainingConfig from a dict.
|
Create an instance of TrainingConfig from a dict.
|
||||||
|
|
||||||
@@ -211,9 +211,11 @@ class TrainingConfig:
|
|||||||
"hard_coded_agent_view": HardCodedAgentView,
|
"hard_coded_agent_view": HardCodedAgentView,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# convert the string representation of enums into the actual enum values themselves?
|
||||||
for key, value in field_enum_map.items():
|
for key, value in field_enum_map.items():
|
||||||
if key in config_dict:
|
if key in config_dict:
|
||||||
config_dict[key] = value[config_dict[key]]
|
config_dict[key] = value[config_dict[key]]
|
||||||
|
|
||||||
return TrainingConfig(**config_dict)
|
return TrainingConfig(**config_dict)
|
||||||
|
|
||||||
def to_dict(self, json_serializable: bool = True) -> Dict:
|
def to_dict(self, json_serializable: bool = True) -> Dict:
|
||||||
@@ -335,7 +337,7 @@ def convert_legacy_training_config_dict(
|
|||||||
return config_dict
|
return config_dict
|
||||||
|
|
||||||
|
|
||||||
def _get_new_key_from_legacy(legacy_key: str) -> str:
|
def _get_new_key_from_legacy(legacy_key: str) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Maps legacy training config keys to the new format keys.
|
Maps legacy training config keys to the new format keys.
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||||
"""Implements reward function."""
|
"""Implements reward function."""
|
||||||
from typing import Dict, TYPE_CHECKING
|
from typing import Dict, TYPE_CHECKING, Union
|
||||||
|
|
||||||
from primaite import getLogger
|
from primaite import getLogger
|
||||||
from primaite.common.custom_typing import NodeUnion
|
from primaite.common.custom_typing import NodeUnion
|
||||||
@@ -152,7 +152,10 @@ def score_node_operating_state(
|
|||||||
|
|
||||||
|
|
||||||
def score_node_os_state(
|
def score_node_os_state(
|
||||||
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig"
|
final_node: Union[ActiveNode, ServiceNode],
|
||||||
|
initial_node: Union[ActiveNode, ServiceNode],
|
||||||
|
reference_node: Union[ActiveNode, ServiceNode],
|
||||||
|
config_values: "TrainingConfig",
|
||||||
) -> float:
|
) -> float:
|
||||||
"""
|
"""
|
||||||
Calculates score relating to the Software State of a node.
|
Calculates score relating to the Software State of a node.
|
||||||
@@ -205,7 +208,7 @@ def score_node_os_state(
|
|||||||
|
|
||||||
|
|
||||||
def score_node_service_state(
|
def score_node_service_state(
|
||||||
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig"
|
final_node: ServiceNode, initial_node: ServiceNode, reference_node: ServiceNode, config_values: "TrainingConfig"
|
||||||
) -> float:
|
) -> float:
|
||||||
"""
|
"""
|
||||||
Calculates score relating to the service state(s) of a node.
|
Calculates score relating to the service state(s) of a node.
|
||||||
@@ -279,7 +282,10 @@ def score_node_service_state(
|
|||||||
|
|
||||||
|
|
||||||
def score_node_file_system(
|
def score_node_file_system(
|
||||||
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig"
|
final_node: Union[ActiveNode, ServiceNode],
|
||||||
|
initial_node: Union[ActiveNode, ServiceNode],
|
||||||
|
reference_node: Union[ActiveNode, ServiceNode],
|
||||||
|
config_values: "TrainingConfig",
|
||||||
) -> float:
|
) -> float:
|
||||||
"""
|
"""
|
||||||
Calculates score relating to the file system state of a node.
|
Calculates score relating to the file system state of a node.
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||||
"""Implements Pattern of Life on the network (nodes and links)."""
|
"""Implements Pattern of Life on the network (nodes and links)."""
|
||||||
from typing import Dict, Union
|
from typing import Dict
|
||||||
|
|
||||||
from networkx import MultiGraph, shortest_path
|
from networkx import MultiGraph, shortest_path
|
||||||
|
|
||||||
@@ -10,7 +10,6 @@ from primaite.common.enums import HardwareState, NodePOLType, NodeType, Software
|
|||||||
from primaite.links.link import Link
|
from primaite.links.link import Link
|
||||||
from primaite.nodes.active_node import ActiveNode
|
from primaite.nodes.active_node import ActiveNode
|
||||||
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
|
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
|
||||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
|
||||||
from primaite.nodes.service_node import ServiceNode
|
from primaite.nodes.service_node import ServiceNode
|
||||||
from primaite.pol.ier import IER
|
from primaite.pol.ier import IER
|
||||||
|
|
||||||
@@ -65,6 +64,8 @@ def apply_iers(
|
|||||||
dest_node = nodes[dest_node_id]
|
dest_node = nodes[dest_node_id]
|
||||||
|
|
||||||
# 1. Check the source node situation
|
# 1. Check the source node situation
|
||||||
|
# TODO: should be using isinstance rather than checking node type attribute. IE. just because it's a switch
|
||||||
|
# doesn't mean it has a software state? It could be a PassiveNode or ActiveNode
|
||||||
if source_node.node_type == NodeType.SWITCH:
|
if source_node.node_type == NodeType.SWITCH:
|
||||||
# It's a switch
|
# It's a switch
|
||||||
if (
|
if (
|
||||||
@@ -215,7 +216,7 @@ def apply_iers(
|
|||||||
|
|
||||||
def apply_node_pol(
|
def apply_node_pol(
|
||||||
nodes: Dict[str, NodeUnion],
|
nodes: Dict[str, NodeUnion],
|
||||||
node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]],
|
node_pol: Dict[str, NodeStateInstructionGreen],
|
||||||
step: int,
|
step: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -74,6 +74,9 @@ def apply_red_agent_iers(
|
|||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# It's not a switch or an actuator (so active node)
|
# It's not a switch or an actuator (so active node)
|
||||||
|
# TODO: this occurs after ruling out the possibility that the node is a switch or an actuator, but it
|
||||||
|
# could still be a passive/active node, therefore it won't have a hardware_state. The logic here needs
|
||||||
|
# to change according to duck typing.
|
||||||
if source_node.hardware_state == HardwareState.ON:
|
if source_node.hardware_state == HardwareState.ON:
|
||||||
if source_node.has_service(protocol):
|
if source_node.has_service(protocol):
|
||||||
# Red agents IERs can only be valid if the source service is in a compromised state
|
# Red agents IERs can only be valid if the source service is in a compromised state
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||||
"""The Transaction class."""
|
"""The Transaction class."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Tuple, TYPE_CHECKING, Union
|
from typing import List, Optional, Tuple, TYPE_CHECKING, Union
|
||||||
|
|
||||||
from primaite.common.enums import AgentIdentifier
|
from primaite.common.enums import AgentIdentifier
|
||||||
|
|
||||||
@@ -31,15 +31,15 @@ class Transaction(object):
|
|||||||
"The step number"
|
"The step number"
|
||||||
self.obs_space: "spaces.Space" = None
|
self.obs_space: "spaces.Space" = None
|
||||||
"The observation space (pre)"
|
"The observation space (pre)"
|
||||||
self.obs_space_pre: Union["np.ndarray", Tuple["np.ndarray"]] = None
|
self.obs_space_pre: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None
|
||||||
"The observation space before any actions are taken"
|
"The observation space before any actions are taken"
|
||||||
self.obs_space_post: Union["np.ndarray", Tuple["np.ndarray"]] = None
|
self.obs_space_post: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None
|
||||||
"The observation space after any actions are taken"
|
"The observation space after any actions are taken"
|
||||||
self.reward: float = None
|
self.reward: Optional[float] = None
|
||||||
"The reward value"
|
"The reward value"
|
||||||
self.action_space: int = None
|
self.action_space: Optional[int] = None
|
||||||
"The action space invoked by the agent"
|
"The action space invoked by the agent"
|
||||||
self.obs_space_description: List[str] = None
|
self.obs_space_description: Optional[List[str]] = None
|
||||||
"The env observation space description"
|
"The env observation space description"
|
||||||
|
|
||||||
def as_csv_data(self) -> Tuple[List, List]:
|
def as_csv_data(self) -> Tuple[List, List]:
|
||||||
|
|||||||
Reference in New Issue
Block a user