Added type hints

This commit is contained in:
Marek Wolan
2023-07-14 12:01:38 +01:00
parent a923d818d3
commit c57ed6edcd
16 changed files with 166 additions and 128 deletions

View File

@@ -19,7 +19,7 @@ app = typer.Typer()
@app.command()
def build_dirs():
def build_dirs() -> None:
"""Build the PrimAITE app directories."""
from primaite.setup import setup_app_dirs
@@ -27,7 +27,7 @@ def build_dirs():
@app.command()
def reset_notebooks(overwrite: bool = True):
def reset_notebooks(overwrite: bool = True) -> None:
"""
Force a reset of the demo notebooks in the users notebooks directory.
@@ -39,7 +39,7 @@ def reset_notebooks(overwrite: bool = True):
@app.command()
def logs(last_n: Annotated[int, typer.Option("-n")]):
def logs(last_n: Annotated[int, typer.Option("-n")]) -> None:
"""
Print the PrimAITE log file.
@@ -60,7 +60,7 @@ _LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # n
@app.command()
def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None):
def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None) -> None:
"""
View or set the PrimAITE Log Level.
@@ -88,7 +88,7 @@ def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None):
@app.command()
def notebooks():
def notebooks() -> None:
"""Start Jupyter Lab in the users PrimAITE notebooks directory."""
from primaite.notebooks import start_jupyter_session
@@ -96,7 +96,7 @@ def notebooks():
@app.command()
def version():
def version() -> None:
"""Get the installed PrimAITE version number."""
import primaite
@@ -104,7 +104,7 @@ def version():
@app.command()
def clean_up():
def clean_up() -> None:
"""Cleans up left over files from previous version installations."""
from primaite.setup import old_installation_clean_up
@@ -112,7 +112,7 @@ def clean_up():
@app.command()
def setup(overwrite_existing: bool = True):
def setup(overwrite_existing: bool = True) -> None:
"""
Perform the PrimAITE first-time setup.
@@ -151,7 +151,7 @@ def setup(overwrite_existing: bool = True):
@app.command()
def session(tc: Optional[str] = None, ldc: Optional[str] = None):
def session(tc: Optional[str] = None, ldc: Optional[str] = None) -> None:
"""
Run a PrimAITE session.
@@ -177,7 +177,7 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None):
@app.command()
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None):
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None) -> None:
"""
View or set the plotly template for Session plots.

View File

@@ -13,7 +13,7 @@ _LOGGER = getLogger(__name__)
def run(
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
):
) -> None:
"""
Run the PrimAITE Session.

View File

@@ -3,7 +3,7 @@
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from primaite.common.enums import HardwareState, NodePOLType, SoftwareState
from primaite.common.enums import FileSystemState, HardwareState, NodePOLType, SoftwareState
class NodeStateInstructionGreen(object):
@@ -15,9 +15,9 @@ class NodeStateInstructionGreen(object):
_start_step: int,
_end_step: int,
_node_id: str,
_node_pol_type,
_service_name,
_state,
_node_pol_type: "NodePOLType",
_service_name: str,
_state: Union["HardwareState", "SoftwareState", "FileSystemState"],
):
"""
Initialise the Node State Instruction.
@@ -37,9 +37,9 @@ class NodeStateInstructionGreen(object):
self.node_pol_type: "NodePOLType" = _node_pol_type
self.service_name: str = _service_name # Not used when not a service instruction
# TODO: confirm type of state
self.state: Union["HardwareState", "SoftwareState"] = _state
self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _state
def get_start_step(self):
def get_start_step(self) -> int:
"""
Gets the start step.
@@ -48,7 +48,7 @@ class NodeStateInstructionGreen(object):
"""
return self.start_step
def get_end_step(self):
def get_end_step(self) -> int:
"""
Gets the end step.
@@ -57,7 +57,7 @@ class NodeStateInstructionGreen(object):
"""
return self.end_step
def get_node_id(self):
def get_node_id(self) -> str:
"""
Gets the node ID.
@@ -66,7 +66,7 @@ class NodeStateInstructionGreen(object):
"""
return self.node_id
def get_node_pol_type(self):
def get_node_pol_type(self) -> "NodePOLType":
"""
Gets the node pattern of life type (enum).
@@ -75,7 +75,7 @@ class NodeStateInstructionGreen(object):
"""
return self.node_pol_type
def get_service_name(self):
def get_service_name(self) -> str:
"""
Gets the service name.
@@ -84,7 +84,7 @@ class NodeStateInstructionGreen(object):
"""
return self.service_name
def get_state(self):
def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]:
"""
Gets the state (node or service).

View File

@@ -1,9 +1,13 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""Defines node behaviour for Green PoL."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Union
from primaite.common.enums import NodePOLType
if TYPE_CHECKING:
from primaite.common.enums import FileSystemState, HardwareState, NodePOLInitiator, SoftwareState
@dataclass()
class NodeStateInstructionRed(object):
@@ -11,18 +15,18 @@ class NodeStateInstructionRed(object):
def __init__(
self,
_id,
_start_step,
_end_step,
_target_node_id,
_pol_initiator,
_id: str,
_start_step: int,
_end_step: int,
_target_node_id: str,
_pol_initiator: "NodePOLInitiator",
_pol_type: NodePOLType,
pol_protocol,
_pol_state,
_pol_source_node_id,
_pol_source_node_service,
_pol_source_node_service_state,
):
pol_protocol: str,
_pol_state: Union["HardwareState", "SoftwareState", "FileSystemState"],
_pol_source_node_id: str,
_pol_source_node_service: str,
_pol_source_node_service_state: str,
) -> None:
"""
Initialise the Node State Instruction for the red agent.
@@ -38,19 +42,19 @@ class NodeStateInstructionRed(object):
:param _pol_source_node_service: The source node service (used for initiator type SERVICE)
:param _pol_source_node_service_state: The source node service state (used for initiator type SERVICE)
"""
self.id = _id
self.start_step = _start_step
self.end_step = _end_step
self.target_node_id = _target_node_id
self.initiator = _pol_initiator
self.id: str = _id
self.start_step: int = _start_step
self.end_step: int = _end_step
self.target_node_id: str = _target_node_id
self.initiator: "NodePOLInitiator" = _pol_initiator
self.pol_type: NodePOLType = _pol_type
self.service_name = pol_protocol # Not used when not a service instruction
self.state = _pol_state
self.source_node_id = _pol_source_node_id
self.source_node_service = _pol_source_node_service
self.service_name: str = pol_protocol # Not used when not a service instruction
self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _pol_state
self.source_node_id: str = _pol_source_node_id
self.source_node_service: str = _pol_source_node_service
self.source_node_service_state = _pol_source_node_service_state
def get_start_step(self):
def get_start_step(self) -> int:
"""
Gets the start step.
@@ -59,7 +63,7 @@ class NodeStateInstructionRed(object):
"""
return self.start_step
def get_end_step(self):
def get_end_step(self) -> int:
"""
Gets the end step.
@@ -68,7 +72,7 @@ class NodeStateInstructionRed(object):
"""
return self.end_step
def get_target_node_id(self):
def get_target_node_id(self) -> str:
"""
Gets the node ID.
@@ -77,7 +81,7 @@ class NodeStateInstructionRed(object):
"""
return self.target_node_id
def get_initiator(self):
def get_initiator(self) -> "NodePOLInitiator":
"""
Gets the initiator.
@@ -95,7 +99,7 @@ class NodeStateInstructionRed(object):
"""
return self.pol_type
def get_service_name(self):
def get_service_name(self) -> str:
"""
Gets the service name.
@@ -104,7 +108,7 @@ class NodeStateInstructionRed(object):
"""
return self.service_name
def get_state(self):
def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]:
"""
Gets the state (node or service).
@@ -113,7 +117,7 @@ class NodeStateInstructionRed(object):
"""
return self.state
def get_source_node_id(self):
def get_source_node_id(self) -> str:
"""
Gets the source node id (used for initiator type SERVICE).
@@ -122,7 +126,7 @@ class NodeStateInstructionRed(object):
"""
return self.source_node_id
def get_source_node_service(self):
def get_source_node_service(self) -> str:
"""
Gets the source node service (used for initiator type SERVICE).
@@ -131,7 +135,7 @@ class NodeStateInstructionRed(object):
"""
return self.source_node_service
def get_source_node_service_state(self):
def get_source_node_service_state(self) -> str:
"""
Gets the source node service state (used for initiator type SERVICE).

View File

@@ -4,13 +4,17 @@ import importlib.util
import os
import subprocess
import sys
from typing import TYPE_CHECKING
from primaite import getLogger, NOTEBOOKS_DIR
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from logging import Logger
_LOGGER: "Logger" = getLogger(__name__)
def start_jupyter_session():
def start_jupyter_session() -> None:
"""
Starts a new Jupyter notebook session in the app notebooks directory.

View File

@@ -14,7 +14,7 @@ from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from primaite.nodes.service_node import ServiceNode
from primaite.pol.ier import IER
_VERBOSE = False
_VERBOSE: bool = False
def apply_iers(
@@ -24,7 +24,7 @@ def apply_iers(
iers: Dict[str, IER],
acl: AccessControlList,
step: int,
):
) -> None:
"""
Applies IERs to the links (link pattern of life).
@@ -217,7 +217,7 @@ def apply_node_pol(
nodes: Dict[str, NodeUnion],
node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]],
step: int,
):
) -> None:
"""
Applies node pattern of life.

View File

@@ -11,17 +11,17 @@ class IER(object):
def __init__(
self,
_id,
_start_step,
_end_step,
_load,
_protocol,
_port,
_source_node_id,
_dest_node_id,
_mission_criticality,
_running=False,
):
_id: str,
_start_step: int,
_end_step: int,
_load: int,
_protocol: str,
_port: str,
_source_node_id: str,
_dest_node_id: str,
_mission_criticality: int,
_running: bool = False,
) -> None:
"""
Initialise an Information Exchange Request.
@@ -36,18 +36,18 @@ class IER(object):
:param _mission_criticality: Criticality of this IER to the mission (0 none, 5 mission critical)
:param _running: Indicates whether the IER is currently running
"""
self.id = _id
self.start_step = _start_step
self.end_step = _end_step
self.source_node_id = _source_node_id
self.dest_node_id = _dest_node_id
self.load = _load
self.protocol = _protocol
self.port = _port
self.mission_criticality = _mission_criticality
self.running = _running
self.id: str = _id
self.start_step: int = _start_step
self.end_step: int = _end_step
self.source_node_id: str = _source_node_id
self.dest_node_id: str = _dest_node_id
self.load: int = _load
self.protocol: str = _protocol
self.port: str = _port
self.mission_criticality: int = _mission_criticality
self.running: bool = _running
def get_id(self):
def get_id(self) -> str:
"""
Gets IER ID.
@@ -56,7 +56,7 @@ class IER(object):
"""
return self.id
def get_start_step(self):
def get_start_step(self) -> int:
"""
Gets IER start step.
@@ -65,7 +65,7 @@ class IER(object):
"""
return self.start_step
def get_end_step(self):
def get_end_step(self) -> int:
"""
Gets IER end step.
@@ -74,7 +74,7 @@ class IER(object):
"""
return self.end_step
def get_load(self):
def get_load(self) -> int:
"""
Gets IER load.
@@ -83,7 +83,7 @@ class IER(object):
"""
return self.load
def get_protocol(self):
def get_protocol(self) -> str:
"""
Gets IER protocol.
@@ -92,7 +92,7 @@ class IER(object):
"""
return self.protocol
def get_port(self):
def get_port(self) -> str:
"""
Gets IER port.
@@ -101,7 +101,7 @@ class IER(object):
"""
return self.port
def get_source_node_id(self):
def get_source_node_id(self) -> str:
"""
Gets IER source node ID.
@@ -110,7 +110,7 @@ class IER(object):
"""
return self.source_node_id
def get_dest_node_id(self):
def get_dest_node_id(self) -> str:
"""
Gets IER destination node ID.
@@ -119,7 +119,7 @@ class IER(object):
"""
return self.dest_node_id
def get_is_running(self):
def get_is_running(self) -> bool:
"""
Informs whether the IER is currently running.
@@ -128,7 +128,7 @@ class IER(object):
"""
return self.running
def set_is_running(self, _value):
def set_is_running(self, _value: bool) -> None:
"""
Sets the running state of the IER.
@@ -137,7 +137,7 @@ class IER(object):
"""
self.running = _value
def get_mission_criticality(self):
def get_mission_criticality(self) -> int:
"""
Gets the IER mission criticality (used in the reward function).

View File

@@ -13,7 +13,7 @@ from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from primaite.nodes.service_node import ServiceNode
from primaite.pol.ier import IER
_VERBOSE = False
_VERBOSE: bool = False
def apply_red_agent_iers(
@@ -23,7 +23,7 @@ def apply_red_agent_iers(
iers: Dict[str, IER],
acl: AccessControlList,
step: int,
):
) -> None:
"""
Applies IERs to the links (link POL) resulting from red agent attack.
@@ -213,7 +213,7 @@ def apply_red_agent_node_pol(
iers: Dict[str, IER],
node_pol: Dict[str, NodeStateInstructionRed],
step: int,
):
) -> None:
"""
Applies node pattern of life.

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from pathlib import Path
from typing import Dict, Final, Union
from typing import Any, Dict, Final, Union
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
@@ -29,7 +29,7 @@ class PrimaiteSession:
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
):
) -> None:
"""
The PrimaiteSession constructor.
@@ -52,7 +52,7 @@ class PrimaiteSession:
self.learning_path: Path = None # noqa
self.evaluation_path: Path = None # noqa
def setup(self):
def setup(self) -> None:
"""Performs the session setup."""
if self._training_config.agent_framework == AgentFramework.CUSTOM:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}")
@@ -123,8 +123,8 @@ class PrimaiteSession:
def learn(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Train the agent.
@@ -135,8 +135,8 @@ class PrimaiteSession:
def evaluate(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
@@ -145,6 +145,6 @@ class PrimaiteSession:
if not self._training_config.session_type == SessionType.TRAIN:
self._agent_session.evaluate(**kwargs)
def close(self):
def close(self) -> None:
"""Closes the agent."""
self._agent_session.close()

View File

@@ -1,10 +1,15 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
from typing import TYPE_CHECKING
from primaite import getLogger
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from logging import Logger
_LOGGER: "Logger" = getLogger(__name__)
def run():
def run() -> None:
"""Perform the full clean-up."""
pass

View File

@@ -3,15 +3,19 @@ import filecmp
import os
import shutil
from pathlib import Path
from typing import TYPE_CHECKING
import pkg_resources
from primaite import getLogger, NOTEBOOKS_DIR
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from logging import Logger
_LOGGER: "Logger" = getLogger(__name__)
def run(overwrite_existing: bool = True):
def run(overwrite_existing: bool = True) -> None:
"""
Resets the demo jupyter notebooks in the users app notebooks directory.

View File

@@ -2,15 +2,19 @@ import filecmp
import os
import shutil
from pathlib import Path
from typing import TYPE_CHECKING
import pkg_resources
from primaite import getLogger, USERS_CONFIG_DIR
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from logging import Logger
_LOGGER: "Logger" = getLogger(__name__)
def run(overwrite_existing=True):
def run(overwrite_existing: bool = True) -> None:
"""
Resets the example config files in the users app config directory.

View File

@@ -1,10 +1,15 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
from typing import TYPE_CHECKING
from primaite import _USER_DIRS, getLogger, LOG_DIR, NOTEBOOKS_DIR
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from logging import Logger
_LOGGER: "Logger" = getLogger(__name__)
def run():
def run() -> None:
"""
Handles creation of application directories and user directories.

View File

@@ -1,15 +1,19 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""The Transaction class."""
from datetime import datetime
from typing import List, Tuple
from typing import List, Tuple, TYPE_CHECKING, Union
from primaite.common.enums import AgentIdentifier
if TYPE_CHECKING:
import numpy as np
from gym import spaces
class Transaction(object):
"""Transaction class."""
def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int):
def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int) -> None:
"""
Transaction constructor.
@@ -17,7 +21,7 @@ class Transaction(object):
:param episode_number: The episode number
:param step_number: The step number
"""
self.timestamp = datetime.now()
self.timestamp: datetime = datetime.now()
"The datetime of the transaction"
self.agent_identifier: AgentIdentifier = agent_identifier
"The agent identifier"
@@ -25,17 +29,17 @@ class Transaction(object):
"The episode number"
self.step_number: int = step_number
"The step number"
self.obs_space = None
self.obs_space: "spaces.Space" = None
"The observation space (pre)"
self.obs_space_pre = None
self.obs_space_pre: Union["np.ndarray", Tuple["np.ndarray"]] = None
"The observation space before any actions are taken"
self.obs_space_post = None
self.obs_space_post: Union["np.ndarray", Tuple["np.ndarray"]] = None
"The observation space after any actions are taken"
self.reward: float = None
"The reward value"
self.action_space = None
self.action_space: int = None
"The action space invoked by the agent"
self.obs_space_description = None
self.obs_space_description: List[str] = None
"The env observation space description"
def as_csv_data(self) -> Tuple[List, List]:
@@ -68,7 +72,7 @@ class Transaction(object):
return header, row
def _turn_action_space_to_array(action_space) -> List[str]:
def _turn_action_space_to_array(action_space: Union[int, List[int]]) -> List[str]:
"""
Turns action space into a string array so it can be saved to csv.
@@ -81,7 +85,7 @@ def _turn_action_space_to_array(action_space) -> List[str]:
return [str(action_space)]
def _turn_obs_space_to_array(obs_space, obs_assets, obs_features) -> List[str]:
def _turn_obs_space_to_array(obs_space: "np.ndarray", obs_assets: int, obs_features: int) -> List[str]:
"""
Turns observation space into a string array so it can be saved to csv.

View File

@@ -1,12 +1,16 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
import os
from pathlib import Path
from typing import TYPE_CHECKING
import pkg_resources
from primaite import getLogger
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from logging import Logger
_LOGGER: "Logger" = getLogger(__name__)
def get_file_path(path: str) -> Path:

View File

@@ -6,6 +6,9 @@ from primaite import getLogger
from primaite.transactions.transaction import Transaction
if TYPE_CHECKING:
from io import TextIOWrapper
from pathlib import Path
from primaite.environment.primaite_env import Primaite
_LOGGER: Logger = getLogger(__name__)
@@ -28,7 +31,7 @@ class SessionOutputWriter:
env: "Primaite",
transaction_writer: bool = False,
learning_session: bool = True,
):
) -> None:
"""
Initialise the Session Output Writer.
@@ -41,15 +44,16 @@ class SessionOutputWriter:
determines the name of the folder which contains the final output csv. Defaults to True
:type learning_session: bool, optional
"""
self._env = env
self.transaction_writer = transaction_writer
self.learning_session = learning_session
self._env: "Primaite" = env
self.transaction_writer: bool = transaction_writer
self.learning_session: bool = learning_session
if self.transaction_writer:
fn = f"all_transactions_{self._env.timestamp_str}.csv"
else:
fn = f"average_reward_per_episode_{self._env.timestamp_str}.csv"
self._csv_file_path: "Path"
if self.learning_session:
self._csv_file_path = self._env.session_path / "learning" / fn
else:
@@ -57,26 +61,26 @@ class SessionOutputWriter:
self._csv_file_path.parent.mkdir(exist_ok=True, parents=True)
self._csv_file = None
self._csv_writer = None
self._csv_file: "TextIOWrapper" = None
self._csv_writer: "csv._writer" = None
self._first_write: bool = True
def _init_csv_writer(self):
def _init_csv_writer(self) -> None:
self._csv_file = open(self._csv_file_path, "w", encoding="UTF8", newline="")
self._csv_writer = csv.writer(self._csv_file)
def __del__(self):
def __del__(self) -> None:
self.close()
def close(self):
def close(self) -> None:
"""Close the cvs file."""
if self._csv_file:
self._csv_file.close()
_LOGGER.debug(f"Finished writing file: {self._csv_file_path}")
def write(self, data: Union[Tuple, Transaction]):
def write(self, data: Union[Tuple, Transaction]) -> None:
"""
Write a row of session data.