Add More Typehint

This commit is contained in:
Marek Wolan
2023-07-13 18:08:44 +01:00
parent 4e4166d4d4
commit a923d818d3
9 changed files with 107 additions and 83 deletions

View File

@@ -148,6 +148,7 @@ class ActionType(Enum):
ANY = 2
# TODO: this is not used anymore, write a ticket to delete it.
class ObservationType(Enum):
"""Observation type enumeration."""

View File

@@ -562,7 +562,7 @@ class Primaite(Env):
else:
return
def apply_actions_to_acl(self, _action):
def apply_actions_to_acl(self, _action: int) -> None:
"""
Applies agent actions to the Access Control List [TO DO].
@@ -640,7 +640,7 @@ class Primaite(Env):
else:
return
def apply_time_based_updates(self):
def apply_time_based_updates(self) -> None:
"""
Updates anything that needs to count down and then change state.
@@ -696,12 +696,12 @@ class Primaite(Env):
return self.obs_handler.space, self.obs_handler.current_observation
def update_environent_obs(self):
def update_environent_obs(self) -> None:
"""Updates the observation space based on the node and link status."""
self.obs_handler.update_obs()
self.env_obs = self.obs_handler.current_observation
def load_lay_down_config(self):
def load_lay_down_config(self) -> None:
"""Loads config data in order to build the environment configuration."""
for item in self.lay_down_config:
if item["item_type"] == "NODE":
@@ -739,7 +739,7 @@ class Primaite(Env):
_LOGGER.info("Environment configuration loaded")
print("Environment configuration loaded")
def create_node(self, item):
def create_node(self, item: Dict) -> None:
"""
Creates a node from config data.
@@ -820,7 +820,7 @@ class Primaite(Env):
# Add node to network (reference)
self.network_reference.add_nodes_from([node_ref])
def create_link(self, item: Dict):
def create_link(self, item: Dict) -> None:
"""
Creates a link from config data.
@@ -864,7 +864,7 @@ class Primaite(Env):
self.services_list,
)
def create_green_ier(self, item):
def create_green_ier(self, item: Dict) -> None:
"""
Creates a green IER from config data.
@@ -905,7 +905,7 @@ class Primaite(Env):
ier_mission_criticality,
)
def create_red_ier(self, item):
def create_red_ier(self, item: Dict) -> None:
"""
Creates a red IER from config data.
@@ -935,7 +935,7 @@ class Primaite(Env):
ier_mission_criticality,
)
def create_green_pol(self, item):
def create_green_pol(self, item: Dict) -> None:
"""
Creates a green PoL object from config data.
@@ -969,7 +969,7 @@ class Primaite(Env):
pol_state,
)
def create_red_pol(self, item):
def create_red_pol(self, item: Dict) -> None:
"""
Creates a red PoL object from config data.
@@ -1010,7 +1010,7 @@ class Primaite(Env):
pol_source_node_service_state,
)
def create_acl_rule(self, item):
def create_acl_rule(self, item: Dict) -> None:
"""
Creates an ACL rule from config data.
@@ -1031,7 +1031,8 @@ class Primaite(Env):
acl_rule_port,
)
def create_services_list(self, services):
# TODO: confirm typehint using runtime
def create_services_list(self, services: Dict) -> None:
"""
Creates a list of services (enum) from config data.
@@ -1047,7 +1048,7 @@ class Primaite(Env):
# Set the number of services
self.num_services = len(self.services_list)
def create_ports_list(self, ports):
def create_ports_list(self, ports: Dict) -> None:
"""
Creates a list of ports from config data.
@@ -1063,7 +1064,8 @@ class Primaite(Env):
# Set the number of ports
self.num_ports = len(self.ports_list)
def get_observation_info(self, observation_info):
# TODO: this is not used anymore, write a ticket to delete it
def get_observation_info(self, observation_info: Dict) -> None:
"""
Extracts observation_info.
@@ -1072,7 +1074,8 @@ class Primaite(Env):
"""
self.observation_type = ObservationType[observation_info["type"]]
def get_action_info(self, action_info):
# TODO: this is not used anymore, write a ticket to delete it.
def get_action_info(self, action_info: Dict) -> None:
"""
Extracts action_info.
@@ -1081,7 +1084,7 @@ class Primaite(Env):
"""
self.action_type = ActionType[action_info["type"]]
def save_obs_config(self, obs_config: dict):
def save_obs_config(self, obs_config: dict) -> None:
"""
Cache the config for the observation space.
@@ -1094,7 +1097,7 @@ class Primaite(Env):
"""
self.obs_config = obs_config
def reset_environment(self):
def reset_environment(self) -> None:
"""
Resets environment.
@@ -1119,7 +1122,7 @@ class Primaite(Env):
for ier_key, ier_value in self.red_iers.items():
ier_value.set_is_running(False)
def reset_node(self, item):
def reset_node(self, item: Dict) -> None:
"""
Resets the statuses of a node.
@@ -1167,7 +1170,7 @@ class Primaite(Env):
# Bad formatting
pass
def create_node_action_dict(self):
def create_node_action_dict(self) -> Dict[int, List[int]]:
"""
Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.
@@ -1202,7 +1205,7 @@ class Primaite(Env):
return actions
def create_acl_action_dict(self):
def create_acl_action_dict(self) -> Dict[int, List[int]]:
"""Creates a dictionary mapping each possible discrete action to more readable multidiscrete action."""
# reserve 0 action to be a nothing action
actions = {0: [0, 0, 0, 0, 0, 0]}
@@ -1232,7 +1235,7 @@ class Primaite(Env):
return actions
def create_node_and_acl_action_dict(self):
def create_node_and_acl_action_dict(self) -> Dict[int, List[int]]:
"""
Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action.
@@ -1249,7 +1252,7 @@ class Primaite(Env):
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
return combined_action_dict
def _create_random_red_agent(self):
def _create_random_red_agent(self) -> None:
"""Decide on random red agent for the episode to be called in env.reset()."""
# Reset the current red iers and red node pol
self.red_iers = {}

View File

@@ -1,25 +1,32 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""Implements reward function."""
from typing import Dict
from typing import Dict, TYPE_CHECKING
from primaite import getLogger
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
from primaite.common.service import Service
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from logging import Logger
from primaite.config.training_config import TrainingConfig
from primaite.pol.ier import IER
_LOGGER: "Logger" = getLogger(__name__)
def calculate_reward_function(
initial_nodes,
final_nodes,
reference_nodes,
green_iers,
green_iers_reference,
red_iers,
step_count,
config_values,
initial_nodes: Dict[str, NodeUnion],
final_nodes: Dict[str, NodeUnion],
reference_nodes: Dict[str, NodeUnion],
green_iers: Dict[str, "IER"],
green_iers_reference: Dict[str, "IER"],
red_iers: Dict[str, "IER"],
step_count: int,
config_values: "TrainingConfig",
) -> float:
"""
Compares the states of the initial and final nodes/links to get a reward.
@@ -93,7 +100,9 @@ def calculate_reward_function(
return reward_value
def score_node_operating_state(final_node, initial_node, reference_node, config_values) -> float:
def score_node_operating_state(
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig"
) -> float:
"""
Calculates score relating to the hardware state of a node.
@@ -142,7 +151,9 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
return score
def score_node_os_state(final_node, initial_node, reference_node, config_values) -> float:
def score_node_os_state(
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig"
) -> float:
"""
Calculates score relating to the Software State of a node.
@@ -193,7 +204,9 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
return score
def score_node_service_state(final_node, initial_node, reference_node, config_values) -> float:
def score_node_service_state(
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig"
) -> float:
"""
Calculates score relating to the service state(s) of a node.
@@ -265,7 +278,9 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
return score
def score_node_file_system(final_node, initial_node, reference_node, config_values) -> float:
def score_node_file_system(
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig"
) -> float:
"""
Calculates score relating to the file system state of a node.

View File

@@ -8,7 +8,7 @@ from primaite.common.protocol import Protocol
class Link(object):
"""Link class."""
def __init__(self, _id, _bandwidth, _source_node_name, _dest_node_name, _services):
def __init__(self, _id: str, _bandwidth: int, _source_node_name: str, _dest_node_name: str, _services: str) -> None:
"""
Initialise a Link within the simulated network.
@@ -18,17 +18,17 @@ class Link(object):
:param _dest_node_name: The name of the destination node
:param _protocols: The protocols to add to the link
"""
self.id = _id
self.bandwidth = _bandwidth
self.source_node_name = _source_node_name
self.dest_node_name = _dest_node_name
self.id: str = _id
self.bandwidth: int = _bandwidth
self.source_node_name: str = _source_node_name
self.dest_node_name: str = _dest_node_name
self.protocol_list: List[Protocol] = []
# Add the default protocols
for protocol_name in _services:
self.add_protocol(protocol_name)
def add_protocol(self, _protocol):
def add_protocol(self, _protocol: str) -> None:
"""
Adds a new protocol to the list of protocols on this link.
@@ -37,7 +37,7 @@ class Link(object):
"""
self.protocol_list.append(Protocol(_protocol))
def get_id(self):
def get_id(self) -> str:
"""
Gets link ID.
@@ -46,7 +46,7 @@ class Link(object):
"""
return self.id
def get_source_node_name(self):
def get_source_node_name(self) -> str:
"""
Gets source node name.
@@ -55,7 +55,7 @@ class Link(object):
"""
return self.source_node_name
def get_dest_node_name(self):
def get_dest_node_name(self) -> str:
"""
Gets destination node name.
@@ -64,7 +64,7 @@ class Link(object):
"""
return self.dest_node_name
def get_bandwidth(self):
def get_bandwidth(self) -> int:
"""
Gets bandwidth of link.
@@ -73,7 +73,7 @@ class Link(object):
"""
return self.bandwidth
def get_protocol_list(self):
def get_protocol_list(self) -> List[Protocol]:
"""
Gets list of protocols on this link.
@@ -82,7 +82,7 @@ class Link(object):
"""
return self.protocol_list
def get_current_load(self):
def get_current_load(self) -> int:
"""
Gets current total load on this link.
@@ -94,7 +94,7 @@ class Link(object):
total_load += protocol.get_load()
return total_load
def add_protocol_load(self, _protocol, _load):
def add_protocol_load(self, _protocol: str, _load: int) -> None:
"""
Adds a loading to a protocol on this link.
@@ -108,7 +108,7 @@ class Link(object):
else:
pass
def clear_traffic(self):
def clear_traffic(self) -> None:
"""Clears all traffic on this link."""
for protocol in self.protocol_list:
protocol.clear_load()

View File

@@ -24,7 +24,7 @@ class ActiveNode(Node):
software_state: SoftwareState,
file_system_state: FileSystemState,
config_values: TrainingConfig,
):
) -> None:
"""
Initialise an active node.
@@ -60,7 +60,7 @@ class ActiveNode(Node):
return self._software_state
@software_state.setter
def software_state(self, software_state: SoftwareState):
def software_state(self, software_state: SoftwareState) -> None:
"""
Get the software_state.
@@ -79,7 +79,7 @@ class ActiveNode(Node):
f"Node.software_state:{self._software_state}"
)
def set_software_state_if_not_compromised(self, software_state: SoftwareState):
def set_software_state_if_not_compromised(self, software_state: SoftwareState) -> None:
"""
Sets Software State if the node is not compromised.
@@ -99,14 +99,14 @@ class ActiveNode(Node):
f"Node.software_state:{self._software_state}"
)
def update_os_patching_status(self):
def update_os_patching_status(self) -> None:
"""Updates operating system status based on patching cycle."""
self.patching_count -= 1
if self.patching_count <= 0:
self.patching_count = 0
self._software_state = SoftwareState.GOOD
def set_file_system_state(self, file_system_state: FileSystemState):
def set_file_system_state(self, file_system_state: FileSystemState) -> None:
"""
Sets the file system state (actual and observed).
@@ -133,7 +133,7 @@ class ActiveNode(Node):
f"Node.file_system_state.actual:{self.file_system_state_actual}"
)
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState):
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState) -> None:
"""
Sets the file system state (actual and observed) if not in a compromised state.
@@ -166,12 +166,12 @@ class ActiveNode(Node):
f"Node.file_system_state.actual:{self.file_system_state_actual}"
)
def start_file_system_scan(self):
def start_file_system_scan(self) -> None:
"""Starts a file system scan."""
self.file_system_scanning = True
self.file_system_scanning_count = self.config_values.file_system_scanning_limit
def update_file_system_state(self):
def update_file_system_state(self) -> None:
"""Updates file system status based on scanning/restore/repair cycle."""
# Deprecate both the action count (for restoring or reparing) and the scanning count
self.file_system_action_count -= 1
@@ -193,14 +193,14 @@ class ActiveNode(Node):
self.file_system_scanning = False
self.file_system_scanning_count = 0
def update_resetting_status(self):
def update_resetting_status(self) -> None:
"""Updates the reset count & makes software and file state to GOOD."""
super().update_resetting_status()
if self.resetting_count <= 0:
self.file_system_state_actual = FileSystemState.GOOD
self.software_state = SoftwareState.GOOD
def update_booting_status(self):
def update_booting_status(self) -> None:
"""Updates the booting software and file state to GOOD."""
super().update_booting_status()
if self.booting_count <= 0:

View File

@@ -38,40 +38,40 @@ class Node:
self.booting_count: int = 0
self.shutting_down_count: int = 0
def __repr__(self):
def __repr__(self) -> str:
"""Returns the name of the node."""
return self.name
def turn_on(self):
def turn_on(self) -> None:
"""Sets the node state to ON."""
self.hardware_state = HardwareState.BOOTING
self.booting_count = self.config_values.node_booting_duration
def turn_off(self):
def turn_off(self) -> None:
"""Sets the node state to OFF."""
self.hardware_state = HardwareState.OFF
self.shutting_down_count = self.config_values.node_shutdown_duration
def reset(self):
def reset(self) -> None:
"""Sets the node state to Resetting and starts the reset count."""
self.hardware_state = HardwareState.RESETTING
self.resetting_count = self.config_values.node_reset_duration
def update_resetting_status(self):
def update_resetting_status(self) -> None:
"""Updates the resetting count."""
self.resetting_count -= 1
if self.resetting_count <= 0:
self.resetting_count = 0
self.hardware_state = HardwareState.ON
def update_booting_status(self):
def update_booting_status(self) -> None:
"""Updates the booting count."""
self.booting_count -= 1
if self.booting_count <= 0:
self.booting_count = 0
self.hardware_state = HardwareState.ON
def update_shutdown_status(self):
def update_shutdown_status(self) -> None:
"""Updates the shutdown count."""
self.shutting_down_count -= 1
if self.shutting_down_count <= 0:

View File

@@ -1,5 +1,9 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""Defines node behaviour for Green PoL."""
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from primaite.common.enums import HardwareState, NodePOLType, SoftwareState
class NodeStateInstructionGreen(object):
@@ -7,10 +11,10 @@ class NodeStateInstructionGreen(object):
def __init__(
self,
_id,
_start_step,
_end_step,
_node_id,
_id: str,
_start_step: int,
_end_step: int,
_node_id: str,
_node_pol_type,
_service_name,
_state,
@@ -30,9 +34,10 @@ class NodeStateInstructionGreen(object):
self.start_step = _start_step
self.end_step = _end_step
self.node_id = _node_id
self.node_pol_type = _node_pol_type
self.service_name = _service_name # Not used when not a service instruction
self.state = _state
self.node_pol_type: "NodePOLType" = _node_pol_type
self.service_name: str = _service_name # Not used when not a service instruction
# TODO: confirm type of state
self.state: Union["HardwareState", "SoftwareState"] = _state
def get_start_step(self):
"""

View File

@@ -16,7 +16,7 @@ class PassiveNode(Node):
priority: Priority,
hardware_state: HardwareState,
config_values: TrainingConfig,
):
) -> None:
"""
Initialise a passive node.

View File

@@ -25,7 +25,7 @@ class ServiceNode(ActiveNode):
software_state: SoftwareState,
file_system_state: FileSystemState,
config_values: TrainingConfig,
):
) -> None:
"""
Initialise a Service Node.
@@ -52,7 +52,7 @@ class ServiceNode(ActiveNode):
)
self.services: Dict[str, Service] = {}
def add_service(self, service: Service):
def add_service(self, service: Service) -> None:
"""
Adds a service to the node.
@@ -102,7 +102,7 @@ class ServiceNode(ActiveNode):
return False
return False
def set_service_state(self, protocol_name: str, software_state: SoftwareState):
def set_service_state(self, protocol_name: str, software_state: SoftwareState) -> None:
"""
Sets the software_state of a service (protocol) on the node.
@@ -131,7 +131,7 @@ class ServiceNode(ActiveNode):
f"Node.services[<key>].software_state:{software_state}"
)
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState):
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState) -> None:
"""
Sets the software_state of a service (protocol) on the node.
@@ -158,7 +158,7 @@ class ServiceNode(ActiveNode):
f"Node.services[<key>].software_state:{software_state}"
)
def get_service_state(self, protocol_name):
def get_service_state(self, protocol_name: str) -> SoftwareState:
"""
Gets the state of a service.
@@ -169,20 +169,20 @@ class ServiceNode(ActiveNode):
if service_value:
return service_value.software_state
def update_services_patching_status(self):
def update_services_patching_status(self) -> None:
"""Updates the patching counter for any service that are patching."""
for service_key, service_value in self.services.items():
if service_value.software_state == SoftwareState.PATCHING:
service_value.reduce_patching_count()
def update_resetting_status(self):
def update_resetting_status(self) -> None:
"""Update resetting counter and set software state if it reached 0."""
super().update_resetting_status()
if self.resetting_count <= 0:
for service in self.services.values():
service.software_state = SoftwareState.GOOD
def update_booting_status(self):
def update_booting_status(self) -> None:
"""Update booting counter and set software to good if it reached 0."""
super().update_booting_status()
if self.booting_count <= 0: