Continue Adding Typehints

This commit is contained in:
Marek Wolan
2023-07-13 12:25:54 +01:00
parent d2bac4307a
commit 4e4166d4d4
13 changed files with 185 additions and 141 deletions

View File

@@ -5,7 +5,7 @@ import time
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from typing import Dict, Final, Union
from typing import Any, Dict, Final, TYPE_CHECKING, Union
from uuid import uuid4
import yaml
@@ -17,7 +17,13 @@ from primaite.config.training_config import TrainingConfig
from primaite.data_viz.session_plots import plot_av_reward_per_episode
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from logging import Logger
import numpy as np
_LOGGER: "Logger" = getLogger(__name__)
def get_session_path(session_timestamp: datetime) -> Path:
@@ -47,7 +53,7 @@ class AgentSessionABC(ABC):
"""
@abstractmethod
def __init__(self, training_config_path, lay_down_config_path):
def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
"""
Initialise an agent session from config files.
@@ -107,11 +113,11 @@ class AgentSessionABC(ABC):
return path
@property
def uuid(self):
def uuid(self) -> str:
"""The Agent Session UUID."""
return self._uuid
def _write_session_metadata_file(self):
def _write_session_metadata_file(self) -> None:
"""
Write the ``session_metadata.json`` file.
@@ -147,7 +153,7 @@ class AgentSessionABC(ABC):
json.dump(metadata_dict, file)
_LOGGER.debug("Finished writing session metadata file")
def _update_session_metadata_file(self):
def _update_session_metadata_file(self) -> None:
"""
Update the ``session_metadata.json`` file.
@@ -176,7 +182,7 @@ class AgentSessionABC(ABC):
_LOGGER.debug("Finished updating session metadata file")
@abstractmethod
def _setup(self):
def _setup(self) -> None:
_LOGGER.info(
"Welcome to the Primary-level AI Training Environment " f"(PrimAITE) (version: {primaite.__version__})"
)
@@ -186,14 +192,14 @@ class AgentSessionABC(ABC):
self._can_evaluate = False
@abstractmethod
def _save_checkpoint(self):
def _save_checkpoint(self) -> None:
pass
@abstractmethod
def learn(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Train the agent.
@@ -210,8 +216,8 @@ class AgentSessionABC(ABC):
@abstractmethod
def evaluate(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
@@ -224,7 +230,7 @@ class AgentSessionABC(ABC):
_LOGGER.info("Finished evaluation")
@abstractmethod
def _get_latest_checkpoint(self):
def _get_latest_checkpoint(self) -> None:
pass
@classmethod
@@ -264,7 +270,6 @@ class AgentSessionABC(ABC):
msg = f"Failed to load PrimAITE Session, path does not exist: {path}"
_LOGGER.error(msg)
raise FileNotFoundError(msg)
pass
@property
def _saved_agent_path(self) -> Path:
@@ -276,21 +281,21 @@ class AgentSessionABC(ABC):
return self.learning_path / file_name
@abstractmethod
def save(self):
def save(self) -> None:
"""Save the agent."""
pass
@abstractmethod
def export(self):
def export(self) -> None:
"""Export the agent to transportable file format."""
pass
def close(self):
def close(self) -> None:
"""Closes the agent."""
self._env.episode_av_reward_writer.close() # noqa
self._env.transaction_writer.close() # noqa
def _plot_av_reward_per_episode(self, learning_session: bool = True):
def _plot_av_reward_per_episode(self, learning_session: bool = True) -> None:
# self.close()
title = f"PrimAITE Session {self.timestamp_str} "
subtitle = str(self._training_config)
@@ -318,7 +323,7 @@ class HardCodedAgentSessionABC(AgentSessionABC):
implemented.
"""
def __init__(self, training_config_path, lay_down_config_path):
def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
"""
Initialise a hardcoded agent session.
@@ -331,7 +336,7 @@ class HardCodedAgentSessionABC(AgentSessionABC):
super().__init__(training_config_path, lay_down_config_path)
self._setup()
def _setup(self):
def _setup(self) -> None:
self._env: Primaite = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
@@ -342,16 +347,16 @@ class HardCodedAgentSessionABC(AgentSessionABC):
self._can_learn = False
self._can_evaluate = True
def _save_checkpoint(self):
def _save_checkpoint(self) -> None:
pass
def _get_latest_checkpoint(self):
def _get_latest_checkpoint(self) -> None:
pass
def learn(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Train the agent.
@@ -360,13 +365,13 @@ class HardCodedAgentSessionABC(AgentSessionABC):
_LOGGER.warning("Deterministic agents cannot learn")
@abstractmethod
def _calculate_action(self, obs):
def _calculate_action(self, obs: np.ndarray) -> None:
pass
def evaluate(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
@@ -398,14 +403,14 @@ class HardCodedAgentSessionABC(AgentSessionABC):
super().evaluate()
@classmethod
def load(cls):
def load(cls) -> None:
"""Load an agent from file."""
_LOGGER.warning("Deterministic agents cannot be loaded")
def save(self):
def save(self) -> None:
"""Save the agent."""
_LOGGER.warning("Deterministic agents cannot be saved")
def export(self):
def export(self) -> None:
"""Export the agent to transportable file format."""
_LOGGER.warning("Deterministic agents cannot be exported")

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Union
from typing import Dict, List, Union
import numpy as np
@@ -32,7 +32,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
def get_blocked_green_iers(
self, green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[Any, Any]:
) -> Dict[str, IER]:
"""Get blocked green IERs.
:param green_iers: Green IERs to check for being
@@ -60,7 +60,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
return blocked_green_iers
def get_matching_acl_rules_for_ier(self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]):
def get_matching_acl_rules_for_ier(
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[int, ACLRule]:
"""Get list of ACL rules which are relevant to an IER.
:param ier: Information Exchange Request to query against the ACL list
@@ -83,7 +85,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
def get_blocking_acl_rules_for_ier(
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[str, Any]:
) -> Dict[int, ACLRule]:
"""
Get blocking ACL rules for an IER.
@@ -111,7 +113,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
def get_allow_acl_rules_for_ier(
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[str, Any]:
) -> Dict[int, ACLRule]:
"""Get all allowing ACL rules for an IER.
:param ier: Information Exchange Request to query against the ACL list
@@ -141,7 +143,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
acl: AccessControlList,
nodes: Dict[str, Union[ServiceNode, ActiveNode]],
services_list: List[str],
) -> Dict[str, ACLRule]:
) -> Dict[int, ACLRule]:
"""Filter ACL rules to only those which are relevant to the specified nodes.
:param source_node_id: Source node
@@ -186,7 +188,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
acl: AccessControlList,
nodes: Dict[str, NodeUnion],
services_list: List[str],
) -> Dict[str, ACLRule]:
) -> Dict[int, ACLRule]:
"""List ALLOW rules relating to specified nodes.
:param source_node_id: Source node id
@@ -233,7 +235,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
acl: AccessControlList,
nodes: Dict[str, NodeUnion],
services_list: List[str],
) -> Dict[str, ACLRule]:
) -> Dict[int, ACLRule]:
"""List DENY rules relating to specified nodes.
:param source_node_id: Source node id

View File

@@ -4,7 +4,7 @@ import json
import shutil
from datetime import datetime
from pathlib import Path
from typing import Union
from typing import Any, Callable, Dict, TYPE_CHECKING, Union
from uuid import uuid4
from ray.rllib.algorithms import Algorithm
@@ -18,10 +18,14 @@ from primaite.agents.agent import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from logging import Logger
_LOGGER: "Logger" = getLogger(__name__)
def _env_creator(env_config):
# TODO: verify type of env_config
def _env_creator(env_config: Dict[str, Any]) -> Primaite:
return Primaite(
training_config_path=env_config["training_config_path"],
lay_down_config_path=env_config["lay_down_config_path"],
@@ -30,11 +34,12 @@ def _env_creator(env_config):
)
def _custom_log_creator(session_path: Path):
# TODO: verify type hint return type
def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]:
logdir = session_path / "ray_results"
logdir.mkdir(parents=True, exist_ok=True)
def logger_creator(config):
def logger_creator(config: Dict) -> UnifiedLogger:
return UnifiedLogger(config, logdir, loggers=None)
return logger_creator
@@ -43,7 +48,7 @@ def _custom_log_creator(session_path: Path):
class RLlibAgent(AgentSessionABC):
"""An AgentSession class that implements a Ray RLlib agent."""
def __init__(self, training_config_path, lay_down_config_path):
def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
"""
Initialise the RLLib Agent training session.
@@ -82,7 +87,7 @@ class RLlibAgent(AgentSessionABC):
f"{self._training_config.deep_learning_framework}"
)
def _update_session_metadata_file(self):
def _update_session_metadata_file(self) -> None:
"""
Update the ``session_metadata.json`` file.
@@ -110,7 +115,7 @@ class RLlibAgent(AgentSessionABC):
json.dump(metadata_dict, file)
_LOGGER.debug("Finished updating session metadata file")
def _setup(self):
def _setup(self) -> None:
super()._setup()
register_env("primaite", _env_creator)
self._agent_config = self._agent_config_class()
@@ -147,8 +152,8 @@ class RLlibAgent(AgentSessionABC):
def learn(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
@@ -168,8 +173,8 @@ class RLlibAgent(AgentSessionABC):
def evaluate(
self,
**kwargs,
):
**kwargs: None,
) -> None:
"""
Evaluate the agent.
@@ -177,7 +182,7 @@ class RLlibAgent(AgentSessionABC):
"""
raise NotImplementedError
def _get_latest_checkpoint(self):
def _get_latest_checkpoint(self) -> None:
raise NotImplementedError
@classmethod
@@ -185,7 +190,7 @@ class RLlibAgent(AgentSessionABC):
"""Load an agent from file."""
raise NotImplementedError
def save(self, overwrite_existing: bool = True):
def save(self, overwrite_existing: bool = True) -> None:
"""Save the agent."""
# Make temp dir to save in isolation
temp_dir = self.learning_path / str(uuid4())
@@ -205,6 +210,6 @@ class RLlibAgent(AgentSessionABC):
# Drop the temp directory
shutil.rmtree(temp_dir)
def export(self):
def export(self) -> None:
"""Export the agent to transportable file format."""
raise NotImplementedError

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from pathlib import Path
from typing import Union
from typing import Any, TYPE_CHECKING, Union
import numpy as np
from stable_baselines3 import A2C, PPO
@@ -12,13 +12,16 @@ from primaite.agents.agent import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from logging import Logger
_LOGGER: "Logger" = getLogger(__name__)
class SB3Agent(AgentSessionABC):
"""An AgentSession class that implements a Stable Baselines3 agent."""
def __init__(self, training_config_path, lay_down_config_path):
def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
"""
Initialise the SB3 Agent training session.
@@ -57,7 +60,7 @@ class SB3Agent(AgentSessionABC):
self.is_eval = False
def _setup(self):
def _setup(self) -> None:
super()._setup()
self._env = Primaite(
training_config_path=self._training_config_path,
@@ -75,7 +78,7 @@ class SB3Agent(AgentSessionABC):
seed=self._training_config.seed,
)
def _save_checkpoint(self):
def _save_checkpoint(self) -> None:
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._env.episode_count
save_checkpoint = False
@@ -86,13 +89,13 @@ class SB3Agent(AgentSessionABC):
self._agent.save(checkpoint_path)
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
def _get_latest_checkpoint(self):
def _get_latest_checkpoint(self) -> None:
pass
def learn(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Train the agent.
@@ -115,8 +118,8 @@ class SB3Agent(AgentSessionABC):
def evaluate(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
@@ -150,10 +153,10 @@ class SB3Agent(AgentSessionABC):
"""Load an agent from file."""
raise NotImplementedError
def save(self):
def save(self) -> None:
"""Save the agent."""
self._agent.save(self._saved_agent_path)
def export(self):
def export(self) -> None:
"""Export the agent to transportable file format."""
raise NotImplementedError

View File

@@ -1,6 +1,11 @@
from typing import TYPE_CHECKING
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
if TYPE_CHECKING:
import numpy as np
class RandomAgent(HardCodedAgentSessionABC):
"""
@@ -9,7 +14,7 @@ class RandomAgent(HardCodedAgentSessionABC):
Get a completely random action from the action space.
"""
def _calculate_action(self, obs):
def _calculate_action(self, obs: "np.ndarray") -> int:
return self._env.action_space.sample()
@@ -20,7 +25,7 @@ class DummyAgent(HardCodedAgentSessionABC):
All action spaces setup so dummy action is always 0 regardless of action type used.
"""
def _calculate_action(self, obs):
def _calculate_action(self, obs: "np.ndarray") -> int:
return 0
@@ -31,7 +36,7 @@ class DoNothingACLAgent(HardCodedAgentSessionABC):
A valid ACL action that has no effect; does nothing.
"""
def _calculate_action(self, obs):
def _calculate_action(self, obs: "np.ndarray") -> int:
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
nothing_action = transform_action_acl_enum(nothing_action)
nothing_action = get_new_action(nothing_action, self._env.action_dict)
@@ -46,7 +51,7 @@ class DoNothingNodeAgent(HardCodedAgentSessionABC):
A valid Node action that has no effect; does nothing.
"""
def _calculate_action(self, obs):
def _calculate_action(self, obs: "np.ndarray") -> int:
nothing_action = [1, "NONE", "ON", 0]
nothing_action = transform_action_node_enum(nothing_action)
nothing_action = get_new_action(nothing_action, self._env.action_dict)

View File

@@ -38,7 +38,7 @@ def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]:
return new_action
def transform_action_acl_readable(action: List[str]) -> List[Union[str, int]]:
def transform_action_acl_readable(action: List[int]) -> List[Union[str, int]]:
"""
Transform an ACL action to a more readable format.

View File

@@ -1,8 +1,8 @@
from typing import Type, Union
from typing import TypeVar
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.passive_node import PassiveNode
from primaite.nodes.service_node import ServiceNode
NodeUnion: Type = Union[ActiveNode, PassiveNode, ServiceNode]
NodeUnion = TypeVar("NodeUnion", ServiceNode, ActiveNode, PassiveNode)
"""A Union of ActiveNode, PassiveNode, and ServiceNode."""

View File

@@ -5,17 +5,17 @@
class Protocol(object):
"""Protocol class."""
def __init__(self, _name):
def __init__(self, _name: str) -> None:
"""
Initialise a protocol.
:param _name: The name of the protocol
:type _name: str
"""
self.name = _name
self.load = 0 # bps
self.name: str = _name
self.load: int = 0 # bps
def get_name(self):
def get_name(self) -> str:
"""
Gets the protocol name.
@@ -24,7 +24,7 @@ class Protocol(object):
"""
return self.name
def get_load(self):
def get_load(self) -> int:
"""
Gets the protocol load.
@@ -33,7 +33,7 @@ class Protocol(object):
"""
return self.load
def add_load(self, _load):
def add_load(self, _load: int) -> None:
"""
Adds load to the protocol.
@@ -42,6 +42,6 @@ class Protocol(object):
"""
self.load += _load
def clear_load(self):
def clear_load(self) -> None:
"""Clears the load on this protocol."""
self.load = 0

View File

@@ -15,12 +15,12 @@ class Service(object):
:param port: The service port.
:param software_state: The service SoftwareState.
"""
self.name = name
self.port = port
self.software_state = software_state
self.patching_count = 0
self.name: str = name
self.port: str = port
self.software_state: SoftwareState = software_state
self.patching_count: int = 0
def reduce_patching_count(self):
def reduce_patching_count(self) -> None:
"""Reduces the patching count for the service."""
self.patching_count -= 1
if self.patching_count <= 0:

View File

@@ -1,12 +1,15 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
from pathlib import Path
from typing import Any, Dict, Final, Union
from typing import Any, Dict, Final, TYPE_CHECKING, Union
import yaml
from primaite import getLogger, USERS_CONFIG_DIR
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from logging import Logger
_LOGGER: "Logger" = getLogger(__name__)
_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Final, Optional, Union
from typing import Any, Dict, Final, Optional, TYPE_CHECKING, Union
import yaml
@@ -18,7 +18,10 @@ from primaite.common.enums import (
SessionType,
)
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from logging import Logger
_LOGGER: Logger = getLogger(__name__)
_EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training"

View File

@@ -14,17 +14,19 @@ from primaite.nodes.service_node import ServiceNode
# TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking
# Therefore, this avoids circular dependency problem.
if TYPE_CHECKING:
from logging import Logger
from primaite.environment.primaite_env import Primaite
_LOGGER = logging.getLogger(__name__)
_LOGGER: "Logger" = logging.getLogger(__name__)
class AbstractObservationComponent(ABC):
"""Represents a part of the PrimAITE observation space."""
@abstractmethod
def __init__(self, env: "Primaite"):
def __init__(self, env: "Primaite") -> None:
"""
Initialise observation component.
@@ -39,7 +41,7 @@ class AbstractObservationComponent(ABC):
return NotImplemented
@abstractmethod
def update(self):
def update(self) -> None:
"""Update the observation based on the current state of the environment."""
self.current_observation = NotImplemented
@@ -74,7 +76,7 @@ class NodeLinkTable(AbstractObservationComponent):
_MAX_VAL: int = 1_000_000_000
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite"):
def __init__(self, env: "Primaite") -> None:
"""
Initialise a NodeLinkTable observation space component.
@@ -101,7 +103,7 @@ class NodeLinkTable(AbstractObservationComponent):
self.structure = self.generate_structure()
def update(self):
def update(self) -> None:
"""
Update the observation based on current environment state.
@@ -148,7 +150,7 @@ class NodeLinkTable(AbstractObservationComponent):
protocol_index += 1
item_index += 1
def generate_structure(self):
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
nodes = self.env.nodes.values()
links = self.env.links.values()
@@ -211,7 +213,7 @@ class NodeStatuses(AbstractObservationComponent):
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite"):
def __init__(self, env: "Primaite") -> None:
"""
Initialise a NodeStatuses observation component.
@@ -237,7 +239,7 @@ class NodeStatuses(AbstractObservationComponent):
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
self.structure = self.generate_structure()
def update(self):
def update(self) -> None:
"""
Update the observation based on current environment state.
@@ -268,7 +270,7 @@ class NodeStatuses(AbstractObservationComponent):
)
self.current_observation[:] = obs
def generate_structure(self):
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
services = self.env.services_list
@@ -317,7 +319,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
env: "Primaite",
combine_service_traffic: bool = False,
quantisation_levels: int = 5,
):
) -> None:
"""
Initialise a LinkTrafficLevels observation component.
@@ -359,7 +361,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
self.structure = self.generate_structure()
def update(self):
def update(self) -> None:
"""
Update the observation based on current environment state.
@@ -385,7 +387,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
self.current_observation[:] = obs
def generate_structure(self):
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
structure = []
for _, link in self.env.links.items():
@@ -415,7 +417,7 @@ class ObservationsHandler:
"LINK_TRAFFIC_LEVELS": LinkTrafficLevels,
}
def __init__(self):
def __init__(self) -> None:
"""Initialise the observation handler."""
self.registered_obs_components: List[AbstractObservationComponent] = []
@@ -430,7 +432,7 @@ class ObservationsHandler:
self.flatten: bool = False
def update_obs(self):
def update_obs(self) -> None:
"""Fetch fresh information about the environment."""
current_obs = []
for obs in self.registered_obs_components:
@@ -443,7 +445,7 @@ class ObservationsHandler:
self._observation = tuple(current_obs)
self._flat_observation = spaces.flatten(self._space, self._observation)
def register(self, obs_component: AbstractObservationComponent):
def register(self, obs_component: AbstractObservationComponent) -> None:
"""
Add a component for this handler to track.
@@ -453,7 +455,7 @@ class ObservationsHandler:
self.registered_obs_components.append(obs_component)
self.update_space()
def deregister(self, obs_component: AbstractObservationComponent):
def deregister(self, obs_component: AbstractObservationComponent) -> None:
"""
Remove a component from this handler.
@@ -464,7 +466,7 @@ class ObservationsHandler:
self.registered_obs_components.remove(obs_component)
self.update_space()
def update_space(self):
def update_space(self) -> None:
"""Rebuild the handler's composite observation space from its components."""
component_spaces = []
for obs_comp in self.registered_obs_components:
@@ -481,7 +483,7 @@ class ObservationsHandler:
self._flat_space = spaces.Box(0, 1, (0,))
@property
def space(self):
def space(self) -> spaces.Space:
"""Observation space, return the flattened version if flatten is True."""
if self.flatten:
return self._flat_space
@@ -489,7 +491,7 @@ class ObservationsHandler:
return self._space
@property
def current_observation(self):
def current_observation(self) -> Union[np.ndarray, Tuple[np.ndarray]]:
"""Current observation, return the flattened version if flatten is True."""
if self.flatten:
return self._flat_observation
@@ -497,7 +499,7 @@ class ObservationsHandler:
return self._observation
@classmethod
def from_config(cls, env: "Primaite", obs_space_config: dict):
def from_config(cls, env: "Primaite", obs_space_config: dict) -> "ObservationsHandler":
"""
Parse a config dictinary, return a new observation handler populated with new observation component objects.
@@ -543,7 +545,7 @@ class ObservationsHandler:
handler.update_obs()
return handler
def describe_structure(self):
def describe_structure(self) -> List[str]:
"""
Create a list of names for the features of the obs space.

View File

@@ -5,7 +5,7 @@ import logging
import uuid as uuid
from pathlib import Path
from random import choice, randint, sample, uniform
from typing import Dict, Final, Tuple, Union
from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union
import networkx as nx
import numpy as np
@@ -20,6 +20,7 @@ from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
ActionType,
AgentFramework,
AgentIdentifier,
FileSystemState,
HardwareState,
NodePOLInitiator,
@@ -48,7 +49,10 @@ from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_nod
from primaite.transactions.transaction import Transaction
from primaite.utils.session_output_writer import SessionOutputWriter
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from logging import Logger
_LOGGER: "Logger" = getLogger(__name__)
class Primaite(Env):
@@ -66,7 +70,7 @@ class Primaite(Env):
lay_down_config_path: Union[str, Path],
session_path: Path,
timestamp_str: str,
):
) -> None:
"""
The Primaite constructor.
@@ -77,13 +81,14 @@ class Primaite(Env):
"""
self.session_path: Final[Path] = session_path
self.timestamp_str: Final[str] = timestamp_str
self._training_config_path = training_config_path
self._lay_down_config_path = lay_down_config_path
self._training_config_path: Union[str, Path] = training_config_path
self._lay_down_config_path: Union[str, Path] = lay_down_config_path
self.training_config: TrainingConfig = training_config.load(training_config_path)
_LOGGER.info(f"Using: {str(self.training_config)}")
# Number of steps in an episode
self.episode_steps: int
if self.training_config.session_type == SessionType.TRAIN:
self.episode_steps = self.training_config.num_train_steps
elif self.training_config.session_type == SessionType.EVAL:
@@ -94,7 +99,7 @@ class Primaite(Env):
super(Primaite, self).__init__()
# The agent in use
self.agent_identifier = self.training_config.agent_identifier
self.agent_identifier: AgentIdentifier = self.training_config.agent_identifier
# Create a dictionary to hold all the nodes
self.nodes: Dict[str, NodeUnion] = {}
@@ -113,36 +118,38 @@ class Primaite(Env):
self.green_iers_reference: Dict[str, IER] = {}
# Create a dictionary to hold all the node PoLs (this will come from an external source)
# TODO: figure out type
self.node_pol = {}
# Create a dictionary to hold all the red agent IERs (this will come from an external source)
self.red_iers = {}
self.red_iers: Dict[str, IER] = {}
# Create a dictionary to hold all the red agent node PoLs (this will come from an external source)
self.red_node_pol = {}
self.red_node_pol: Dict[str, NodeStateInstructionRed] = {}
# Create the Access Control List
self.acl = AccessControlList()
self.acl: AccessControlList = AccessControlList()
# Create a list of services (enums)
self.services_list = []
self.services_list: List[str] = []
# Create a list of ports
self.ports_list = []
self.ports_list: List[str] = []
# Create graph (network)
self.network = nx.MultiGraph()
self.network: nx.Graph = nx.MultiGraph()
# Create a graph (network) reference
self.network_reference = nx.MultiGraph()
self.network_reference: nx.Graph = nx.MultiGraph()
# Create step count
self.step_count = 0
self.step_count: int = 0
self.total_step_count: int = 0
"""The total number of time steps completed."""
# Create step info dictionary
# TODO: figure out type
self.step_info = {}
# Total reward
@@ -152,22 +159,23 @@ class Primaite(Env):
self.average_reward: float = 0
# Episode count
self.episode_count = 0
self.episode_count: int = 0
# Number of nodes - gets a value by examining the nodes dictionary after it's been populated
self.num_nodes = 0
self.num_nodes: int = 0
# Number of links - gets a value by examining the links dictionary after it's been populated
self.num_links = 0
self.num_links: int = 0
# Number of services - gets a value when config is loaded
self.num_services = 0
self.num_services: int = 0
# Number of ports - gets a value when config is loaded
self.num_ports = 0
self.num_ports: int = 0
# The action type
self.action_type = 0
# TODO: confirm type
self.action_type: int = 0
# TODO fix up with TrainingConfig
# stores the observation config from the yaml, default is NODE_LINK_TABLE
@@ -179,7 +187,7 @@ class Primaite(Env):
# It will be initialised later.
self.obs_handler: ObservationsHandler
self._obs_space_description = None
self._obs_space_description: List[str] = None
"The env observation space description for transactions writing"
# Open the config file and build the environment laydown
@@ -211,9 +219,13 @@ class Primaite(Env):
_LOGGER.error("Could not save network diagram", exc_info=True)
# Initiate observation space
self.observation_space: spaces.Space
self.env_obs: np.ndarray
self.observation_space, self.env_obs = self.init_observations()
# Define Action Space - depends on action space type (Node or ACL)
self.action_dict: Dict[int, List[int]]
self.action_space: spaces.Space
if self.training_config.action_type == ActionType.NODE:
_LOGGER.debug("Action space type NODE selected")
# Terms (for node action space):
@@ -241,8 +253,12 @@ class Primaite(Env):
else:
_LOGGER.error(f"Invalid action type selected: {self.training_config.action_type}")
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=True)
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=True)
self.episode_av_reward_writer: SessionOutputWriter = SessionOutputWriter(
self, transaction_writer=False, learning_session=True
)
self.transaction_writer: SessionOutputWriter = SessionOutputWriter(
self, transaction_writer=True, learning_session=True
)
@property
def actual_episode_count(self) -> int:
@@ -251,7 +267,7 @@ class Primaite(Env):
return self.episode_count - 1
return self.episode_count
def set_as_eval(self):
def set_as_eval(self) -> None:
"""Set the writers to write to eval directories."""
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=False)
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=False)
@@ -260,12 +276,12 @@ class Primaite(Env):
self.total_step_count = 0
self.episode_steps = self.training_config.num_eval_steps
def _write_av_reward_per_episode(self):
def _write_av_reward_per_episode(self) -> None:
if self.actual_episode_count > 0:
csv_data = self.actual_episode_count, self.average_reward
self.episode_av_reward_writer.write(csv_data)
def reset(self):
def reset(self) -> np.ndarray:
"""
AI Gym Reset function.
@@ -299,7 +315,7 @@ class Primaite(Env):
return self.env_obs
def step(self, action):
def step(self, action: int) -> tuple(np.ndarray, float, bool, Dict):
"""
AI Gym Step function.
@@ -418,7 +434,7 @@ class Primaite(Env):
# Return
return self.env_obs, reward, done, self.step_info
def close(self):
def close(self) -> None:
"""Override parent close and close writers."""
# Close files if last episode/step
# if self.can_finish:
@@ -427,18 +443,18 @@ class Primaite(Env):
self.transaction_writer.close()
self.episode_av_reward_writer.close()
def init_acl(self):
def init_acl(self) -> None:
"""Initialise the Access Control List."""
self.acl.remove_all_rules()
def output_link_status(self):
def output_link_status(self) -> None:
"""Output the link status of all links to the console."""
for link_key, link_value in self.links.items():
_LOGGER.debug("Link ID: " + link_value.get_id())
for protocol in link_value.protocol_list:
print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load()))
def interpret_action_and_apply(self, _action):
def interpret_action_and_apply(self, _action: int) -> None:
"""
Applies agent actions to the nodes and Access Control List.
@@ -458,7 +474,7 @@ class Primaite(Env):
else:
logging.error("Invalid action type found")
def apply_actions_to_nodes(self, _action):
def apply_actions_to_nodes(self, _action: int) -> None:
"""
Applies agent actions to the nodes.