Added type hints
This commit is contained in:
@@ -1,12 +1,16 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
|
||||
|
||||
def get_file_path(path: str) -> Path:
|
||||
|
||||
@@ -6,6 +6,9 @@ from primaite import getLogger
|
||||
from primaite.transactions.transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from io import TextIOWrapper
|
||||
from pathlib import Path
|
||||
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
@@ -28,7 +31,7 @@ class SessionOutputWriter:
|
||||
env: "Primaite",
|
||||
transaction_writer: bool = False,
|
||||
learning_session: bool = True,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the Session Output Writer.
|
||||
|
||||
@@ -41,15 +44,16 @@ class SessionOutputWriter:
|
||||
determines the name of the folder which contains the final output csv. Defaults to True
|
||||
:type learning_session: bool, optional
|
||||
"""
|
||||
self._env = env
|
||||
self.transaction_writer = transaction_writer
|
||||
self.learning_session = learning_session
|
||||
self._env: "Primaite" = env
|
||||
self.transaction_writer: bool = transaction_writer
|
||||
self.learning_session: bool = learning_session
|
||||
|
||||
if self.transaction_writer:
|
||||
fn = f"all_transactions_{self._env.timestamp_str}.csv"
|
||||
else:
|
||||
fn = f"average_reward_per_episode_{self._env.timestamp_str}.csv"
|
||||
|
||||
self._csv_file_path: "Path"
|
||||
if self.learning_session:
|
||||
self._csv_file_path = self._env.session_path / "learning" / fn
|
||||
else:
|
||||
@@ -57,26 +61,26 @@ class SessionOutputWriter:
|
||||
|
||||
self._csv_file_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
self._csv_file = None
|
||||
self._csv_writer = None
|
||||
self._csv_file: "TextIOWrapper" = None
|
||||
self._csv_writer: "csv._writer" = None
|
||||
|
||||
self._first_write: bool = True
|
||||
|
||||
def _init_csv_writer(self):
|
||||
def _init_csv_writer(self) -> None:
|
||||
self._csv_file = open(self._csv_file_path, "w", encoding="UTF8", newline="")
|
||||
|
||||
self._csv_writer = csv.writer(self._csv_file)
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Close the cvs file."""
|
||||
if self._csv_file:
|
||||
self._csv_file.close()
|
||||
_LOGGER.debug(f"Finished writing file: {self._csv_file_path}")
|
||||
|
||||
def write(self, data: Union[Tuple, Transaction]):
|
||||
def write(self, data: Union[Tuple, Transaction]) -> None:
|
||||
"""
|
||||
Write a row of session data.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user