Files
PrimAITE/src/primaite/utils/session_output_writer.py
Chris McCarthy e11fd2ced4 #917 - Fixed the RLlib integration
- Dropped support for overriding the num_episodes and num_steps at the agent level. It's just not needed and will add complexity when overriding and writing output files.
2023-06-30 16:52:57 +01:00

84 lines
2.3 KiB
Python

import csv
from logging import Logger
from typing import Final, List, Tuple, TYPE_CHECKING, Union
from primaite import getLogger
from primaite.transactions.transaction import Transaction
if TYPE_CHECKING:
from primaite.environment.primaite_env import Primaite
_LOGGER: Logger = getLogger(__name__)
class SessionOutputWriter:
"""
A session output writer class.
Is used to write session outputs to csv file.
"""
_AV_REWARD_PER_EPISODE_HEADER: Final[List[str]] = [
"Episode",
"Average Reward",
]
def __init__(
self,
env: "Primaite",
transaction_writer: bool = False,
learning_session: bool = True,
):
self._env = env
self.transaction_writer = transaction_writer
self.learning_session = learning_session
if self.transaction_writer:
fn = f"all_transactions_{self._env.timestamp_str}.csv"
else:
fn = f"average_reward_per_episode_{self._env.timestamp_str}.csv"
if self.learning_session:
self._csv_file_path = self._env.session_path / "learning" / fn
else:
self._csv_file_path = self._env.session_path / "evaluation" / fn
self._csv_file_path.parent.mkdir(exist_ok=True, parents=True)
self._csv_file = None
self._csv_writer = None
self._first_write: bool = True
def _init_csv_writer(self):
self._csv_file = open(self._csv_file_path, "w", encoding="UTF8", newline="")
self._csv_writer = csv.writer(self._csv_file)
def __del__(self):
self.close()
def close(self):
"""Close the cvs file."""
if self._csv_file:
self._csv_file.close()
_LOGGER.debug(f"Finished writing file: {self._csv_file_path}")
def write(self, data: Union[Tuple, Transaction]):
"""
Write a row of session data.
:param data: The row of data to write. Can be a Tuple or an instance
of Transaction.
"""
if isinstance(data, Transaction):
header, data = data.as_csv_data()
else:
header = self._AV_REWARD_PER_EPISODE_HEADER
if self._first_write:
self._init_csv_writer()
self._csv_writer.writerow(header)
self._first_write = False
self._csv_writer.writerow(data)