Added type hints
This commit is contained in:
@@ -19,7 +19,7 @@ app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def build_dirs():
|
||||
def build_dirs() -> None:
|
||||
"""Build the PrimAITE app directories."""
|
||||
from primaite.setup import setup_app_dirs
|
||||
|
||||
@@ -27,7 +27,7 @@ def build_dirs():
|
||||
|
||||
|
||||
@app.command()
|
||||
def reset_notebooks(overwrite: bool = True):
|
||||
def reset_notebooks(overwrite: bool = True) -> None:
|
||||
"""
|
||||
Force a reset of the demo notebooks in the users notebooks directory.
|
||||
|
||||
@@ -39,7 +39,7 @@ def reset_notebooks(overwrite: bool = True):
|
||||
|
||||
|
||||
@app.command()
|
||||
def logs(last_n: Annotated[int, typer.Option("-n")]):
|
||||
def logs(last_n: Annotated[int, typer.Option("-n")]) -> None:
|
||||
"""
|
||||
Print the PrimAITE log file.
|
||||
|
||||
@@ -60,7 +60,7 @@ _LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # n
|
||||
|
||||
|
||||
@app.command()
|
||||
def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None):
|
||||
def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None) -> None:
|
||||
"""
|
||||
View or set the PrimAITE Log Level.
|
||||
|
||||
@@ -88,7 +88,7 @@ def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None):
|
||||
|
||||
|
||||
@app.command()
|
||||
def notebooks():
|
||||
def notebooks() -> None:
|
||||
"""Start Jupyter Lab in the users PrimAITE notebooks directory."""
|
||||
from primaite.notebooks import start_jupyter_session
|
||||
|
||||
@@ -96,7 +96,7 @@ def notebooks():
|
||||
|
||||
|
||||
@app.command()
|
||||
def version():
|
||||
def version() -> None:
|
||||
"""Get the installed PrimAITE version number."""
|
||||
import primaite
|
||||
|
||||
@@ -104,7 +104,7 @@ def version():
|
||||
|
||||
|
||||
@app.command()
|
||||
def clean_up():
|
||||
def clean_up() -> None:
|
||||
"""Cleans up left over files from previous version installations."""
|
||||
from primaite.setup import old_installation_clean_up
|
||||
|
||||
@@ -112,7 +112,7 @@ def clean_up():
|
||||
|
||||
|
||||
@app.command()
|
||||
def setup(overwrite_existing: bool = True):
|
||||
def setup(overwrite_existing: bool = True) -> None:
|
||||
"""
|
||||
Perform the PrimAITE first-time setup.
|
||||
|
||||
@@ -151,7 +151,7 @@ def setup(overwrite_existing: bool = True):
|
||||
|
||||
|
||||
@app.command()
|
||||
def session(tc: Optional[str] = None, ldc: Optional[str] = None):
|
||||
def session(tc: Optional[str] = None, ldc: Optional[str] = None) -> None:
|
||||
"""
|
||||
Run a PrimAITE session.
|
||||
|
||||
@@ -177,7 +177,7 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None):
|
||||
|
||||
|
||||
@app.command()
|
||||
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None):
|
||||
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None) -> None:
|
||||
"""
|
||||
View or set the plotly template for Session plots.
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ _LOGGER = getLogger(__name__)
|
||||
def run(
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Run the PrimAITE Session.
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.common.enums import HardwareState, NodePOLType, SoftwareState
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodePOLType, SoftwareState
|
||||
|
||||
|
||||
class NodeStateInstructionGreen(object):
|
||||
@@ -15,9 +15,9 @@ class NodeStateInstructionGreen(object):
|
||||
_start_step: int,
|
||||
_end_step: int,
|
||||
_node_id: str,
|
||||
_node_pol_type,
|
||||
_service_name,
|
||||
_state,
|
||||
_node_pol_type: "NodePOLType",
|
||||
_service_name: str,
|
||||
_state: Union["HardwareState", "SoftwareState", "FileSystemState"],
|
||||
):
|
||||
"""
|
||||
Initialise the Node State Instruction.
|
||||
@@ -37,9 +37,9 @@ class NodeStateInstructionGreen(object):
|
||||
self.node_pol_type: "NodePOLType" = _node_pol_type
|
||||
self.service_name: str = _service_name # Not used when not a service instruction
|
||||
# TODO: confirm type of state
|
||||
self.state: Union["HardwareState", "SoftwareState"] = _state
|
||||
self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _state
|
||||
|
||||
def get_start_step(self):
|
||||
def get_start_step(self) -> int:
|
||||
"""
|
||||
Gets the start step.
|
||||
|
||||
@@ -48,7 +48,7 @@ class NodeStateInstructionGreen(object):
|
||||
"""
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self):
|
||||
def get_end_step(self) -> int:
|
||||
"""
|
||||
Gets the end step.
|
||||
|
||||
@@ -57,7 +57,7 @@ class NodeStateInstructionGreen(object):
|
||||
"""
|
||||
return self.end_step
|
||||
|
||||
def get_node_id(self):
|
||||
def get_node_id(self) -> str:
|
||||
"""
|
||||
Gets the node ID.
|
||||
|
||||
@@ -66,7 +66,7 @@ class NodeStateInstructionGreen(object):
|
||||
"""
|
||||
return self.node_id
|
||||
|
||||
def get_node_pol_type(self):
|
||||
def get_node_pol_type(self) -> "NodePOLType":
|
||||
"""
|
||||
Gets the node pattern of life type (enum).
|
||||
|
||||
@@ -75,7 +75,7 @@ class NodeStateInstructionGreen(object):
|
||||
"""
|
||||
return self.node_pol_type
|
||||
|
||||
def get_service_name(self):
|
||||
def get_service_name(self) -> str:
|
||||
"""
|
||||
Gets the service name.
|
||||
|
||||
@@ -84,7 +84,7 @@ class NodeStateInstructionGreen(object):
|
||||
"""
|
||||
return self.service_name
|
||||
|
||||
def get_state(self):
|
||||
def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]:
|
||||
"""
|
||||
Gets the state (node or service).
|
||||
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""Defines node behaviour for Green PoL."""
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from primaite.common.enums import NodePOLType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodePOLInitiator, SoftwareState
|
||||
|
||||
|
||||
@dataclass()
|
||||
class NodeStateInstructionRed(object):
|
||||
@@ -11,18 +15,18 @@ class NodeStateInstructionRed(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_id,
|
||||
_start_step,
|
||||
_end_step,
|
||||
_target_node_id,
|
||||
_pol_initiator,
|
||||
_id: str,
|
||||
_start_step: int,
|
||||
_end_step: int,
|
||||
_target_node_id: str,
|
||||
_pol_initiator: "NodePOLInitiator",
|
||||
_pol_type: NodePOLType,
|
||||
pol_protocol,
|
||||
_pol_state,
|
||||
_pol_source_node_id,
|
||||
_pol_source_node_service,
|
||||
_pol_source_node_service_state,
|
||||
):
|
||||
pol_protocol: str,
|
||||
_pol_state: Union["HardwareState", "SoftwareState", "FileSystemState"],
|
||||
_pol_source_node_id: str,
|
||||
_pol_source_node_service: str,
|
||||
_pol_source_node_service_state: str,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the Node State Instruction for the red agent.
|
||||
|
||||
@@ -38,19 +42,19 @@ class NodeStateInstructionRed(object):
|
||||
:param _pol_source_node_service: The source node service (used for initiator type SERVICE)
|
||||
:param _pol_source_node_service_state: The source node service state (used for initiator type SERVICE)
|
||||
"""
|
||||
self.id = _id
|
||||
self.start_step = _start_step
|
||||
self.end_step = _end_step
|
||||
self.target_node_id = _target_node_id
|
||||
self.initiator = _pol_initiator
|
||||
self.id: str = _id
|
||||
self.start_step: int = _start_step
|
||||
self.end_step: int = _end_step
|
||||
self.target_node_id: str = _target_node_id
|
||||
self.initiator: "NodePOLInitiator" = _pol_initiator
|
||||
self.pol_type: NodePOLType = _pol_type
|
||||
self.service_name = pol_protocol # Not used when not a service instruction
|
||||
self.state = _pol_state
|
||||
self.source_node_id = _pol_source_node_id
|
||||
self.source_node_service = _pol_source_node_service
|
||||
self.service_name: str = pol_protocol # Not used when not a service instruction
|
||||
self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _pol_state
|
||||
self.source_node_id: str = _pol_source_node_id
|
||||
self.source_node_service: str = _pol_source_node_service
|
||||
self.source_node_service_state = _pol_source_node_service_state
|
||||
|
||||
def get_start_step(self):
|
||||
def get_start_step(self) -> int:
|
||||
"""
|
||||
Gets the start step.
|
||||
|
||||
@@ -59,7 +63,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self):
|
||||
def get_end_step(self) -> int:
|
||||
"""
|
||||
Gets the end step.
|
||||
|
||||
@@ -68,7 +72,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.end_step
|
||||
|
||||
def get_target_node_id(self):
|
||||
def get_target_node_id(self) -> str:
|
||||
"""
|
||||
Gets the node ID.
|
||||
|
||||
@@ -77,7 +81,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.target_node_id
|
||||
|
||||
def get_initiator(self):
|
||||
def get_initiator(self) -> "NodePOLInitiator":
|
||||
"""
|
||||
Gets the initiator.
|
||||
|
||||
@@ -95,7 +99,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.pol_type
|
||||
|
||||
def get_service_name(self):
|
||||
def get_service_name(self) -> str:
|
||||
"""
|
||||
Gets the service name.
|
||||
|
||||
@@ -104,7 +108,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.service_name
|
||||
|
||||
def get_state(self):
|
||||
def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]:
|
||||
"""
|
||||
Gets the state (node or service).
|
||||
|
||||
@@ -113,7 +117,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.state
|
||||
|
||||
def get_source_node_id(self):
|
||||
def get_source_node_id(self) -> str:
|
||||
"""
|
||||
Gets the source node id (used for initiator type SERVICE).
|
||||
|
||||
@@ -122,7 +126,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.source_node_id
|
||||
|
||||
def get_source_node_service(self):
|
||||
def get_source_node_service(self) -> str:
|
||||
"""
|
||||
Gets the source node service (used for initiator type SERVICE).
|
||||
|
||||
@@ -131,7 +135,7 @@ class NodeStateInstructionRed(object):
|
||||
"""
|
||||
return self.source_node_service
|
||||
|
||||
def get_source_node_service_state(self):
|
||||
def get_source_node_service_state(self) -> str:
|
||||
"""
|
||||
Gets the source node service state (used for initiator type SERVICE).
|
||||
|
||||
|
||||
@@ -4,13 +4,17 @@ import importlib.util
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from primaite import getLogger, NOTEBOOKS_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
|
||||
|
||||
def start_jupyter_session():
|
||||
def start_jupyter_session() -> None:
|
||||
"""
|
||||
Starts a new Jupyter notebook session in the app notebooks directory.
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
_VERBOSE = False
|
||||
_VERBOSE: bool = False
|
||||
|
||||
|
||||
def apply_iers(
|
||||
@@ -24,7 +24,7 @@ def apply_iers(
|
||||
iers: Dict[str, IER],
|
||||
acl: AccessControlList,
|
||||
step: int,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Applies IERs to the links (link pattern of life).
|
||||
|
||||
@@ -217,7 +217,7 @@ def apply_node_pol(
|
||||
nodes: Dict[str, NodeUnion],
|
||||
node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]],
|
||||
step: int,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Applies node pattern of life.
|
||||
|
||||
|
||||
@@ -11,17 +11,17 @@ class IER(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_id,
|
||||
_start_step,
|
||||
_end_step,
|
||||
_load,
|
||||
_protocol,
|
||||
_port,
|
||||
_source_node_id,
|
||||
_dest_node_id,
|
||||
_mission_criticality,
|
||||
_running=False,
|
||||
):
|
||||
_id: str,
|
||||
_start_step: int,
|
||||
_end_step: int,
|
||||
_load: int,
|
||||
_protocol: str,
|
||||
_port: str,
|
||||
_source_node_id: str,
|
||||
_dest_node_id: str,
|
||||
_mission_criticality: int,
|
||||
_running: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an Information Exchange Request.
|
||||
|
||||
@@ -36,18 +36,18 @@ class IER(object):
|
||||
:param _mission_criticality: Criticality of this IER to the mission (0 none, 5 mission critical)
|
||||
:param _running: Indicates whether the IER is currently running
|
||||
"""
|
||||
self.id = _id
|
||||
self.start_step = _start_step
|
||||
self.end_step = _end_step
|
||||
self.source_node_id = _source_node_id
|
||||
self.dest_node_id = _dest_node_id
|
||||
self.load = _load
|
||||
self.protocol = _protocol
|
||||
self.port = _port
|
||||
self.mission_criticality = _mission_criticality
|
||||
self.running = _running
|
||||
self.id: str = _id
|
||||
self.start_step: int = _start_step
|
||||
self.end_step: int = _end_step
|
||||
self.source_node_id: str = _source_node_id
|
||||
self.dest_node_id: str = _dest_node_id
|
||||
self.load: int = _load
|
||||
self.protocol: str = _protocol
|
||||
self.port: str = _port
|
||||
self.mission_criticality: int = _mission_criticality
|
||||
self.running: bool = _running
|
||||
|
||||
def get_id(self):
|
||||
def get_id(self) -> str:
|
||||
"""
|
||||
Gets IER ID.
|
||||
|
||||
@@ -56,7 +56,7 @@ class IER(object):
|
||||
"""
|
||||
return self.id
|
||||
|
||||
def get_start_step(self):
|
||||
def get_start_step(self) -> int:
|
||||
"""
|
||||
Gets IER start step.
|
||||
|
||||
@@ -65,7 +65,7 @@ class IER(object):
|
||||
"""
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self):
|
||||
def get_end_step(self) -> int:
|
||||
"""
|
||||
Gets IER end step.
|
||||
|
||||
@@ -74,7 +74,7 @@ class IER(object):
|
||||
"""
|
||||
return self.end_step
|
||||
|
||||
def get_load(self):
|
||||
def get_load(self) -> int:
|
||||
"""
|
||||
Gets IER load.
|
||||
|
||||
@@ -83,7 +83,7 @@ class IER(object):
|
||||
"""
|
||||
return self.load
|
||||
|
||||
def get_protocol(self):
|
||||
def get_protocol(self) -> str:
|
||||
"""
|
||||
Gets IER protocol.
|
||||
|
||||
@@ -92,7 +92,7 @@ class IER(object):
|
||||
"""
|
||||
return self.protocol
|
||||
|
||||
def get_port(self):
|
||||
def get_port(self) -> str:
|
||||
"""
|
||||
Gets IER port.
|
||||
|
||||
@@ -101,7 +101,7 @@ class IER(object):
|
||||
"""
|
||||
return self.port
|
||||
|
||||
def get_source_node_id(self):
|
||||
def get_source_node_id(self) -> str:
|
||||
"""
|
||||
Gets IER source node ID.
|
||||
|
||||
@@ -110,7 +110,7 @@ class IER(object):
|
||||
"""
|
||||
return self.source_node_id
|
||||
|
||||
def get_dest_node_id(self):
|
||||
def get_dest_node_id(self) -> str:
|
||||
"""
|
||||
Gets IER destination node ID.
|
||||
|
||||
@@ -119,7 +119,7 @@ class IER(object):
|
||||
"""
|
||||
return self.dest_node_id
|
||||
|
||||
def get_is_running(self):
|
||||
def get_is_running(self) -> bool:
|
||||
"""
|
||||
Informs whether the IER is currently running.
|
||||
|
||||
@@ -128,7 +128,7 @@ class IER(object):
|
||||
"""
|
||||
return self.running
|
||||
|
||||
def set_is_running(self, _value):
|
||||
def set_is_running(self, _value: bool) -> None:
|
||||
"""
|
||||
Sets the running state of the IER.
|
||||
|
||||
@@ -137,7 +137,7 @@ class IER(object):
|
||||
"""
|
||||
self.running = _value
|
||||
|
||||
def get_mission_criticality(self):
|
||||
def get_mission_criticality(self) -> int:
|
||||
"""
|
||||
Gets the IER mission criticality (used in the reward function).
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
_VERBOSE = False
|
||||
_VERBOSE: bool = False
|
||||
|
||||
|
||||
def apply_red_agent_iers(
|
||||
@@ -23,7 +23,7 @@ def apply_red_agent_iers(
|
||||
iers: Dict[str, IER],
|
||||
acl: AccessControlList,
|
||||
step: int,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Applies IERs to the links (link POL) resulting from red agent attack.
|
||||
|
||||
@@ -213,7 +213,7 @@ def apply_red_agent_node_pol(
|
||||
iers: Dict[str, IER],
|
||||
node_pol: Dict[str, NodeStateInstructionRed],
|
||||
step: int,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Applies node pattern of life.
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, Final, Union
|
||||
from typing import Any, Dict, Final, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
@@ -29,7 +29,7 @@ class PrimaiteSession:
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
The PrimaiteSession constructor.
|
||||
|
||||
@@ -52,7 +52,7 @@ class PrimaiteSession:
|
||||
self.learning_path: Path = None # noqa
|
||||
self.evaluation_path: Path = None # noqa
|
||||
|
||||
def setup(self):
|
||||
def setup(self) -> None:
|
||||
"""Performs the session setup."""
|
||||
if self._training_config.agent_framework == AgentFramework.CUSTOM:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}")
|
||||
@@ -123,8 +123,8 @@ class PrimaiteSession:
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
@@ -135,8 +135,8 @@ class PrimaiteSession:
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
@@ -145,6 +145,6 @@ class PrimaiteSession:
|
||||
if not self._training_config.session_type == SessionType.TRAIN:
|
||||
self._agent_session.evaluate(**kwargs)
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Closes the agent."""
|
||||
self._agent_session.close()
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
|
||||
|
||||
def run():
|
||||
def run() -> None:
|
||||
"""Perform the full clean-up."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -3,15 +3,19 @@ import filecmp
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from primaite import getLogger, NOTEBOOKS_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
|
||||
|
||||
def run(overwrite_existing: bool = True):
|
||||
def run(overwrite_existing: bool = True) -> None:
|
||||
"""
|
||||
Resets the demo jupyter notebooks in the users app notebooks directory.
|
||||
|
||||
|
||||
@@ -2,15 +2,19 @@ import filecmp
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from primaite import getLogger, USERS_CONFIG_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
|
||||
|
||||
def run(overwrite_existing=True):
|
||||
def run(overwrite_existing: bool = True) -> None:
|
||||
"""
|
||||
Resets the example config files in the users app config directory.
|
||||
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from primaite import _USER_DIRS, getLogger, LOG_DIR, NOTEBOOKS_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
|
||||
|
||||
def run():
|
||||
def run() -> None:
|
||||
"""
|
||||
Handles creation of application directories and user directories.
|
||||
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""The Transaction class."""
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from primaite.common.enums import AgentIdentifier
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
|
||||
class Transaction(object):
|
||||
"""Transaction class."""
|
||||
|
||||
def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int):
|
||||
def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int) -> None:
|
||||
"""
|
||||
Transaction constructor.
|
||||
|
||||
@@ -17,7 +21,7 @@ class Transaction(object):
|
||||
:param episode_number: The episode number
|
||||
:param step_number: The step number
|
||||
"""
|
||||
self.timestamp = datetime.now()
|
||||
self.timestamp: datetime = datetime.now()
|
||||
"The datetime of the transaction"
|
||||
self.agent_identifier: AgentIdentifier = agent_identifier
|
||||
"The agent identifier"
|
||||
@@ -25,17 +29,17 @@ class Transaction(object):
|
||||
"The episode number"
|
||||
self.step_number: int = step_number
|
||||
"The step number"
|
||||
self.obs_space = None
|
||||
self.obs_space: "spaces.Space" = None
|
||||
"The observation space (pre)"
|
||||
self.obs_space_pre = None
|
||||
self.obs_space_pre: Union["np.ndarray", Tuple["np.ndarray"]] = None
|
||||
"The observation space before any actions are taken"
|
||||
self.obs_space_post = None
|
||||
self.obs_space_post: Union["np.ndarray", Tuple["np.ndarray"]] = None
|
||||
"The observation space after any actions are taken"
|
||||
self.reward: float = None
|
||||
"The reward value"
|
||||
self.action_space = None
|
||||
self.action_space: int = None
|
||||
"The action space invoked by the agent"
|
||||
self.obs_space_description = None
|
||||
self.obs_space_description: List[str] = None
|
||||
"The env observation space description"
|
||||
|
||||
def as_csv_data(self) -> Tuple[List, List]:
|
||||
@@ -68,7 +72,7 @@ class Transaction(object):
|
||||
return header, row
|
||||
|
||||
|
||||
def _turn_action_space_to_array(action_space) -> List[str]:
|
||||
def _turn_action_space_to_array(action_space: Union[int, List[int]]) -> List[str]:
|
||||
"""
|
||||
Turns action space into a string array so it can be saved to csv.
|
||||
|
||||
@@ -81,7 +85,7 @@ def _turn_action_space_to_array(action_space) -> List[str]:
|
||||
return [str(action_space)]
|
||||
|
||||
|
||||
def _turn_obs_space_to_array(obs_space, obs_assets, obs_features) -> List[str]:
|
||||
def _turn_obs_space_to_array(obs_space: "np.ndarray", obs_assets: int, obs_features: int) -> List[str]:
|
||||
"""
|
||||
Turns observation space into a string array so it can be saved to csv.
|
||||
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
|
||||
|
||||
def get_file_path(path: str) -> Path:
|
||||
|
||||
@@ -6,6 +6,9 @@ from primaite import getLogger
|
||||
from primaite.transactions.transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from io import TextIOWrapper
|
||||
from pathlib import Path
|
||||
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
@@ -28,7 +31,7 @@ class SessionOutputWriter:
|
||||
env: "Primaite",
|
||||
transaction_writer: bool = False,
|
||||
learning_session: bool = True,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the Session Output Writer.
|
||||
|
||||
@@ -41,15 +44,16 @@ class SessionOutputWriter:
|
||||
determines the name of the folder which contains the final output csv. Defaults to True
|
||||
:type learning_session: bool, optional
|
||||
"""
|
||||
self._env = env
|
||||
self.transaction_writer = transaction_writer
|
||||
self.learning_session = learning_session
|
||||
self._env: "Primaite" = env
|
||||
self.transaction_writer: bool = transaction_writer
|
||||
self.learning_session: bool = learning_session
|
||||
|
||||
if self.transaction_writer:
|
||||
fn = f"all_transactions_{self._env.timestamp_str}.csv"
|
||||
else:
|
||||
fn = f"average_reward_per_episode_{self._env.timestamp_str}.csv"
|
||||
|
||||
self._csv_file_path: "Path"
|
||||
if self.learning_session:
|
||||
self._csv_file_path = self._env.session_path / "learning" / fn
|
||||
else:
|
||||
@@ -57,26 +61,26 @@ class SessionOutputWriter:
|
||||
|
||||
self._csv_file_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
self._csv_file = None
|
||||
self._csv_writer = None
|
||||
self._csv_file: "TextIOWrapper" = None
|
||||
self._csv_writer: "csv._writer" = None
|
||||
|
||||
self._first_write: bool = True
|
||||
|
||||
def _init_csv_writer(self):
|
||||
def _init_csv_writer(self) -> None:
|
||||
self._csv_file = open(self._csv_file_path, "w", encoding="UTF8", newline="")
|
||||
|
||||
self._csv_writer = csv.writer(self._csv_file)
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Close the cvs file."""
|
||||
if self._csv_file:
|
||||
self._csv_file.close()
|
||||
_LOGGER.debug(f"Finished writing file: {self._csv_file_path}")
|
||||
|
||||
def write(self, data: Union[Tuple, Transaction]):
|
||||
def write(self, data: Union[Tuple, Transaction]) -> None:
|
||||
"""
|
||||
Write a row of session data.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user