diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 1f06a371..90860f7d 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -5,7 +5,7 @@ import time from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path -from typing import Dict, Final, Union +from typing import Any, Dict, Final, TYPE_CHECKING, Union from uuid import uuid4 import yaml @@ -17,7 +17,13 @@ from primaite.config.training_config import TrainingConfig from primaite.data_viz.session_plots import plot_av_reward_per_episode from primaite.environment.primaite_env import Primaite -_LOGGER = getLogger(__name__) +if TYPE_CHECKING: + from logging import Logger + + import numpy as np + + +_LOGGER: "Logger" = getLogger(__name__) def get_session_path(session_timestamp: datetime) -> Path: @@ -47,7 +53,7 @@ class AgentSessionABC(ABC): """ @abstractmethod - def __init__(self, training_config_path, lay_down_config_path): + def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None: """ Initialise an agent session from config files. @@ -107,11 +113,11 @@ class AgentSessionABC(ABC): return path @property - def uuid(self): + def uuid(self) -> str: """The Agent Session UUID.""" return self._uuid - def _write_session_metadata_file(self): + def _write_session_metadata_file(self) -> None: """ Write the ``session_metadata.json`` file. @@ -147,7 +153,7 @@ class AgentSessionABC(ABC): json.dump(metadata_dict, file) _LOGGER.debug("Finished writing session metadata file") - def _update_session_metadata_file(self): + def _update_session_metadata_file(self) -> None: """ Update the ``session_metadata.json`` file. @@ -176,7 +182,7 @@ class AgentSessionABC(ABC): _LOGGER.debug("Finished updating session metadata file") @abstractmethod - def _setup(self): + def _setup(self) -> None: _LOGGER.info( "Welcome to the Primary-level AI Training Environment " f"(PrimAITE) (version: {primaite.__version__})" ) @@ -186,14 +192,14 @@ class AgentSessionABC(ABC): self._can_evaluate = False @abstractmethod - def _save_checkpoint(self): + def _save_checkpoint(self) -> None: pass @abstractmethod def learn( self, - **kwargs, - ): + **kwargs: Any, + ) -> None: """ Train the agent. @@ -210,8 +216,8 @@ class AgentSessionABC(ABC): @abstractmethod def evaluate( self, - **kwargs, - ): + **kwargs: Any, + ) -> None: """ Evaluate the agent. @@ -224,7 +230,7 @@ class AgentSessionABC(ABC): _LOGGER.info("Finished evaluation") @abstractmethod - def _get_latest_checkpoint(self): + def _get_latest_checkpoint(self) -> None: pass @classmethod @@ -264,7 +270,6 @@ class AgentSessionABC(ABC): msg = f"Failed to load PrimAITE Session, path does not exist: {path}" _LOGGER.error(msg) raise FileNotFoundError(msg) - pass @property def _saved_agent_path(self) -> Path: @@ -276,21 +281,21 @@ class AgentSessionABC(ABC): return self.learning_path / file_name @abstractmethod - def save(self): + def save(self) -> None: """Save the agent.""" pass @abstractmethod - def export(self): + def export(self) -> None: """Export the agent to transportable file format.""" pass - def close(self): + def close(self) -> None: """Closes the agent.""" self._env.episode_av_reward_writer.close() # noqa self._env.transaction_writer.close() # noqa - def _plot_av_reward_per_episode(self, learning_session: bool = True): + def _plot_av_reward_per_episode(self, learning_session: bool = True) -> None: # self.close() title = f"PrimAITE Session {self.timestamp_str} " subtitle = str(self._training_config) @@ -318,7 +323,7 @@ class HardCodedAgentSessionABC(AgentSessionABC): implemented. """ - def __init__(self, training_config_path, lay_down_config_path): + def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None: """ Initialise a hardcoded agent session. @@ -331,7 +336,7 @@ class HardCodedAgentSessionABC(AgentSessionABC): super().__init__(training_config_path, lay_down_config_path) self._setup() - def _setup(self): + def _setup(self) -> None: self._env: Primaite = Primaite( training_config_path=self._training_config_path, lay_down_config_path=self._lay_down_config_path, @@ -342,16 +347,16 @@ class HardCodedAgentSessionABC(AgentSessionABC): self._can_learn = False self._can_evaluate = True - def _save_checkpoint(self): + def _save_checkpoint(self) -> None: pass - def _get_latest_checkpoint(self): + def _get_latest_checkpoint(self) -> None: pass def learn( self, - **kwargs, - ): + **kwargs: Any, + ) -> None: """ Train the agent. @@ -360,13 +365,13 @@ class HardCodedAgentSessionABC(AgentSessionABC): _LOGGER.warning("Deterministic agents cannot learn") @abstractmethod - def _calculate_action(self, obs): + def _calculate_action(self, obs: np.ndarray) -> None: pass def evaluate( self, - **kwargs, - ): + **kwargs: Any, + ) -> None: """ Evaluate the agent. @@ -398,14 +403,14 @@ class HardCodedAgentSessionABC(AgentSessionABC): super().evaluate() @classmethod - def load(cls): + def load(cls) -> None: """Load an agent from file.""" _LOGGER.warning("Deterministic agents cannot be loaded") - def save(self): + def save(self) -> None: """Save the agent.""" _LOGGER.warning("Deterministic agents cannot be saved") - def export(self): + def export(self) -> None: """Export the agent to transportable file format.""" _LOGGER.warning("Deterministic agents cannot be exported") diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py index 166ff415..98c1d7d9 100644 --- a/src/primaite/agents/hardcoded_acl.py +++ b/src/primaite/agents/hardcoded_acl.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Union +from typing import Dict, List, Union import numpy as np @@ -32,7 +32,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): def get_blocked_green_iers( self, green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion] - ) -> Dict[Any, Any]: + ) -> Dict[str, IER]: """Get blocked green IERs. :param green_iers: Green IERs to check for being @@ -60,7 +60,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return blocked_green_iers - def get_matching_acl_rules_for_ier(self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]): + def get_matching_acl_rules_for_ier( + self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion] + ) -> Dict[int, ACLRule]: """Get list of ACL rules which are relevant to an IER. :param ier: Information Exchange Request to query against the ACL list @@ -83,7 +85,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): def get_blocking_acl_rules_for_ier( self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion] - ) -> Dict[str, Any]: + ) -> Dict[int, ACLRule]: """ Get blocking ACL rules for an IER. @@ -111,7 +113,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): def get_allow_acl_rules_for_ier( self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion] - ) -> Dict[str, Any]: + ) -> Dict[int, ACLRule]: """Get all allowing ACL rules for an IER. :param ier: Information Exchange Request to query against the ACL list @@ -141,7 +143,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): acl: AccessControlList, nodes: Dict[str, Union[ServiceNode, ActiveNode]], services_list: List[str], - ) -> Dict[str, ACLRule]: + ) -> Dict[int, ACLRule]: """Filter ACL rules to only those which are relevant to the specified nodes. :param source_node_id: Source node @@ -186,7 +188,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): acl: AccessControlList, nodes: Dict[str, NodeUnion], services_list: List[str], - ) -> Dict[str, ACLRule]: + ) -> Dict[int, ACLRule]: """List ALLOW rules relating to specified nodes. :param source_node_id: Source node id @@ -233,7 +235,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): acl: AccessControlList, nodes: Dict[str, NodeUnion], services_list: List[str], - ) -> Dict[str, ACLRule]: + ) -> Dict[int, ACLRule]: """List DENY rules relating to specified nodes. :param source_node_id: Source node id diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index 6253f574..6674a8df 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -4,7 +4,7 @@ import json import shutil from datetime import datetime from pathlib import Path -from typing import Union +from typing import Any, Callable, Dict, TYPE_CHECKING, Union from uuid import uuid4 from ray.rllib.algorithms import Algorithm @@ -18,10 +18,14 @@ from primaite.agents.agent import AgentSessionABC from primaite.common.enums import AgentFramework, AgentIdentifier from primaite.environment.primaite_env import Primaite -_LOGGER = getLogger(__name__) +if TYPE_CHECKING: + from logging import Logger + +_LOGGER: "Logger" = getLogger(__name__) -def _env_creator(env_config): +# TODO: verify type of env_config +def _env_creator(env_config: Dict[str, Any]) -> Primaite: return Primaite( training_config_path=env_config["training_config_path"], lay_down_config_path=env_config["lay_down_config_path"], @@ -30,11 +34,12 @@ def _env_creator(env_config): ) -def _custom_log_creator(session_path: Path): +# TODO: verify type hint return type +def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]: logdir = session_path / "ray_results" logdir.mkdir(parents=True, exist_ok=True) - def logger_creator(config): + def logger_creator(config: Dict) -> UnifiedLogger: return UnifiedLogger(config, logdir, loggers=None) return logger_creator @@ -43,7 +48,7 @@ def _custom_log_creator(session_path: Path): class RLlibAgent(AgentSessionABC): """An AgentSession class that implements a Ray RLlib agent.""" - def __init__(self, training_config_path, lay_down_config_path): + def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None: """ Initialise the RLLib Agent training session. @@ -82,7 +87,7 @@ class RLlibAgent(AgentSessionABC): f"{self._training_config.deep_learning_framework}" ) - def _update_session_metadata_file(self): + def _update_session_metadata_file(self) -> None: """ Update the ``session_metadata.json`` file. @@ -110,7 +115,7 @@ class RLlibAgent(AgentSessionABC): json.dump(metadata_dict, file) _LOGGER.debug("Finished updating session metadata file") - def _setup(self): + def _setup(self) -> None: super()._setup() register_env("primaite", _env_creator) self._agent_config = self._agent_config_class() @@ -147,8 +152,8 @@ class RLlibAgent(AgentSessionABC): def learn( self, - **kwargs, - ): + **kwargs: Any, + ) -> None: """ Evaluate the agent. @@ -168,8 +173,8 @@ class RLlibAgent(AgentSessionABC): def evaluate( self, - **kwargs, - ): + **kwargs: None, + ) -> None: """ Evaluate the agent. @@ -177,7 +182,7 @@ class RLlibAgent(AgentSessionABC): """ raise NotImplementedError - def _get_latest_checkpoint(self): + def _get_latest_checkpoint(self) -> None: raise NotImplementedError @classmethod @@ -185,7 +190,7 @@ class RLlibAgent(AgentSessionABC): """Load an agent from file.""" raise NotImplementedError - def save(self, overwrite_existing: bool = True): + def save(self, overwrite_existing: bool = True) -> None: """Save the agent.""" # Make temp dir to save in isolation temp_dir = self.learning_path / str(uuid4()) @@ -205,6 +210,6 @@ class RLlibAgent(AgentSessionABC): # Drop the temp directory shutil.rmtree(temp_dir) - def export(self): + def export(self) -> None: """Export the agent to transportable file format.""" raise NotImplementedError diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index cb00985a..5f04acc0 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import Union +from typing import Any, TYPE_CHECKING, Union import numpy as np from stable_baselines3 import A2C, PPO @@ -12,13 +12,16 @@ from primaite.agents.agent import AgentSessionABC from primaite.common.enums import AgentFramework, AgentIdentifier from primaite.environment.primaite_env import Primaite -_LOGGER = getLogger(__name__) +if TYPE_CHECKING: + from logging import Logger + +_LOGGER: "Logger" = getLogger(__name__) class SB3Agent(AgentSessionABC): """An AgentSession class that implements a Stable Baselines3 agent.""" - def __init__(self, training_config_path, lay_down_config_path): + def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None: """ Initialise the SB3 Agent training session. @@ -57,7 +60,7 @@ class SB3Agent(AgentSessionABC): self.is_eval = False - def _setup(self): + def _setup(self) -> None: super()._setup() self._env = Primaite( training_config_path=self._training_config_path, @@ -75,7 +78,7 @@ class SB3Agent(AgentSessionABC): seed=self._training_config.seed, ) - def _save_checkpoint(self): + def _save_checkpoint(self) -> None: checkpoint_n = self._training_config.checkpoint_every_n_episodes episode_count = self._env.episode_count save_checkpoint = False @@ -86,13 +89,13 @@ class SB3Agent(AgentSessionABC): self._agent.save(checkpoint_path) _LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}") - def _get_latest_checkpoint(self): + def _get_latest_checkpoint(self) -> None: pass def learn( self, - **kwargs, - ): + **kwargs: Any, + ) -> None: """ Train the agent. @@ -115,8 +118,8 @@ class SB3Agent(AgentSessionABC): def evaluate( self, - **kwargs, - ): + **kwargs: Any, + ) -> None: """ Evaluate the agent. @@ -150,10 +153,10 @@ class SB3Agent(AgentSessionABC): """Load an agent from file.""" raise NotImplementedError - def save(self): + def save(self) -> None: """Save the agent.""" self._agent.save(self._saved_agent_path) - def export(self): + def export(self) -> None: """Export the agent to transportable file format.""" raise NotImplementedError diff --git a/src/primaite/agents/simple.py b/src/primaite/agents/simple.py index b429a2f5..2c130c0c 100644 --- a/src/primaite/agents/simple.py +++ b/src/primaite/agents/simple.py @@ -1,6 +1,11 @@ +from typing import TYPE_CHECKING + from primaite.agents.agent import HardCodedAgentSessionABC from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum +if TYPE_CHECKING: + import numpy as np + class RandomAgent(HardCodedAgentSessionABC): """ @@ -9,7 +14,7 @@ class RandomAgent(HardCodedAgentSessionABC): Get a completely random action from the action space. """ - def _calculate_action(self, obs): + def _calculate_action(self, obs: "np.ndarray") -> int: return self._env.action_space.sample() @@ -20,7 +25,7 @@ class DummyAgent(HardCodedAgentSessionABC): All action spaces setup so dummy action is always 0 regardless of action type used. """ - def _calculate_action(self, obs): + def _calculate_action(self, obs: "np.ndarray") -> int: return 0 @@ -31,7 +36,7 @@ class DoNothingACLAgent(HardCodedAgentSessionABC): A valid ACL action that has no effect; does nothing. """ - def _calculate_action(self, obs): + def _calculate_action(self, obs: "np.ndarray") -> int: nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"] nothing_action = transform_action_acl_enum(nothing_action) nothing_action = get_new_action(nothing_action, self._env.action_dict) @@ -46,7 +51,7 @@ class DoNothingNodeAgent(HardCodedAgentSessionABC): A valid Node action that has no effect; does nothing. """ - def _calculate_action(self, obs): + def _calculate_action(self, obs: "np.ndarray") -> int: nothing_action = [1, "NONE", "ON", 0] nothing_action = transform_action_node_enum(nothing_action) nothing_action = get_new_action(nothing_action, self._env.action_dict) diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index 8858fa6a..2e6b3f0c 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -38,7 +38,7 @@ def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]: return new_action -def transform_action_acl_readable(action: List[str]) -> List[Union[str, int]]: +def transform_action_acl_readable(action: List[int]) -> List[Union[str, int]]: """ Transform an ACL action to a more readable format. diff --git a/src/primaite/common/custom_typing.py b/src/primaite/common/custom_typing.py index 37b10efe..e01c8713 100644 --- a/src/primaite/common/custom_typing.py +++ b/src/primaite/common/custom_typing.py @@ -1,8 +1,8 @@ -from typing import Type, Union +from typing import TypeVar from primaite.nodes.active_node import ActiveNode from primaite.nodes.passive_node import PassiveNode from primaite.nodes.service_node import ServiceNode -NodeUnion: Type = Union[ActiveNode, PassiveNode, ServiceNode] +NodeUnion = TypeVar("NodeUnion", ServiceNode, ActiveNode, PassiveNode) """A Union of ActiveNode, PassiveNode, and ServiceNode.""" diff --git a/src/primaite/common/protocol.py b/src/primaite/common/protocol.py index ad6a1d83..f7a757e8 100644 --- a/src/primaite/common/protocol.py +++ b/src/primaite/common/protocol.py @@ -5,17 +5,17 @@ class Protocol(object): """Protocol class.""" - def __init__(self, _name): + def __init__(self, _name: str) -> None: """ Initialise a protocol. :param _name: The name of the protocol :type _name: str """ - self.name = _name - self.load = 0 # bps + self.name: str = _name + self.load: int = 0 # bps - def get_name(self): + def get_name(self) -> str: """ Gets the protocol name. @@ -24,7 +24,7 @@ class Protocol(object): """ return self.name - def get_load(self): + def get_load(self) -> int: """ Gets the protocol load. @@ -33,7 +33,7 @@ class Protocol(object): """ return self.load - def add_load(self, _load): + def add_load(self, _load: int) -> None: """ Adds load to the protocol. @@ -42,6 +42,6 @@ class Protocol(object): """ self.load += _load - def clear_load(self): + def clear_load(self) -> None: """Clears the load on this protocol.""" self.load = 0 diff --git a/src/primaite/common/service.py b/src/primaite/common/service.py index 258ac8f9..f3dddcc7 100644 --- a/src/primaite/common/service.py +++ b/src/primaite/common/service.py @@ -15,12 +15,12 @@ class Service(object): :param port: The service port. :param software_state: The service SoftwareState. """ - self.name = name - self.port = port - self.software_state = software_state - self.patching_count = 0 + self.name: str = name + self.port: str = port + self.software_state: SoftwareState = software_state + self.patching_count: int = 0 - def reduce_patching_count(self): + def reduce_patching_count(self) -> None: """Reduces the patching count for the service.""" self.patching_count -= 1 if self.patching_count <= 0: diff --git a/src/primaite/config/lay_down_config.py b/src/primaite/config/lay_down_config.py index 3a85b9da..2cc5f9c2 100644 --- a/src/primaite/config/lay_down_config.py +++ b/src/primaite/config/lay_down_config.py @@ -1,12 +1,15 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. from pathlib import Path -from typing import Any, Dict, Final, Union +from typing import Any, Dict, Final, TYPE_CHECKING, Union import yaml from primaite import getLogger, USERS_CONFIG_DIR -_LOGGER = getLogger(__name__) +if TYPE_CHECKING: + from logging import Logger + +_LOGGER: "Logger" = getLogger(__name__) _EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down" diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 785d9757..5cf62174 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -3,7 +3,7 @@ from __future__ import annotations from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, Final, Optional, Union +from typing import Any, Dict, Final, Optional, TYPE_CHECKING, Union import yaml @@ -18,7 +18,10 @@ from primaite.common.enums import ( SessionType, ) -_LOGGER = getLogger(__name__) +if TYPE_CHECKING: + from logging import Logger + +_LOGGER: Logger = getLogger(__name__) _EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training" diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 53c173fd..cb9872d1 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -14,17 +14,19 @@ from primaite.nodes.service_node import ServiceNode # TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking # Therefore, this avoids circular dependency problem. if TYPE_CHECKING: + from logging import Logger + from primaite.environment.primaite_env import Primaite -_LOGGER = logging.getLogger(__name__) +_LOGGER: "Logger" = logging.getLogger(__name__) class AbstractObservationComponent(ABC): """Represents a part of the PrimAITE observation space.""" @abstractmethod - def __init__(self, env: "Primaite"): + def __init__(self, env: "Primaite") -> None: """ Initialise observation component. @@ -39,7 +41,7 @@ class AbstractObservationComponent(ABC): return NotImplemented @abstractmethod - def update(self): + def update(self) -> None: """Update the observation based on the current state of the environment.""" self.current_observation = NotImplemented @@ -74,7 +76,7 @@ class NodeLinkTable(AbstractObservationComponent): _MAX_VAL: int = 1_000_000_000 _DATA_TYPE: type = np.int64 - def __init__(self, env: "Primaite"): + def __init__(self, env: "Primaite") -> None: """ Initialise a NodeLinkTable observation space component. @@ -101,7 +103,7 @@ class NodeLinkTable(AbstractObservationComponent): self.structure = self.generate_structure() - def update(self): + def update(self) -> None: """ Update the observation based on current environment state. @@ -148,7 +150,7 @@ class NodeLinkTable(AbstractObservationComponent): protocol_index += 1 item_index += 1 - def generate_structure(self): + def generate_structure(self) -> List[str]: """Return a list of labels for the components of the flattened observation space.""" nodes = self.env.nodes.values() links = self.env.links.values() @@ -211,7 +213,7 @@ class NodeStatuses(AbstractObservationComponent): _DATA_TYPE: type = np.int64 - def __init__(self, env: "Primaite"): + def __init__(self, env: "Primaite") -> None: """ Initialise a NodeStatuses observation component. @@ -237,7 +239,7 @@ class NodeStatuses(AbstractObservationComponent): self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) self.structure = self.generate_structure() - def update(self): + def update(self) -> None: """ Update the observation based on current environment state. @@ -268,7 +270,7 @@ class NodeStatuses(AbstractObservationComponent): ) self.current_observation[:] = obs - def generate_structure(self): + def generate_structure(self) -> List[str]: """Return a list of labels for the components of the flattened observation space.""" services = self.env.services_list @@ -317,7 +319,7 @@ class LinkTrafficLevels(AbstractObservationComponent): env: "Primaite", combine_service_traffic: bool = False, quantisation_levels: int = 5, - ): + ) -> None: """ Initialise a LinkTrafficLevels observation component. @@ -359,7 +361,7 @@ class LinkTrafficLevels(AbstractObservationComponent): self.structure = self.generate_structure() - def update(self): + def update(self) -> None: """ Update the observation based on current environment state. @@ -385,7 +387,7 @@ class LinkTrafficLevels(AbstractObservationComponent): self.current_observation[:] = obs - def generate_structure(self): + def generate_structure(self) -> List[str]: """Return a list of labels for the components of the flattened observation space.""" structure = [] for _, link in self.env.links.items(): @@ -415,7 +417,7 @@ class ObservationsHandler: "LINK_TRAFFIC_LEVELS": LinkTrafficLevels, } - def __init__(self): + def __init__(self) -> None: """Initialise the observation handler.""" self.registered_obs_components: List[AbstractObservationComponent] = [] @@ -430,7 +432,7 @@ class ObservationsHandler: self.flatten: bool = False - def update_obs(self): + def update_obs(self) -> None: """Fetch fresh information about the environment.""" current_obs = [] for obs in self.registered_obs_components: @@ -443,7 +445,7 @@ class ObservationsHandler: self._observation = tuple(current_obs) self._flat_observation = spaces.flatten(self._space, self._observation) - def register(self, obs_component: AbstractObservationComponent): + def register(self, obs_component: AbstractObservationComponent) -> None: """ Add a component for this handler to track. @@ -453,7 +455,7 @@ class ObservationsHandler: self.registered_obs_components.append(obs_component) self.update_space() - def deregister(self, obs_component: AbstractObservationComponent): + def deregister(self, obs_component: AbstractObservationComponent) -> None: """ Remove a component from this handler. @@ -464,7 +466,7 @@ class ObservationsHandler: self.registered_obs_components.remove(obs_component) self.update_space() - def update_space(self): + def update_space(self) -> None: """Rebuild the handler's composite observation space from its components.""" component_spaces = [] for obs_comp in self.registered_obs_components: @@ -481,7 +483,7 @@ class ObservationsHandler: self._flat_space = spaces.Box(0, 1, (0,)) @property - def space(self): + def space(self) -> spaces.Space: """Observation space, return the flattened version if flatten is True.""" if self.flatten: return self._flat_space @@ -489,7 +491,7 @@ class ObservationsHandler: return self._space @property - def current_observation(self): + def current_observation(self) -> Union[np.ndarray, Tuple[np.ndarray]]: """Current observation, return the flattened version if flatten is True.""" if self.flatten: return self._flat_observation @@ -497,7 +499,7 @@ class ObservationsHandler: return self._observation @classmethod - def from_config(cls, env: "Primaite", obs_space_config: dict): + def from_config(cls, env: "Primaite", obs_space_config: dict) -> "ObservationsHandler": """ Parse a config dictinary, return a new observation handler populated with new observation component objects. @@ -543,7 +545,7 @@ class ObservationsHandler: handler.update_obs() return handler - def describe_structure(self): + def describe_structure(self) -> List[str]: """ Create a list of names for the features of the obs space. diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index b92c434e..5bf843f1 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -5,7 +5,7 @@ import logging import uuid as uuid from pathlib import Path from random import choice, randint, sample, uniform -from typing import Dict, Final, Tuple, Union +from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union import networkx as nx import numpy as np @@ -20,6 +20,7 @@ from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( ActionType, AgentFramework, + AgentIdentifier, FileSystemState, HardwareState, NodePOLInitiator, @@ -48,7 +49,10 @@ from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_nod from primaite.transactions.transaction import Transaction from primaite.utils.session_output_writer import SessionOutputWriter -_LOGGER = getLogger(__name__) +if TYPE_CHECKING: + from logging import Logger + +_LOGGER: "Logger" = getLogger(__name__) class Primaite(Env): @@ -66,7 +70,7 @@ class Primaite(Env): lay_down_config_path: Union[str, Path], session_path: Path, timestamp_str: str, - ): + ) -> None: """ The Primaite constructor. @@ -77,13 +81,14 @@ class Primaite(Env): """ self.session_path: Final[Path] = session_path self.timestamp_str: Final[str] = timestamp_str - self._training_config_path = training_config_path - self._lay_down_config_path = lay_down_config_path + self._training_config_path: Union[str, Path] = training_config_path + self._lay_down_config_path: Union[str, Path] = lay_down_config_path self.training_config: TrainingConfig = training_config.load(training_config_path) _LOGGER.info(f"Using: {str(self.training_config)}") # Number of steps in an episode + self.episode_steps: int if self.training_config.session_type == SessionType.TRAIN: self.episode_steps = self.training_config.num_train_steps elif self.training_config.session_type == SessionType.EVAL: @@ -94,7 +99,7 @@ class Primaite(Env): super(Primaite, self).__init__() # The agent in use - self.agent_identifier = self.training_config.agent_identifier + self.agent_identifier: AgentIdentifier = self.training_config.agent_identifier # Create a dictionary to hold all the nodes self.nodes: Dict[str, NodeUnion] = {} @@ -113,36 +118,38 @@ class Primaite(Env): self.green_iers_reference: Dict[str, IER] = {} # Create a dictionary to hold all the node PoLs (this will come from an external source) + # TODO: figure out type self.node_pol = {} # Create a dictionary to hold all the red agent IERs (this will come from an external source) - self.red_iers = {} + self.red_iers: Dict[str, IER] = {} # Create a dictionary to hold all the red agent node PoLs (this will come from an external source) - self.red_node_pol = {} + self.red_node_pol: Dict[str, NodeStateInstructionRed] = {} # Create the Access Control List - self.acl = AccessControlList() + self.acl: AccessControlList = AccessControlList() # Create a list of services (enums) - self.services_list = [] + self.services_list: List[str] = [] # Create a list of ports - self.ports_list = [] + self.ports_list: List[str] = [] # Create graph (network) - self.network = nx.MultiGraph() + self.network: nx.Graph = nx.MultiGraph() # Create a graph (network) reference - self.network_reference = nx.MultiGraph() + self.network_reference: nx.Graph = nx.MultiGraph() # Create step count - self.step_count = 0 + self.step_count: int = 0 self.total_step_count: int = 0 """The total number of time steps completed.""" # Create step info dictionary + # TODO: figure out type self.step_info = {} # Total reward @@ -152,22 +159,23 @@ class Primaite(Env): self.average_reward: float = 0 # Episode count - self.episode_count = 0 + self.episode_count: int = 0 # Number of nodes - gets a value by examining the nodes dictionary after it's been populated - self.num_nodes = 0 + self.num_nodes: int = 0 # Number of links - gets a value by examining the links dictionary after it's been populated - self.num_links = 0 + self.num_links: int = 0 # Number of services - gets a value when config is loaded - self.num_services = 0 + self.num_services: int = 0 # Number of ports - gets a value when config is loaded - self.num_ports = 0 + self.num_ports: int = 0 # The action type - self.action_type = 0 + # TODO: confirm type + self.action_type: int = 0 # TODO fix up with TrainingConfig # stores the observation config from the yaml, default is NODE_LINK_TABLE @@ -179,7 +187,7 @@ class Primaite(Env): # It will be initialised later. self.obs_handler: ObservationsHandler - self._obs_space_description = None + self._obs_space_description: List[str] = None "The env observation space description for transactions writing" # Open the config file and build the environment laydown @@ -211,9 +219,13 @@ class Primaite(Env): _LOGGER.error("Could not save network diagram", exc_info=True) # Initiate observation space + self.observation_space: spaces.Space + self.env_obs: np.ndarray self.observation_space, self.env_obs = self.init_observations() # Define Action Space - depends on action space type (Node or ACL) + self.action_dict: Dict[int, List[int]] + self.action_space: spaces.Space if self.training_config.action_type == ActionType.NODE: _LOGGER.debug("Action space type NODE selected") # Terms (for node action space): @@ -241,8 +253,12 @@ class Primaite(Env): else: _LOGGER.error(f"Invalid action type selected: {self.training_config.action_type}") - self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=True) - self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=True) + self.episode_av_reward_writer: SessionOutputWriter = SessionOutputWriter( + self, transaction_writer=False, learning_session=True + ) + self.transaction_writer: SessionOutputWriter = SessionOutputWriter( + self, transaction_writer=True, learning_session=True + ) @property def actual_episode_count(self) -> int: @@ -251,7 +267,7 @@ class Primaite(Env): return self.episode_count - 1 return self.episode_count - def set_as_eval(self): + def set_as_eval(self) -> None: """Set the writers to write to eval directories.""" self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=False) self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=False) @@ -260,12 +276,12 @@ class Primaite(Env): self.total_step_count = 0 self.episode_steps = self.training_config.num_eval_steps - def _write_av_reward_per_episode(self): + def _write_av_reward_per_episode(self) -> None: if self.actual_episode_count > 0: csv_data = self.actual_episode_count, self.average_reward self.episode_av_reward_writer.write(csv_data) - def reset(self): + def reset(self) -> np.ndarray: """ AI Gym Reset function. @@ -299,7 +315,7 @@ class Primaite(Env): return self.env_obs - def step(self, action): + def step(self, action: int) -> tuple(np.ndarray, float, bool, Dict): """ AI Gym Step function. @@ -418,7 +434,7 @@ class Primaite(Env): # Return return self.env_obs, reward, done, self.step_info - def close(self): + def close(self) -> None: """Override parent close and close writers.""" # Close files if last episode/step # if self.can_finish: @@ -427,18 +443,18 @@ class Primaite(Env): self.transaction_writer.close() self.episode_av_reward_writer.close() - def init_acl(self): + def init_acl(self) -> None: """Initialise the Access Control List.""" self.acl.remove_all_rules() - def output_link_status(self): + def output_link_status(self) -> None: """Output the link status of all links to the console.""" for link_key, link_value in self.links.items(): _LOGGER.debug("Link ID: " + link_value.get_id()) for protocol in link_value.protocol_list: print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load())) - def interpret_action_and_apply(self, _action): + def interpret_action_and_apply(self, _action: int) -> None: """ Applies agent actions to the nodes and Access Control List. @@ -458,7 +474,7 @@ class Primaite(Env): else: logging.error("Invalid action type found") - def apply_actions_to_nodes(self, _action): + def apply_actions_to_nodes(self, _action: int) -> None: """ Applies agent actions to the nodes.