diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index db5d153c..ff090ca9 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -148,6 +148,7 @@ class ActionType(Enum): ANY = 2 +# TODO: this is not used anymore, write a ticket to delete it. class ObservationType(Enum): """Observation type enumeration.""" diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 5bf843f1..d1c8adf5 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -562,7 +562,7 @@ class Primaite(Env): else: return - def apply_actions_to_acl(self, _action): + def apply_actions_to_acl(self, _action: int) -> None: """ Applies agent actions to the Access Control List [TO DO]. @@ -640,7 +640,7 @@ class Primaite(Env): else: return - def apply_time_based_updates(self): + def apply_time_based_updates(self) -> None: """ Updates anything that needs to count down and then change state. @@ -696,12 +696,12 @@ class Primaite(Env): return self.obs_handler.space, self.obs_handler.current_observation - def update_environent_obs(self): + def update_environent_obs(self) -> None: """Updates the observation space based on the node and link status.""" self.obs_handler.update_obs() self.env_obs = self.obs_handler.current_observation - def load_lay_down_config(self): + def load_lay_down_config(self) -> None: """Loads config data in order to build the environment configuration.""" for item in self.lay_down_config: if item["item_type"] == "NODE": @@ -739,7 +739,7 @@ class Primaite(Env): _LOGGER.info("Environment configuration loaded") print("Environment configuration loaded") - def create_node(self, item): + def create_node(self, item: Dict) -> None: """ Creates a node from config data. @@ -820,7 +820,7 @@ class Primaite(Env): # Add node to network (reference) self.network_reference.add_nodes_from([node_ref]) - def create_link(self, item: Dict): + def create_link(self, item: Dict) -> None: """ Creates a link from config data. @@ -864,7 +864,7 @@ class Primaite(Env): self.services_list, ) - def create_green_ier(self, item): + def create_green_ier(self, item: Dict) -> None: """ Creates a green IER from config data. @@ -905,7 +905,7 @@ class Primaite(Env): ier_mission_criticality, ) - def create_red_ier(self, item): + def create_red_ier(self, item: Dict) -> None: """ Creates a red IER from config data. @@ -935,7 +935,7 @@ class Primaite(Env): ier_mission_criticality, ) - def create_green_pol(self, item): + def create_green_pol(self, item: Dict) -> None: """ Creates a green PoL object from config data. @@ -969,7 +969,7 @@ class Primaite(Env): pol_state, ) - def create_red_pol(self, item): + def create_red_pol(self, item: Dict) -> None: """ Creates a red PoL object from config data. @@ -1010,7 +1010,7 @@ class Primaite(Env): pol_source_node_service_state, ) - def create_acl_rule(self, item): + def create_acl_rule(self, item: Dict) -> None: """ Creates an ACL rule from config data. @@ -1031,7 +1031,8 @@ class Primaite(Env): acl_rule_port, ) - def create_services_list(self, services): + # TODO: confirm typehint using runtime + def create_services_list(self, services: Dict) -> None: """ Creates a list of services (enum) from config data. @@ -1047,7 +1048,7 @@ class Primaite(Env): # Set the number of services self.num_services = len(self.services_list) - def create_ports_list(self, ports): + def create_ports_list(self, ports: Dict) -> None: """ Creates a list of ports from config data. @@ -1063,7 +1064,8 @@ class Primaite(Env): # Set the number of ports self.num_ports = len(self.ports_list) - def get_observation_info(self, observation_info): + # TODO: this is not used anymore, write a ticket to delete it + def get_observation_info(self, observation_info: Dict) -> None: """ Extracts observation_info. @@ -1072,7 +1074,8 @@ class Primaite(Env): """ self.observation_type = ObservationType[observation_info["type"]] - def get_action_info(self, action_info): + # TODO: this is not used anymore, write a ticket to delete it. + def get_action_info(self, action_info: Dict) -> None: """ Extracts action_info. @@ -1081,7 +1084,7 @@ class Primaite(Env): """ self.action_type = ActionType[action_info["type"]] - def save_obs_config(self, obs_config: dict): + def save_obs_config(self, obs_config: dict) -> None: """ Cache the config for the observation space. @@ -1094,7 +1097,7 @@ class Primaite(Env): """ self.obs_config = obs_config - def reset_environment(self): + def reset_environment(self) -> None: """ Resets environment. @@ -1119,7 +1122,7 @@ class Primaite(Env): for ier_key, ier_value in self.red_iers.items(): ier_value.set_is_running(False) - def reset_node(self, item): + def reset_node(self, item: Dict) -> None: """ Resets the statuses of a node. @@ -1167,7 +1170,7 @@ class Primaite(Env): # Bad formatting pass - def create_node_action_dict(self): + def create_node_action_dict(self) -> Dict[int, List[int]]: """ Creates a dictionary mapping each possible discrete action to more readable multidiscrete action. @@ -1202,7 +1205,7 @@ class Primaite(Env): return actions - def create_acl_action_dict(self): + def create_acl_action_dict(self) -> Dict[int, List[int]]: """Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.""" # reserve 0 action to be a nothing action actions = {0: [0, 0, 0, 0, 0, 0]} @@ -1232,7 +1235,7 @@ class Primaite(Env): return actions - def create_node_and_acl_action_dict(self): + def create_node_and_acl_action_dict(self) -> Dict[int, List[int]]: """ Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action. @@ -1249,7 +1252,7 @@ class Primaite(Env): combined_action_dict = {**acl_action_dict, **new_node_action_dict} return combined_action_dict - def _create_random_red_agent(self): + def _create_random_red_agent(self) -> None: """Decide on random red agent for the episode to be called in env.reset().""" # Reset the current red iers and red node pol self.red_iers = {} diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index 9cbb0078..c9acd921 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -1,25 +1,32 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """Implements reward function.""" -from typing import Dict +from typing import Dict, TYPE_CHECKING from primaite import getLogger +from primaite.common.custom_typing import NodeUnion from primaite.common.enums import FileSystemState, HardwareState, SoftwareState from primaite.common.service import Service from primaite.nodes.active_node import ActiveNode from primaite.nodes.service_node import ServiceNode -_LOGGER = getLogger(__name__) +if TYPE_CHECKING: + from logging import Logger + + from primaite.config.training_config import TrainingConfig + from primaite.pol.ier import IER + +_LOGGER: "Logger" = getLogger(__name__) def calculate_reward_function( - initial_nodes, - final_nodes, - reference_nodes, - green_iers, - green_iers_reference, - red_iers, - step_count, - config_values, + initial_nodes: Dict[str, NodeUnion], + final_nodes: Dict[str, NodeUnion], + reference_nodes: Dict[str, NodeUnion], + green_iers: Dict[str, "IER"], + green_iers_reference: Dict[str, "IER"], + red_iers: Dict[str, "IER"], + step_count: int, + config_values: "TrainingConfig", ) -> float: """ Compares the states of the initial and final nodes/links to get a reward. @@ -93,7 +100,9 @@ def calculate_reward_function( return reward_value -def score_node_operating_state(final_node, initial_node, reference_node, config_values) -> float: +def score_node_operating_state( + final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig" +) -> float: """ Calculates score relating to the hardware state of a node. @@ -142,7 +151,9 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_ return score -def score_node_os_state(final_node, initial_node, reference_node, config_values) -> float: +def score_node_os_state( + final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig" +) -> float: """ Calculates score relating to the Software State of a node. @@ -193,7 +204,9 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values) return score -def score_node_service_state(final_node, initial_node, reference_node, config_values) -> float: +def score_node_service_state( + final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig" +) -> float: """ Calculates score relating to the service state(s) of a node. @@ -265,7 +278,9 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va return score -def score_node_file_system(final_node, initial_node, reference_node, config_values) -> float: +def score_node_file_system( + final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig" +) -> float: """ Calculates score relating to the file system state of a node. diff --git a/src/primaite/links/link.py b/src/primaite/links/link.py index f61281cd..145de5f3 100644 --- a/src/primaite/links/link.py +++ b/src/primaite/links/link.py @@ -8,7 +8,7 @@ from primaite.common.protocol import Protocol class Link(object): """Link class.""" - def __init__(self, _id, _bandwidth, _source_node_name, _dest_node_name, _services): + def __init__(self, _id: str, _bandwidth: int, _source_node_name: str, _dest_node_name: str, _services: str) -> None: """ Initialise a Link within the simulated network. @@ -18,17 +18,17 @@ class Link(object): :param _dest_node_name: The name of the destination node :param _protocols: The protocols to add to the link """ - self.id = _id - self.bandwidth = _bandwidth - self.source_node_name = _source_node_name - self.dest_node_name = _dest_node_name + self.id: str = _id + self.bandwidth: int = _bandwidth + self.source_node_name: str = _source_node_name + self.dest_node_name: str = _dest_node_name self.protocol_list: List[Protocol] = [] # Add the default protocols for protocol_name in _services: self.add_protocol(protocol_name) - def add_protocol(self, _protocol): + def add_protocol(self, _protocol: str) -> None: """ Adds a new protocol to the list of protocols on this link. @@ -37,7 +37,7 @@ class Link(object): """ self.protocol_list.append(Protocol(_protocol)) - def get_id(self): + def get_id(self) -> str: """ Gets link ID. @@ -46,7 +46,7 @@ class Link(object): """ return self.id - def get_source_node_name(self): + def get_source_node_name(self) -> str: """ Gets source node name. @@ -55,7 +55,7 @@ class Link(object): """ return self.source_node_name - def get_dest_node_name(self): + def get_dest_node_name(self) -> str: """ Gets destination node name. @@ -64,7 +64,7 @@ class Link(object): """ return self.dest_node_name - def get_bandwidth(self): + def get_bandwidth(self) -> int: """ Gets bandwidth of link. @@ -73,7 +73,7 @@ class Link(object): """ return self.bandwidth - def get_protocol_list(self): + def get_protocol_list(self) -> List[Protocol]: """ Gets list of protocols on this link. @@ -82,7 +82,7 @@ class Link(object): """ return self.protocol_list - def get_current_load(self): + def get_current_load(self) -> int: """ Gets current total load on this link. @@ -94,7 +94,7 @@ class Link(object): total_load += protocol.get_load() return total_load - def add_protocol_load(self, _protocol, _load): + def add_protocol_load(self, _protocol: str, _load: int) -> None: """ Adds a loading to a protocol on this link. @@ -108,7 +108,7 @@ class Link(object): else: pass - def clear_traffic(self): + def clear_traffic(self) -> None: """Clears all traffic on this link.""" for protocol in self.protocol_list: protocol.clear_load() diff --git a/src/primaite/nodes/active_node.py b/src/primaite/nodes/active_node.py index f86f818b..b73f80f0 100644 --- a/src/primaite/nodes/active_node.py +++ b/src/primaite/nodes/active_node.py @@ -24,7 +24,7 @@ class ActiveNode(Node): software_state: SoftwareState, file_system_state: FileSystemState, config_values: TrainingConfig, - ): + ) -> None: """ Initialise an active node. @@ -60,7 +60,7 @@ class ActiveNode(Node): return self._software_state @software_state.setter - def software_state(self, software_state: SoftwareState): + def software_state(self, software_state: SoftwareState) -> None: """ Get the software_state. @@ -79,7 +79,7 @@ class ActiveNode(Node): f"Node.software_state:{self._software_state}" ) - def set_software_state_if_not_compromised(self, software_state: SoftwareState): + def set_software_state_if_not_compromised(self, software_state: SoftwareState) -> None: """ Sets Software State if the node is not compromised. @@ -99,14 +99,14 @@ class ActiveNode(Node): f"Node.software_state:{self._software_state}" ) - def update_os_patching_status(self): + def update_os_patching_status(self) -> None: """Updates operating system status based on patching cycle.""" self.patching_count -= 1 if self.patching_count <= 0: self.patching_count = 0 self._software_state = SoftwareState.GOOD - def set_file_system_state(self, file_system_state: FileSystemState): + def set_file_system_state(self, file_system_state: FileSystemState) -> None: """ Sets the file system state (actual and observed). @@ -133,7 +133,7 @@ class ActiveNode(Node): f"Node.file_system_state.actual:{self.file_system_state_actual}" ) - def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState): + def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState) -> None: """ Sets the file system state (actual and observed) if not in a compromised state. @@ -166,12 +166,12 @@ class ActiveNode(Node): f"Node.file_system_state.actual:{self.file_system_state_actual}" ) - def start_file_system_scan(self): + def start_file_system_scan(self) -> None: """Starts a file system scan.""" self.file_system_scanning = True self.file_system_scanning_count = self.config_values.file_system_scanning_limit - def update_file_system_state(self): + def update_file_system_state(self) -> None: """Updates file system status based on scanning/restore/repair cycle.""" # Deprecate both the action count (for restoring or reparing) and the scanning count self.file_system_action_count -= 1 @@ -193,14 +193,14 @@ class ActiveNode(Node): self.file_system_scanning = False self.file_system_scanning_count = 0 - def update_resetting_status(self): + def update_resetting_status(self) -> None: """Updates the reset count & makes software and file state to GOOD.""" super().update_resetting_status() if self.resetting_count <= 0: self.file_system_state_actual = FileSystemState.GOOD self.software_state = SoftwareState.GOOD - def update_booting_status(self): + def update_booting_status(self) -> None: """Updates the booting software and file state to GOOD.""" super().update_booting_status() if self.booting_count <= 0: diff --git a/src/primaite/nodes/node.py b/src/primaite/nodes/node.py index 9fd5b719..cd500c9e 100644 --- a/src/primaite/nodes/node.py +++ b/src/primaite/nodes/node.py @@ -38,40 +38,40 @@ class Node: self.booting_count: int = 0 self.shutting_down_count: int = 0 - def __repr__(self): + def __repr__(self) -> str: """Returns the name of the node.""" return self.name - def turn_on(self): + def turn_on(self) -> None: """Sets the node state to ON.""" self.hardware_state = HardwareState.BOOTING self.booting_count = self.config_values.node_booting_duration - def turn_off(self): + def turn_off(self) -> None: """Sets the node state to OFF.""" self.hardware_state = HardwareState.OFF self.shutting_down_count = self.config_values.node_shutdown_duration - def reset(self): + def reset(self) -> None: """Sets the node state to Resetting and starts the reset count.""" self.hardware_state = HardwareState.RESETTING self.resetting_count = self.config_values.node_reset_duration - def update_resetting_status(self): + def update_resetting_status(self) -> None: """Updates the resetting count.""" self.resetting_count -= 1 if self.resetting_count <= 0: self.resetting_count = 0 self.hardware_state = HardwareState.ON - def update_booting_status(self): + def update_booting_status(self) -> None: """Updates the booting count.""" self.booting_count -= 1 if self.booting_count <= 0: self.booting_count = 0 self.hardware_state = HardwareState.ON - def update_shutdown_status(self): + def update_shutdown_status(self) -> None: """Updates the shutdown count.""" self.shutting_down_count -= 1 if self.shutting_down_count <= 0: diff --git a/src/primaite/nodes/node_state_instruction_green.py b/src/primaite/nodes/node_state_instruction_green.py index 7ebe3886..5a225c25 100644 --- a/src/primaite/nodes/node_state_instruction_green.py +++ b/src/primaite/nodes/node_state_instruction_green.py @@ -1,5 +1,9 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """Defines node behaviour for Green PoL.""" +from typing import TYPE_CHECKING, Union + +if TYPE_CHECKING: + from primaite.common.enums import HardwareState, NodePOLType, SoftwareState class NodeStateInstructionGreen(object): @@ -7,10 +11,10 @@ class NodeStateInstructionGreen(object): def __init__( self, - _id, - _start_step, - _end_step, - _node_id, + _id: str, + _start_step: int, + _end_step: int, + _node_id: str, _node_pol_type, _service_name, _state, @@ -30,9 +34,10 @@ class NodeStateInstructionGreen(object): self.start_step = _start_step self.end_step = _end_step self.node_id = _node_id - self.node_pol_type = _node_pol_type - self.service_name = _service_name # Not used when not a service instruction - self.state = _state + self.node_pol_type: "NodePOLType" = _node_pol_type + self.service_name: str = _service_name # Not used when not a service instruction + # TODO: confirm type of state + self.state: Union["HardwareState", "SoftwareState"] = _state def get_start_step(self): """ diff --git a/src/primaite/nodes/passive_node.py b/src/primaite/nodes/passive_node.py index afe4e2d1..c79636e3 100644 --- a/src/primaite/nodes/passive_node.py +++ b/src/primaite/nodes/passive_node.py @@ -16,7 +16,7 @@ class PassiveNode(Node): priority: Priority, hardware_state: HardwareState, config_values: TrainingConfig, - ): + ) -> None: """ Initialise a passive node. diff --git a/src/primaite/nodes/service_node.py b/src/primaite/nodes/service_node.py index 4ad52a1e..ef0cd92e 100644 --- a/src/primaite/nodes/service_node.py +++ b/src/primaite/nodes/service_node.py @@ -25,7 +25,7 @@ class ServiceNode(ActiveNode): software_state: SoftwareState, file_system_state: FileSystemState, config_values: TrainingConfig, - ): + ) -> None: """ Initialise a Service Node. @@ -52,7 +52,7 @@ class ServiceNode(ActiveNode): ) self.services: Dict[str, Service] = {} - def add_service(self, service: Service): + def add_service(self, service: Service) -> None: """ Adds a service to the node. @@ -102,7 +102,7 @@ class ServiceNode(ActiveNode): return False return False - def set_service_state(self, protocol_name: str, software_state: SoftwareState): + def set_service_state(self, protocol_name: str, software_state: SoftwareState) -> None: """ Sets the software_state of a service (protocol) on the node. @@ -131,7 +131,7 @@ class ServiceNode(ActiveNode): f"Node.services[].software_state:{software_state}" ) - def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState): + def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState) -> None: """ Sets the software_state of a service (protocol) on the node. @@ -158,7 +158,7 @@ class ServiceNode(ActiveNode): f"Node.services[].software_state:{software_state}" ) - def get_service_state(self, protocol_name): + def get_service_state(self, protocol_name: str) -> SoftwareState: """ Gets the state of a service. @@ -169,20 +169,20 @@ class ServiceNode(ActiveNode): if service_value: return service_value.software_state - def update_services_patching_status(self): + def update_services_patching_status(self) -> None: """Updates the patching counter for any service that are patching.""" for service_key, service_value in self.services.items(): if service_value.software_state == SoftwareState.PATCHING: service_value.reduce_patching_count() - def update_resetting_status(self): + def update_resetting_status(self) -> None: """Update resetting counter and set software state if it reached 0.""" super().update_resetting_status() if self.resetting_count <= 0: for service in self.services.values(): service.software_state = SoftwareState.GOOD - def update_booting_status(self): + def update_booting_status(self) -> None: """Update booting counter and set software to good if it reached 0.""" super().update_booting_status() if self.booting_count <= 0: