Continue Adding Typehints
This commit is contained in:
@@ -5,7 +5,7 @@ import time
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Final, Union
|
||||
from typing import Any, Dict, Final, TYPE_CHECKING, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
@@ -17,7 +17,13 @@ from primaite.config.training_config import TrainingConfig
|
||||
from primaite.data_viz.session_plots import plot_av_reward_per_episode
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
|
||||
|
||||
def get_session_path(session_timestamp: datetime) -> Path:
|
||||
@@ -47,7 +53,7 @@ class AgentSessionABC(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
|
||||
"""
|
||||
Initialise an agent session from config files.
|
||||
|
||||
@@ -107,11 +113,11 @@ class AgentSessionABC(ABC):
|
||||
return path
|
||||
|
||||
@property
|
||||
def uuid(self):
|
||||
def uuid(self) -> str:
|
||||
"""The Agent Session UUID."""
|
||||
return self._uuid
|
||||
|
||||
def _write_session_metadata_file(self):
|
||||
def _write_session_metadata_file(self) -> None:
|
||||
"""
|
||||
Write the ``session_metadata.json`` file.
|
||||
|
||||
@@ -147,7 +153,7 @@ class AgentSessionABC(ABC):
|
||||
json.dump(metadata_dict, file)
|
||||
_LOGGER.debug("Finished writing session metadata file")
|
||||
|
||||
def _update_session_metadata_file(self):
|
||||
def _update_session_metadata_file(self) -> None:
|
||||
"""
|
||||
Update the ``session_metadata.json`` file.
|
||||
|
||||
@@ -176,7 +182,7 @@ class AgentSessionABC(ABC):
|
||||
_LOGGER.debug("Finished updating session metadata file")
|
||||
|
||||
@abstractmethod
|
||||
def _setup(self):
|
||||
def _setup(self) -> None:
|
||||
_LOGGER.info(
|
||||
"Welcome to the Primary-level AI Training Environment " f"(PrimAITE) (version: {primaite.__version__})"
|
||||
)
|
||||
@@ -186,14 +192,14 @@ class AgentSessionABC(ABC):
|
||||
self._can_evaluate = False
|
||||
|
||||
@abstractmethod
|
||||
def _save_checkpoint(self):
|
||||
def _save_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
@@ -210,8 +216,8 @@ class AgentSessionABC(ABC):
|
||||
@abstractmethod
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
@@ -224,7 +230,7 @@ class AgentSessionABC(ABC):
|
||||
_LOGGER.info("Finished evaluation")
|
||||
|
||||
@abstractmethod
|
||||
def _get_latest_checkpoint(self):
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@@ -264,7 +270,6 @@ class AgentSessionABC(ABC):
|
||||
msg = f"Failed to load PrimAITE Session, path does not exist: {path}"
|
||||
_LOGGER.error(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
pass
|
||||
|
||||
@property
|
||||
def _saved_agent_path(self) -> Path:
|
||||
@@ -276,21 +281,21 @@ class AgentSessionABC(ABC):
|
||||
return self.learning_path / file_name
|
||||
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def export(self):
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Closes the agent."""
|
||||
self._env.episode_av_reward_writer.close() # noqa
|
||||
self._env.transaction_writer.close() # noqa
|
||||
|
||||
def _plot_av_reward_per_episode(self, learning_session: bool = True):
|
||||
def _plot_av_reward_per_episode(self, learning_session: bool = True) -> None:
|
||||
# self.close()
|
||||
title = f"PrimAITE Session {self.timestamp_str} "
|
||||
subtitle = str(self._training_config)
|
||||
@@ -318,7 +323,7 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
implemented.
|
||||
"""
|
||||
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
|
||||
"""
|
||||
Initialise a hardcoded agent session.
|
||||
|
||||
@@ -331,7 +336,7 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
self._setup()
|
||||
|
||||
def _setup(self):
|
||||
def _setup(self) -> None:
|
||||
self._env: Primaite = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
@@ -342,16 +347,16 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
self._can_learn = False
|
||||
self._can_evaluate = True
|
||||
|
||||
def _save_checkpoint(self):
|
||||
def _save_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
@@ -360,13 +365,13 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
_LOGGER.warning("Deterministic agents cannot learn")
|
||||
|
||||
@abstractmethod
|
||||
def _calculate_action(self, obs):
|
||||
def _calculate_action(self, obs: np.ndarray) -> None:
|
||||
pass
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
@@ -398,14 +403,14 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
super().evaluate()
|
||||
|
||||
@classmethod
|
||||
def load(cls):
|
||||
def load(cls) -> None:
|
||||
"""Load an agent from file."""
|
||||
_LOGGER.warning("Deterministic agents cannot be loaded")
|
||||
|
||||
def save(self):
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
_LOGGER.warning("Deterministic agents cannot be saved")
|
||||
|
||||
def export(self):
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
_LOGGER.warning("Deterministic agents cannot be exported")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -32,7 +32,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
def get_blocked_green_iers(
|
||||
self, green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[Any, Any]:
|
||||
) -> Dict[str, IER]:
|
||||
"""Get blocked green IERs.
|
||||
|
||||
:param green_iers: Green IERs to check for being
|
||||
@@ -60,7 +60,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
return blocked_green_iers
|
||||
|
||||
def get_matching_acl_rules_for_ier(self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]):
|
||||
def get_matching_acl_rules_for_ier(
|
||||
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Get list of ACL rules which are relevant to an IER.
|
||||
|
||||
:param ier: Information Exchange Request to query against the ACL list
|
||||
@@ -83,7 +85,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
def get_blocking_acl_rules_for_ier(
|
||||
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[str, Any]:
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""
|
||||
Get blocking ACL rules for an IER.
|
||||
|
||||
@@ -111,7 +113,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
def get_allow_acl_rules_for_ier(
|
||||
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[str, Any]:
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Get all allowing ACL rules for an IER.
|
||||
|
||||
:param ier: Information Exchange Request to query against the ACL list
|
||||
@@ -141,7 +143,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
acl: AccessControlList,
|
||||
nodes: Dict[str, Union[ServiceNode, ActiveNode]],
|
||||
services_list: List[str],
|
||||
) -> Dict[str, ACLRule]:
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Filter ACL rules to only those which are relevant to the specified nodes.
|
||||
|
||||
:param source_node_id: Source node
|
||||
@@ -186,7 +188,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
acl: AccessControlList,
|
||||
nodes: Dict[str, NodeUnion],
|
||||
services_list: List[str],
|
||||
) -> Dict[str, ACLRule]:
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""List ALLOW rules relating to specified nodes.
|
||||
|
||||
:param source_node_id: Source node id
|
||||
@@ -233,7 +235,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
acl: AccessControlList,
|
||||
nodes: Dict[str, NodeUnion],
|
||||
services_list: List[str],
|
||||
) -> Dict[str, ACLRule]:
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""List DENY rules relating to specified nodes.
|
||||
|
||||
:param source_node_id: Source node id
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Any, Callable, Dict, TYPE_CHECKING, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from ray.rllib.algorithms import Algorithm
|
||||
@@ -18,10 +18,14 @@ from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
|
||||
|
||||
def _env_creator(env_config):
|
||||
# TODO: verify type of env_config
|
||||
def _env_creator(env_config: Dict[str, Any]) -> Primaite:
|
||||
return Primaite(
|
||||
training_config_path=env_config["training_config_path"],
|
||||
lay_down_config_path=env_config["lay_down_config_path"],
|
||||
@@ -30,11 +34,12 @@ def _env_creator(env_config):
|
||||
)
|
||||
|
||||
|
||||
def _custom_log_creator(session_path: Path):
|
||||
# TODO: verify type hint return type
|
||||
def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]:
|
||||
logdir = session_path / "ray_results"
|
||||
logdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def logger_creator(config):
|
||||
def logger_creator(config: Dict) -> UnifiedLogger:
|
||||
return UnifiedLogger(config, logdir, loggers=None)
|
||||
|
||||
return logger_creator
|
||||
@@ -43,7 +48,7 @@ def _custom_log_creator(session_path: Path):
|
||||
class RLlibAgent(AgentSessionABC):
|
||||
"""An AgentSession class that implements a Ray RLlib agent."""
|
||||
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
|
||||
"""
|
||||
Initialise the RLLib Agent training session.
|
||||
|
||||
@@ -82,7 +87,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
f"{self._training_config.deep_learning_framework}"
|
||||
)
|
||||
|
||||
def _update_session_metadata_file(self):
|
||||
def _update_session_metadata_file(self) -> None:
|
||||
"""
|
||||
Update the ``session_metadata.json`` file.
|
||||
|
||||
@@ -110,7 +115,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
json.dump(metadata_dict, file)
|
||||
_LOGGER.debug("Finished updating session metadata file")
|
||||
|
||||
def _setup(self):
|
||||
def _setup(self) -> None:
|
||||
super()._setup()
|
||||
register_env("primaite", _env_creator)
|
||||
self._agent_config = self._agent_config_class()
|
||||
@@ -147,8 +152,8 @@ class RLlibAgent(AgentSessionABC):
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
@@ -168,8 +173,8 @@ class RLlibAgent(AgentSessionABC):
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: None,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
@@ -177,7 +182,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@@ -185,7 +190,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
"""Load an agent from file."""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self, overwrite_existing: bool = True):
|
||||
def save(self, overwrite_existing: bool = True) -> None:
|
||||
"""Save the agent."""
|
||||
# Make temp dir to save in isolation
|
||||
temp_dir = self.learning_path / str(uuid4())
|
||||
@@ -205,6 +210,6 @@ class RLlibAgent(AgentSessionABC):
|
||||
# Drop the temp directory
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
def export(self):
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Any, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import A2C, PPO
|
||||
@@ -12,13 +12,16 @@ from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
|
||||
|
||||
class SB3Agent(AgentSessionABC):
|
||||
"""An AgentSession class that implements a Stable Baselines3 agent."""
|
||||
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
|
||||
"""
|
||||
Initialise the SB3 Agent training session.
|
||||
|
||||
@@ -57,7 +60,7 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
self.is_eval = False
|
||||
|
||||
def _setup(self):
|
||||
def _setup(self) -> None:
|
||||
super()._setup()
|
||||
self._env = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
@@ -75,7 +78,7 @@ class SB3Agent(AgentSessionABC):
|
||||
seed=self._training_config.seed,
|
||||
)
|
||||
|
||||
def _save_checkpoint(self):
|
||||
def _save_checkpoint(self) -> None:
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._env.episode_count
|
||||
save_checkpoint = False
|
||||
@@ -86,13 +89,13 @@ class SB3Agent(AgentSessionABC):
|
||||
self._agent.save(checkpoint_path)
|
||||
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
@@ -115,8 +118,8 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
@@ -150,10 +153,10 @@ class SB3Agent(AgentSessionABC):
|
||||
"""Load an agent from file."""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self):
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
self._agent.save(self._saved_agent_path)
|
||||
|
||||
def export(self):
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from primaite.agents.agent import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
|
||||
|
||||
class RandomAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
@@ -9,7 +14,7 @@ class RandomAgent(HardCodedAgentSessionABC):
|
||||
Get a completely random action from the action space.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
def _calculate_action(self, obs: "np.ndarray") -> int:
|
||||
return self._env.action_space.sample()
|
||||
|
||||
|
||||
@@ -20,7 +25,7 @@ class DummyAgent(HardCodedAgentSessionABC):
|
||||
All action spaces setup so dummy action is always 0 regardless of action type used.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
def _calculate_action(self, obs: "np.ndarray") -> int:
|
||||
return 0
|
||||
|
||||
|
||||
@@ -31,7 +36,7 @@ class DoNothingACLAgent(HardCodedAgentSessionABC):
|
||||
A valid ACL action that has no effect; does nothing.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
def _calculate_action(self, obs: "np.ndarray") -> int:
|
||||
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
|
||||
nothing_action = transform_action_acl_enum(nothing_action)
|
||||
nothing_action = get_new_action(nothing_action, self._env.action_dict)
|
||||
@@ -46,7 +51,7 @@ class DoNothingNodeAgent(HardCodedAgentSessionABC):
|
||||
A valid Node action that has no effect; does nothing.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
def _calculate_action(self, obs: "np.ndarray") -> int:
|
||||
nothing_action = [1, "NONE", "ON", 0]
|
||||
nothing_action = transform_action_node_enum(nothing_action)
|
||||
nothing_action = get_new_action(nothing_action, self._env.action_dict)
|
||||
|
||||
@@ -38,7 +38,7 @@ def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]:
|
||||
return new_action
|
||||
|
||||
|
||||
def transform_action_acl_readable(action: List[str]) -> List[Union[str, int]]:
|
||||
def transform_action_acl_readable(action: List[int]) -> List[Union[str, int]]:
|
||||
"""
|
||||
Transform an ACL action to a more readable format.
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Type, Union
|
||||
from typing import TypeVar
|
||||
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.passive_node import PassiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
NodeUnion: Type = Union[ActiveNode, PassiveNode, ServiceNode]
|
||||
NodeUnion = TypeVar("NodeUnion", ServiceNode, ActiveNode, PassiveNode)
|
||||
"""A Union of ActiveNode, PassiveNode, and ServiceNode."""
|
||||
|
||||
@@ -5,17 +5,17 @@
|
||||
class Protocol(object):
|
||||
"""Protocol class."""
|
||||
|
||||
def __init__(self, _name):
|
||||
def __init__(self, _name: str) -> None:
|
||||
"""
|
||||
Initialise a protocol.
|
||||
|
||||
:param _name: The name of the protocol
|
||||
:type _name: str
|
||||
"""
|
||||
self.name = _name
|
||||
self.load = 0 # bps
|
||||
self.name: str = _name
|
||||
self.load: int = 0 # bps
|
||||
|
||||
def get_name(self):
|
||||
def get_name(self) -> str:
|
||||
"""
|
||||
Gets the protocol name.
|
||||
|
||||
@@ -24,7 +24,7 @@ class Protocol(object):
|
||||
"""
|
||||
return self.name
|
||||
|
||||
def get_load(self):
|
||||
def get_load(self) -> int:
|
||||
"""
|
||||
Gets the protocol load.
|
||||
|
||||
@@ -33,7 +33,7 @@ class Protocol(object):
|
||||
"""
|
||||
return self.load
|
||||
|
||||
def add_load(self, _load):
|
||||
def add_load(self, _load: int) -> None:
|
||||
"""
|
||||
Adds load to the protocol.
|
||||
|
||||
@@ -42,6 +42,6 @@ class Protocol(object):
|
||||
"""
|
||||
self.load += _load
|
||||
|
||||
def clear_load(self):
|
||||
def clear_load(self) -> None:
|
||||
"""Clears the load on this protocol."""
|
||||
self.load = 0
|
||||
|
||||
@@ -15,12 +15,12 @@ class Service(object):
|
||||
:param port: The service port.
|
||||
:param software_state: The service SoftwareState.
|
||||
"""
|
||||
self.name = name
|
||||
self.port = port
|
||||
self.software_state = software_state
|
||||
self.patching_count = 0
|
||||
self.name: str = name
|
||||
self.port: str = port
|
||||
self.software_state: SoftwareState = software_state
|
||||
self.patching_count: int = 0
|
||||
|
||||
def reduce_patching_count(self):
|
||||
def reduce_patching_count(self) -> None:
|
||||
"""Reduces the patching count for the service."""
|
||||
self.patching_count -= 1
|
||||
if self.patching_count <= 0:
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Union
|
||||
from typing import Any, Dict, Final, TYPE_CHECKING, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger, USERS_CONFIG_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Optional, Union
|
||||
from typing import Any, Dict, Final, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -18,7 +18,10 @@ from primaite.common.enums import (
|
||||
SessionType,
|
||||
)
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training"
|
||||
|
||||
|
||||
@@ -14,17 +14,19 @@ from primaite.nodes.service_node import ServiceNode
|
||||
# TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking
|
||||
# Therefore, this avoids circular dependency problem.
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_LOGGER: "Logger" = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbstractObservationComponent(ABC):
|
||||
"""Represents a part of the PrimAITE observation space."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, env: "Primaite"):
|
||||
def __init__(self, env: "Primaite") -> None:
|
||||
"""
|
||||
Initialise observation component.
|
||||
|
||||
@@ -39,7 +41,7 @@ class AbstractObservationComponent(ABC):
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def update(self):
|
||||
def update(self) -> None:
|
||||
"""Update the observation based on the current state of the environment."""
|
||||
self.current_observation = NotImplemented
|
||||
|
||||
@@ -74,7 +76,7 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
_MAX_VAL: int = 1_000_000_000
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite"):
|
||||
def __init__(self, env: "Primaite") -> None:
|
||||
"""
|
||||
Initialise a NodeLinkTable observation space component.
|
||||
|
||||
@@ -101,7 +103,7 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self):
|
||||
def update(self) -> None:
|
||||
"""
|
||||
Update the observation based on current environment state.
|
||||
|
||||
@@ -148,7 +150,7 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
protocol_index += 1
|
||||
item_index += 1
|
||||
|
||||
def generate_structure(self):
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
nodes = self.env.nodes.values()
|
||||
links = self.env.links.values()
|
||||
@@ -211,7 +213,7 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite"):
|
||||
def __init__(self, env: "Primaite") -> None:
|
||||
"""
|
||||
Initialise a NodeStatuses observation component.
|
||||
|
||||
@@ -237,7 +239,7 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self):
|
||||
def update(self) -> None:
|
||||
"""
|
||||
Update the observation based on current environment state.
|
||||
|
||||
@@ -268,7 +270,7 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
)
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self):
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
services = self.env.services_list
|
||||
|
||||
@@ -317,7 +319,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
env: "Primaite",
|
||||
combine_service_traffic: bool = False,
|
||||
quantisation_levels: int = 5,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a LinkTrafficLevels observation component.
|
||||
|
||||
@@ -359,7 +361,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self):
|
||||
def update(self) -> None:
|
||||
"""
|
||||
Update the observation based on current environment state.
|
||||
|
||||
@@ -385,7 +387,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self):
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
structure = []
|
||||
for _, link in self.env.links.items():
|
||||
@@ -415,7 +417,7 @@ class ObservationsHandler:
|
||||
"LINK_TRAFFIC_LEVELS": LinkTrafficLevels,
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""Initialise the observation handler."""
|
||||
self.registered_obs_components: List[AbstractObservationComponent] = []
|
||||
|
||||
@@ -430,7 +432,7 @@ class ObservationsHandler:
|
||||
|
||||
self.flatten: bool = False
|
||||
|
||||
def update_obs(self):
|
||||
def update_obs(self) -> None:
|
||||
"""Fetch fresh information about the environment."""
|
||||
current_obs = []
|
||||
for obs in self.registered_obs_components:
|
||||
@@ -443,7 +445,7 @@ class ObservationsHandler:
|
||||
self._observation = tuple(current_obs)
|
||||
self._flat_observation = spaces.flatten(self._space, self._observation)
|
||||
|
||||
def register(self, obs_component: AbstractObservationComponent):
|
||||
def register(self, obs_component: AbstractObservationComponent) -> None:
|
||||
"""
|
||||
Add a component for this handler to track.
|
||||
|
||||
@@ -453,7 +455,7 @@ class ObservationsHandler:
|
||||
self.registered_obs_components.append(obs_component)
|
||||
self.update_space()
|
||||
|
||||
def deregister(self, obs_component: AbstractObservationComponent):
|
||||
def deregister(self, obs_component: AbstractObservationComponent) -> None:
|
||||
"""
|
||||
Remove a component from this handler.
|
||||
|
||||
@@ -464,7 +466,7 @@ class ObservationsHandler:
|
||||
self.registered_obs_components.remove(obs_component)
|
||||
self.update_space()
|
||||
|
||||
def update_space(self):
|
||||
def update_space(self) -> None:
|
||||
"""Rebuild the handler's composite observation space from its components."""
|
||||
component_spaces = []
|
||||
for obs_comp in self.registered_obs_components:
|
||||
@@ -481,7 +483,7 @@ class ObservationsHandler:
|
||||
self._flat_space = spaces.Box(0, 1, (0,))
|
||||
|
||||
@property
|
||||
def space(self):
|
||||
def space(self) -> spaces.Space:
|
||||
"""Observation space, return the flattened version if flatten is True."""
|
||||
if self.flatten:
|
||||
return self._flat_space
|
||||
@@ -489,7 +491,7 @@ class ObservationsHandler:
|
||||
return self._space
|
||||
|
||||
@property
|
||||
def current_observation(self):
|
||||
def current_observation(self) -> Union[np.ndarray, Tuple[np.ndarray]]:
|
||||
"""Current observation, return the flattened version if flatten is True."""
|
||||
if self.flatten:
|
||||
return self._flat_observation
|
||||
@@ -497,7 +499,7 @@ class ObservationsHandler:
|
||||
return self._observation
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, env: "Primaite", obs_space_config: dict):
|
||||
def from_config(cls, env: "Primaite", obs_space_config: dict) -> "ObservationsHandler":
|
||||
"""
|
||||
Parse a config dictinary, return a new observation handler populated with new observation component objects.
|
||||
|
||||
@@ -543,7 +545,7 @@ class ObservationsHandler:
|
||||
handler.update_obs()
|
||||
return handler
|
||||
|
||||
def describe_structure(self):
|
||||
def describe_structure(self) -> List[str]:
|
||||
"""
|
||||
Create a list of names for the features of the obs space.
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
import uuid as uuid
|
||||
from pathlib import Path
|
||||
from random import choice, randint, sample, uniform
|
||||
from typing import Dict, Final, Tuple, Union
|
||||
from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
@@ -20,6 +20,7 @@ from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import (
|
||||
ActionType,
|
||||
AgentFramework,
|
||||
AgentIdentifier,
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
NodePOLInitiator,
|
||||
@@ -48,7 +49,10 @@ from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_nod
|
||||
from primaite.transactions.transaction import Transaction
|
||||
from primaite.utils.session_output_writer import SessionOutputWriter
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
|
||||
|
||||
class Primaite(Env):
|
||||
@@ -66,7 +70,7 @@ class Primaite(Env):
|
||||
lay_down_config_path: Union[str, Path],
|
||||
session_path: Path,
|
||||
timestamp_str: str,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
The Primaite constructor.
|
||||
|
||||
@@ -77,13 +81,14 @@ class Primaite(Env):
|
||||
"""
|
||||
self.session_path: Final[Path] = session_path
|
||||
self.timestamp_str: Final[str] = timestamp_str
|
||||
self._training_config_path = training_config_path
|
||||
self._lay_down_config_path = lay_down_config_path
|
||||
self._training_config_path: Union[str, Path] = training_config_path
|
||||
self._lay_down_config_path: Union[str, Path] = lay_down_config_path
|
||||
|
||||
self.training_config: TrainingConfig = training_config.load(training_config_path)
|
||||
_LOGGER.info(f"Using: {str(self.training_config)}")
|
||||
|
||||
# Number of steps in an episode
|
||||
self.episode_steps: int
|
||||
if self.training_config.session_type == SessionType.TRAIN:
|
||||
self.episode_steps = self.training_config.num_train_steps
|
||||
elif self.training_config.session_type == SessionType.EVAL:
|
||||
@@ -94,7 +99,7 @@ class Primaite(Env):
|
||||
super(Primaite, self).__init__()
|
||||
|
||||
# The agent in use
|
||||
self.agent_identifier = self.training_config.agent_identifier
|
||||
self.agent_identifier: AgentIdentifier = self.training_config.agent_identifier
|
||||
|
||||
# Create a dictionary to hold all the nodes
|
||||
self.nodes: Dict[str, NodeUnion] = {}
|
||||
@@ -113,36 +118,38 @@ class Primaite(Env):
|
||||
self.green_iers_reference: Dict[str, IER] = {}
|
||||
|
||||
# Create a dictionary to hold all the node PoLs (this will come from an external source)
|
||||
# TODO: figure out type
|
||||
self.node_pol = {}
|
||||
|
||||
# Create a dictionary to hold all the red agent IERs (this will come from an external source)
|
||||
self.red_iers = {}
|
||||
self.red_iers: Dict[str, IER] = {}
|
||||
|
||||
# Create a dictionary to hold all the red agent node PoLs (this will come from an external source)
|
||||
self.red_node_pol = {}
|
||||
self.red_node_pol: Dict[str, NodeStateInstructionRed] = {}
|
||||
|
||||
# Create the Access Control List
|
||||
self.acl = AccessControlList()
|
||||
self.acl: AccessControlList = AccessControlList()
|
||||
|
||||
# Create a list of services (enums)
|
||||
self.services_list = []
|
||||
self.services_list: List[str] = []
|
||||
|
||||
# Create a list of ports
|
||||
self.ports_list = []
|
||||
self.ports_list: List[str] = []
|
||||
|
||||
# Create graph (network)
|
||||
self.network = nx.MultiGraph()
|
||||
self.network: nx.Graph = nx.MultiGraph()
|
||||
|
||||
# Create a graph (network) reference
|
||||
self.network_reference = nx.MultiGraph()
|
||||
self.network_reference: nx.Graph = nx.MultiGraph()
|
||||
|
||||
# Create step count
|
||||
self.step_count = 0
|
||||
self.step_count: int = 0
|
||||
|
||||
self.total_step_count: int = 0
|
||||
"""The total number of time steps completed."""
|
||||
|
||||
# Create step info dictionary
|
||||
# TODO: figure out type
|
||||
self.step_info = {}
|
||||
|
||||
# Total reward
|
||||
@@ -152,22 +159,23 @@ class Primaite(Env):
|
||||
self.average_reward: float = 0
|
||||
|
||||
# Episode count
|
||||
self.episode_count = 0
|
||||
self.episode_count: int = 0
|
||||
|
||||
# Number of nodes - gets a value by examining the nodes dictionary after it's been populated
|
||||
self.num_nodes = 0
|
||||
self.num_nodes: int = 0
|
||||
|
||||
# Number of links - gets a value by examining the links dictionary after it's been populated
|
||||
self.num_links = 0
|
||||
self.num_links: int = 0
|
||||
|
||||
# Number of services - gets a value when config is loaded
|
||||
self.num_services = 0
|
||||
self.num_services: int = 0
|
||||
|
||||
# Number of ports - gets a value when config is loaded
|
||||
self.num_ports = 0
|
||||
self.num_ports: int = 0
|
||||
|
||||
# The action type
|
||||
self.action_type = 0
|
||||
# TODO: confirm type
|
||||
self.action_type: int = 0
|
||||
|
||||
# TODO fix up with TrainingConfig
|
||||
# stores the observation config from the yaml, default is NODE_LINK_TABLE
|
||||
@@ -179,7 +187,7 @@ class Primaite(Env):
|
||||
# It will be initialised later.
|
||||
self.obs_handler: ObservationsHandler
|
||||
|
||||
self._obs_space_description = None
|
||||
self._obs_space_description: List[str] = None
|
||||
"The env observation space description for transactions writing"
|
||||
|
||||
# Open the config file and build the environment laydown
|
||||
@@ -211,9 +219,13 @@ class Primaite(Env):
|
||||
_LOGGER.error("Could not save network diagram", exc_info=True)
|
||||
|
||||
# Initiate observation space
|
||||
self.observation_space: spaces.Space
|
||||
self.env_obs: np.ndarray
|
||||
self.observation_space, self.env_obs = self.init_observations()
|
||||
|
||||
# Define Action Space - depends on action space type (Node or ACL)
|
||||
self.action_dict: Dict[int, List[int]]
|
||||
self.action_space: spaces.Space
|
||||
if self.training_config.action_type == ActionType.NODE:
|
||||
_LOGGER.debug("Action space type NODE selected")
|
||||
# Terms (for node action space):
|
||||
@@ -241,8 +253,12 @@ class Primaite(Env):
|
||||
else:
|
||||
_LOGGER.error(f"Invalid action type selected: {self.training_config.action_type}")
|
||||
|
||||
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=True)
|
||||
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=True)
|
||||
self.episode_av_reward_writer: SessionOutputWriter = SessionOutputWriter(
|
||||
self, transaction_writer=False, learning_session=True
|
||||
)
|
||||
self.transaction_writer: SessionOutputWriter = SessionOutputWriter(
|
||||
self, transaction_writer=True, learning_session=True
|
||||
)
|
||||
|
||||
@property
|
||||
def actual_episode_count(self) -> int:
|
||||
@@ -251,7 +267,7 @@ class Primaite(Env):
|
||||
return self.episode_count - 1
|
||||
return self.episode_count
|
||||
|
||||
def set_as_eval(self):
|
||||
def set_as_eval(self) -> None:
|
||||
"""Set the writers to write to eval directories."""
|
||||
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=False)
|
||||
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=False)
|
||||
@@ -260,12 +276,12 @@ class Primaite(Env):
|
||||
self.total_step_count = 0
|
||||
self.episode_steps = self.training_config.num_eval_steps
|
||||
|
||||
def _write_av_reward_per_episode(self):
|
||||
def _write_av_reward_per_episode(self) -> None:
|
||||
if self.actual_episode_count > 0:
|
||||
csv_data = self.actual_episode_count, self.average_reward
|
||||
self.episode_av_reward_writer.write(csv_data)
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> np.ndarray:
|
||||
"""
|
||||
AI Gym Reset function.
|
||||
|
||||
@@ -299,7 +315,7 @@ class Primaite(Env):
|
||||
|
||||
return self.env_obs
|
||||
|
||||
def step(self, action):
|
||||
def step(self, action: int) -> tuple(np.ndarray, float, bool, Dict):
|
||||
"""
|
||||
AI Gym Step function.
|
||||
|
||||
@@ -418,7 +434,7 @@ class Primaite(Env):
|
||||
# Return
|
||||
return self.env_obs, reward, done, self.step_info
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Override parent close and close writers."""
|
||||
# Close files if last episode/step
|
||||
# if self.can_finish:
|
||||
@@ -427,18 +443,18 @@ class Primaite(Env):
|
||||
self.transaction_writer.close()
|
||||
self.episode_av_reward_writer.close()
|
||||
|
||||
def init_acl(self):
|
||||
def init_acl(self) -> None:
|
||||
"""Initialise the Access Control List."""
|
||||
self.acl.remove_all_rules()
|
||||
|
||||
def output_link_status(self):
|
||||
def output_link_status(self) -> None:
|
||||
"""Output the link status of all links to the console."""
|
||||
for link_key, link_value in self.links.items():
|
||||
_LOGGER.debug("Link ID: " + link_value.get_id())
|
||||
for protocol in link_value.protocol_list:
|
||||
print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load()))
|
||||
|
||||
def interpret_action_and_apply(self, _action):
|
||||
def interpret_action_and_apply(self, _action: int) -> None:
|
||||
"""
|
||||
Applies agent actions to the nodes and Access Control List.
|
||||
|
||||
@@ -458,7 +474,7 @@ class Primaite(Env):
|
||||
else:
|
||||
logging.error("Invalid action type found")
|
||||
|
||||
def apply_actions_to_nodes(self, _action):
|
||||
def apply_actions_to_nodes(self, _action: int) -> None:
|
||||
"""
|
||||
Applies agent actions to the nodes.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user