Added type hints

This commit is contained in:
Marek Wolan
2023-07-14 12:01:38 +01:00
parent a923d818d3
commit c57ed6edcd
16 changed files with 166 additions and 128 deletions

View File

@@ -1,12 +1,16 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
import os
from pathlib import Path
from typing import TYPE_CHECKING
import pkg_resources
from primaite import getLogger
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from logging import Logger
_LOGGER: "Logger" = getLogger(__name__)
def get_file_path(path: str) -> Path:

View File

@@ -6,6 +6,9 @@ from primaite import getLogger
from primaite.transactions.transaction import Transaction
if TYPE_CHECKING:
from io import TextIOWrapper
from pathlib import Path
from primaite.environment.primaite_env import Primaite
_LOGGER: Logger = getLogger(__name__)
@@ -28,7 +31,7 @@ class SessionOutputWriter:
env: "Primaite",
transaction_writer: bool = False,
learning_session: bool = True,
):
) -> None:
"""
Initialise the Session Output Writer.
@@ -41,15 +44,16 @@ class SessionOutputWriter:
determines the name of the folder which contains the final output csv. Defaults to True
:type learning_session: bool, optional
"""
self._env = env
self.transaction_writer = transaction_writer
self.learning_session = learning_session
self._env: "Primaite" = env
self.transaction_writer: bool = transaction_writer
self.learning_session: bool = learning_session
if self.transaction_writer:
fn = f"all_transactions_{self._env.timestamp_str}.csv"
else:
fn = f"average_reward_per_episode_{self._env.timestamp_str}.csv"
self._csv_file_path: "Path"
if self.learning_session:
self._csv_file_path = self._env.session_path / "learning" / fn
else:
@@ -57,26 +61,26 @@ class SessionOutputWriter:
self._csv_file_path.parent.mkdir(exist_ok=True, parents=True)
self._csv_file = None
self._csv_writer = None
self._csv_file: "TextIOWrapper" = None
self._csv_writer: "csv._writer" = None
self._first_write: bool = True
def _init_csv_writer(self):
def _init_csv_writer(self) -> None:
self._csv_file = open(self._csv_file_path, "w", encoding="UTF8", newline="")
self._csv_writer = csv.writer(self._csv_file)
def __del__(self):
def __del__(self) -> None:
self.close()
def close(self):
def close(self) -> None:
"""Close the cvs file."""
if self._csv_file:
self._csv_file.close()
_LOGGER.debug(f"Finished writing file: {self._csv_file_path}")
def write(self, data: Union[Tuple, Transaction]):
def write(self, data: Union[Tuple, Transaction]) -> None:
"""
Write a row of session data.