diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index 24815727..64857c80 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -21,6 +21,7 @@ _PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite") def _get_primaite_config(): config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml" if not config_path.exists(): + config_path = Path( pkg_resources.resource_filename( "primaite", "setup/_package_data/primaite_config.yaml" @@ -36,7 +37,7 @@ def _get_primaite_config(): "ERROR": logging.ERROR, "CRITICAL": logging.CRITICAL, } - primaite_config["log_level"] = log_level_map[primaite_config["log_level"]] + primaite_config["log_level"] = log_level_map[primaite_config["logging"]["log_level"]] return primaite_config @@ -108,11 +109,11 @@ def _log_dir() -> Path: _LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter( { - logging.DEBUG: _PRIMAITE_CONFIG["logger_format"]["DEBUG"], - logging.INFO: _PRIMAITE_CONFIG["logger_format"]["INFO"], - logging.WARNING: _PRIMAITE_CONFIG["logger_format"]["WARNING"], - logging.ERROR: _PRIMAITE_CONFIG["logger_format"]["ERROR"], - logging.CRITICAL: _PRIMAITE_CONFIG["logger_format"]["CRITICAL"] + logging.DEBUG: _PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"], + logging.INFO: _PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"], + logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"], + logging.ERROR: _PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"], + logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"] } ) @@ -132,10 +133,10 @@ _FILE_HANDLER: Final[RotatingFileHandler] = RotatingFileHandler( backupCount=9, # Max 100MB of logs encoding="utf8", ) -_STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["log_level"]) -_FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["log_level"]) +_STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"]) +_FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"]) -_LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logger_format"] +_LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logging"]["logger_format"] _STREAM_HANDLER.setFormatter(_LEVEL_FORMATTER) _FILE_HANDLER.setFormatter(_LEVEL_FORMATTER) @@ -145,7 +146,7 @@ _LOGGER.addHandler(_STREAM_HANDLER) _LOGGER.addHandler(_FILE_HANDLER) -def getLogger(name: str) -> Logger: +def getLogger(name: str) -> Logger: # noqa """ Get a PrimAITE logger. diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 5f4dac8f..05133b7e 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -64,7 +64,11 @@ class AgentSessionABC(ABC): "The session timestamp" self.session_path = _get_session_path(self.session_timestamp) "The Session path" - self.checkpoints_path = self.session_path / "checkpoints" + self.learning_path = self.session_path / "learning" + "The learning outputs path" + self.evaluation_path = self.session_path / "evaluation" + "The evaluation outputs path" + self.checkpoints_path = self.learning_path / "checkpoints" self.checkpoints_path.mkdir(parents=True, exist_ok=True) "The Session checkpoints path" @@ -205,7 +209,6 @@ class HardCodedAgentSessionABC(AgentSessionABC): self._env: Primaite = Primaite( training_config_path=self._training_config_path, lay_down_config_path=self._lay_down_config_path, - transaction_list=[], session_path=self.session_path, timestamp_str=self.timestamp_str ) diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index 710225d7..8a6428bb 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -21,7 +21,6 @@ def _env_creator(env_config): return Primaite( training_config_path=env_config["training_config_path"], lay_down_config_path=env_config["lay_down_config_path"], - transaction_list=env_config["transaction_list"], session_path=env_config["session_path"], timestamp_str=env_config["timestamp_str"] ) @@ -106,7 +105,6 @@ class RLlibAgent(AgentSessionABC): env_config=dict( training_config_path=self._training_config_path, lay_down_config_path=self._lay_down_config_path, - transaction_list=[], session_path=self.session_path, timestamp_str=self.timestamp_str ) diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 4d2ded6b..c183c544 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -34,7 +34,7 @@ class SB3Agent(AgentSessionABC): _LOGGER.error(msg) raise ValueError(msg) - self._tensorboard_log_path = self.session_path / "tensorboard_logs" + self._tensorboard_log_path = self.learning_path / "tensorboard_logs" self._tensorboard_log_path.mkdir(parents=True, exist_ok=True) self._setup() _LOGGER.debug( @@ -49,7 +49,6 @@ class SB3Agent(AgentSessionABC): self._env = Primaite( training_config_path=self._training_config_path, lay_down_config_path=self._lay_down_config_path, - transaction_list=[], session_path=self.session_path, timestamp_str=self.timestamp_str ) @@ -108,10 +107,13 @@ class SB3Agent(AgentSessionABC): if not episodes: episodes = self._training_config.num_episodes - - _LOGGER.info(f"Beginning evaluation for {episodes} episodes @" - f" {time_steps} time steps...") - + self._env.set_as_eval() + if deterministic: + deterministic_str = "deterministic" + else: + deterministic_str = "non-deterministic" + _LOGGER.info(f"Beginning {deterministic_str} evaluation for " + f"{episodes} episodes @ {time_steps} time steps...") for episode in range(episodes): obs = self._env.reset() @@ -123,6 +125,7 @@ class SB3Agent(AgentSessionABC): if isinstance(action, np.ndarray): action = np.int64(action) obs, rewards, done, info = self._env.step(action) + _LOGGER.info(f"Finished evaluation") @classmethod def load(self): diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index 9cbcb702..0e0212f4 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -38,7 +38,7 @@ hard_coded_agent_view: FULL action_type: ANY # Number of episodes to run per session -num_episodes: 10 +num_episodes: 1000 # Number of time_steps per episode num_steps: 256 diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 5319d0f1..5b344a99 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -1,10 +1,6 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """Main environment module containing the PRIMmary AI Training Evironment (Primaite) class.""" import copy -import csv -import logging -import time -from datetime import datetime from pathlib import Path from typing import Dict, Tuple, Union, Final @@ -14,6 +10,7 @@ import yaml from gym import Env, spaces from matplotlib import pyplot as plt +from primaite import getLogger from primaite.acl.access_control_list import AccessControlList from primaite.agents.utils import is_valid_acl_action_extra, \ is_valid_node_action @@ -27,7 +24,7 @@ from primaite.common.enums import ( NodeType, ObservationType, Priority, - SoftwareState, + SoftwareState, SessionType, ) from primaite.common.service import Service from primaite.config import training_config @@ -47,11 +44,9 @@ from primaite.pol.ier import IER from primaite.pol.red_agent_pol import apply_red_agent_iers, \ apply_red_agent_node_pol from primaite.transactions.transaction import Transaction -from primaite.transactions.transactions_to_file import \ - write_transaction_to_file +from primaite.utils.session_output_writer import SessionOutputWriter -_LOGGER = logging.getLogger(__name__) -_LOGGER.setLevel(logging.INFO) +_LOGGER = getLogger(__name__) class Primaite(Env): @@ -67,7 +62,6 @@ class Primaite(Env): self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path], - transaction_list, session_path: Path, timestamp_str: str, ): @@ -76,7 +70,6 @@ class Primaite(Env): :param training_config_path: The training config filepath. :param lay_down_config_path: The lay down config filepath. - :param transaction_list: The list of transactions to populate. :param session_path: The directory path the session is writing to. :param timestamp_str: The session timestamp in the format: _. @@ -96,9 +89,6 @@ class Primaite(Env): super(Primaite, self).__init__() - # Transaction list - self.transaction_list = transaction_list - # The agent in use self.agent_identifier = self.training_config.agent_identifier @@ -245,20 +235,31 @@ class Primaite(Env): _LOGGER.error( f"Invalid action type selected: {self.training_config.action_type}" ) - # Set up a csv to store the results of the training - try: - header = ["Episode", "Average Reward"] - file_name = f"average_reward_per_episode_{timestamp_str}.csv" - file_path = session_path / file_name - self.csv_file = open(file_path, "w", encoding="UTF8", newline="") - self.csv_writer = csv.writer(self.csv_file) - self.csv_writer.writerow(header) - except Exception: - _LOGGER.error( - "Could not create csv file to hold average reward per episode" - ) - _LOGGER.error("Exception occured", exc_info=True) + self.episode_av_reward_writer = SessionOutputWriter( + self, + transaction_writer=False, + learning_session=True + ) + self.transaction_writer = SessionOutputWriter( + self, + transaction_writer=True, + learning_session=True + ) + + def set_as_eval(self): + """Set the writers to write to eval directories.""" + self.episode_av_reward_writer = SessionOutputWriter( + self, + transaction_writer=False, + learning_session=False + ) + self.transaction_writer = SessionOutputWriter( + self, + transaction_writer=True, + learning_session=False + ) + self.episode_count = 0 def reset(self): """ @@ -267,12 +268,14 @@ class Primaite(Env): Returns: Environment observation space (reset) """ - csv_data = self.episode_count, self.average_reward - self.csv_writer.writerow(csv_data) + if self.episode_count > 0: + csv_data = self.episode_count, self.average_reward + self.episode_av_reward_writer.write(csv_data) self.episode_count += 1 - # Don't need to reset links, as they are cleared and recalculated every step + # Don't need to reset links, as they are cleared and recalculated every + # step # Clear the ACL self.init_acl() @@ -303,12 +306,8 @@ class Primaite(Env): done: Indicates episode is complete if True step_info: Additional information relating to this step """ - if self.step_count == 0: - _LOGGER.info(f"Episode: {str(self.episode_count)}") - # TEMP done = False - self.step_count += 1 self.total_step_count += 1 @@ -321,13 +320,16 @@ class Primaite(Env): # Create a Transaction (metric) object for this step transaction = Transaction( - datetime.now(), self.agent_identifier, self.episode_count, + self.agent_identifier, + self.episode_count, self.step_count ) # Load the initial observation space into the transaction - transaction.set_obs_space_pre(copy.deepcopy(self.env_obs)) + transaction.obs_space_pre = copy.deepcopy(self.env_obs) # Load the action space into the transaction - transaction.set_action_space(copy.deepcopy(action)) + transaction.action_space = copy.deepcopy(action) + + initial_nodes = copy.deepcopy(self.nodes) # 1. Implement Blue Action self.interpret_action_and_apply(action) @@ -381,7 +383,7 @@ class Primaite(Env): # 5. Calculate reward signal (for RL) reward = calculate_reward_function( - self.nodes_post_pol, + initial_nodes, self.nodes_post_red, self.nodes_reference, self.green_iers, @@ -390,17 +392,22 @@ class Primaite(Env): self.step_count, self.training_config, ) - _LOGGER.debug(f" Step {self.step_count} Reward: {str(reward)}") + _LOGGER.debug( + f"Episode: {self.episode_count}, " + f"Step {self.step_count}, " + f"Reward: {reward}" + ) self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count - if self.training_config.session_type == "EVALUATION": + if self.training_config.session_type is SessionType.EVAL: # For evaluation, need to trigger the done value = True when # step count is reached in order to prevent neverending episode done = True - _LOGGER.info(f" Average Reward: {str(self.average_reward)}") + _LOGGER.info(f"Episode: {self.episode_count}, " + f"Average Reward: {self.average_reward}") # Load the reward into the transaction - transaction.set_reward(reward) + transaction.reward = reward # 6. Output Verbose # self.output_link_status() @@ -408,28 +415,14 @@ class Primaite(Env): # 7. Update env_obs self.update_environent_obs() # Load the new observation space into the transaction - transaction.set_obs_space_post(copy.deepcopy(self.env_obs)) + transaction.obs_space_post = copy.deepcopy(self.env_obs) - # 8. Add the transaction to the list of transactions - self.transaction_list.append(copy.deepcopy(transaction)) + # Write transaction to file + self.transaction_writer.write(transaction) # Return return self.env_obs, reward, done, self.step_info - def close(self): - self.__close__() - - def __close__(self): - """ - Override close function - """ - write_transaction_to_file( - self.transaction_list, - self.session_path, - self.timestamp_str - ) - self.csv_file.close() - def init_acl(self): """Initialise the Access Control List.""" self.acl.remove_all_rules() @@ -467,7 +460,7 @@ class Primaite(Env): ): # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: - logging.error("Invalid action type found") + _LOGGER.error("Invalid action type found") def apply_actions_to_nodes(self, _action): """ diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index 1a1a0770..00e45fa3 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -85,9 +85,6 @@ def calculate_reward_function( ) if live_blocked and not reference_blocked: - _LOGGER.debug( - f"Applying reward of {ier_reward} because IER {ier_key} is blocked" - ) reward_value += ier_reward elif live_blocked and reference_blocked: _LOGGER.debug( diff --git a/src/primaite/main.py b/src/primaite/main.py index 5aba68ef..3c0f93b3 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -20,6 +20,7 @@ def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, session.setup() session.learn() + session.evaluate() if __name__ == "__main__": diff --git a/src/primaite/setup/_package_data/primaite_config.yaml b/src/primaite/setup/_package_data/primaite_config.yaml index 5d469ffe..1dd8775b 100644 --- a/src/primaite/setup/_package_data/primaite_config.yaml +++ b/src/primaite/setup/_package_data/primaite_config.yaml @@ -1,10 +1,11 @@ # The main PrimAITE application config file # Logging -log_level: INFO -logger_format: - DEBUG: '%(asctime)s: %(message)s' - INFO: '%(asctime)s: %(message)s' - WARNING: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' - ERROR: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' - CRITICAL: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' +logging: + log_level: INFO + logger_format: + DEBUG: '%(asctime)s: %(message)s' + INFO: '%(asctime)s: %(message)s' + WARNING: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' + ERROR: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' + CRITICAL: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index a4ce48e3..6e5ba5f0 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -1,57 +1,115 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """The Transaction class.""" +from datetime import datetime +from typing import List, Tuple class Transaction(object): """Transaction class.""" - def __init__(self, _timestamp, _agent_identifier, _episode_number, _step_number): + def __init__( + self, + agent_identifier, + episode_number, + step_number + ): """ - Init. + Transaction constructor. - Args: - _timestamp: The time this object was created - _agent_identifier: An identifier for the agent in use - _episode_number: The episode number - _step_number: The step number + :param agent_identifier: An identifier for the agent in use + :param episode_number: The episode number + :param step_number: The step number """ - self.timestamp = _timestamp - self.agent_identifier = _agent_identifier - self.episode_number = _episode_number - self.step_number = _step_number + self.timestamp = datetime.now() + "The datetime of the transaction" + self.agent_identifier = agent_identifier + self.episode_number = episode_number + "The episode number" + self.step_number = step_number + "The step number" + self.obs_space_pre = None + "The observation space before any actions are taken" + self.obs_space_post = None + "The observation space after any actions are taken" + self.reward = None + "The reward value" + self.action_space = None + "The action space invoked by the agent" - def set_obs_space_pre(self, _obs_space_pre): - """ - Sets the observation space (pre). + def as_csv_data(self) -> Tuple[List, List]: + if isinstance(self.action_space, int): + action_length = self.action_space + else: + action_length = self.action_space.size + obs_shape = self.obs_space_post.shape + obs_assets = self.obs_space_post.shape[0] + if len(obs_shape) == 1: + # A bit of a workaround but I think the way transactions are + # written will change soon + obs_features = 1 + else: + obs_features = self.obs_space_post.shape[1] - Args: - _obs_space_pre: The observation space before any actions are taken - """ - self.obs_space_pre = _obs_space_pre + # Create the action space headers array + action_header = [] + for x in range(action_length): + action_header.append("AS_" + str(x)) - def set_obs_space_post(self, _obs_space_post): - """ - Sets the observation space (post). + # Create the observation space headers array + obs_header_initial = [] + obs_header_new = [] + for x in range(obs_assets): + for y in range(obs_features): + obs_header_initial.append("OSI_" + str(x) + "_" + str(y)) + obs_header_new.append("OSN_" + str(x) + "_" + str(y)) - Args: - _obs_space_post: The observation space after any actions are taken - """ - self.obs_space_post = _obs_space_post + # Open up a csv file + header = ["Timestamp", "Episode", "Step", "Reward"] + header = header + action_header + obs_header_initial + obs_header_new - def set_reward(self, _reward): - """ - Sets the reward. + row = [ + str(self.timestamp), + str(self.episode_number), + str(self.step_number), + str(self.reward), + ] + row = ( + row + + _turn_action_space_to_array(self.action_space) + + _turn_obs_space_to_array(self.obs_space_pre, obs_assets, + obs_features) + + _turn_obs_space_to_array(self.obs_space_post, obs_assets, + obs_features) + ) + return header, row - Args: - _reward: The reward value - """ - self.reward = _reward - def set_action_space(self, _action_space): - """ - Sets the action space. +def _turn_action_space_to_array(action_space) -> List[str]: + """ + Turns action space into a string array so it can be saved to csv. - Args: - _action_space: The action space invoked by the agent - """ - self.action_space = _action_space + :param action_space: The action space + :return: The action space as an array of strings + """ + if isinstance(action_space, list): + return [str(i) for i in action_space] + else: + return [str(action_space)] + + +def _turn_obs_space_to_array(obs_space, obs_assets, obs_features) -> List[str]: + """ + Turns observation space into a string array so it can be saved to csv. + + :param obs_space: The observation space + :param obs_assets: The number of assets (i.e. nodes or links) in the + observation space + :param obs_features: The number of features associated with the asset + :return: The observation space as an array of strings + """ + return_array = [] + for x in range(obs_assets): + for y in range(obs_features): + return_array.append(str(obs_space[x][y])) + + return return_array diff --git a/src/primaite/transactions/transactions_to_file.py b/src/primaite/transactions/transactions_to_file.py deleted file mode 100644 index ed7a8f1c..00000000 --- a/src/primaite/transactions/transactions_to_file.py +++ /dev/null @@ -1,119 +0,0 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. -"""Writes the Transaction log list out to file for evaluation to utilse.""" - -import csv -from pathlib import Path - -import numpy as np - -from primaite import getLogger - -_LOGGER = getLogger(__name__) - - -def turn_action_space_to_array(_action_space): - """ - Turns action space into a string array so it can be saved to csv. - - Args: - _action_space: The action space. - """ - if isinstance(_action_space, list): - return [str(i) for i in _action_space] - else: - return [str(_action_space)] - - -def turn_obs_space_to_array(_obs_space, _obs_assets, _obs_features): - """ - Turns observation space into a string array so it can be saved to csv. - - Args: - _obs_space: The observation space - _obs_assets: The number of assets (i.e. nodes or links) in the observation space - _obs_features: The number of features associated with the asset - """ - return_array = [] - for x in range(_obs_assets): - for y in range(_obs_features): - return_array.append(str(_obs_space[x][y])) - - return return_array - - -def write_transaction_to_file(transaction_list, session_path: Path, timestamp_str: str): - """ - Writes transaction logs to file to support training evaluation. - - :param transaction_list: The list of transactions from all steps and all - episodes. - :param session_path: The directory path the session is writing to. - :param timestamp_str: The session timestamp in the format: - _. - """ - # Get the first transaction and use it to determine the makeup of the - # observation space and action space - # Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action - # space as "AS_1" - # This will be tied into the PrimAITE Use Case so that they make sense - - template_transation = transaction_list[0] - if isinstance(template_transation.action_space, int): - action_length = template_transation.action_space - else: - action_length = template_transation.action_space.size - obs_shape = template_transation.obs_space_post.shape - obs_assets = template_transation.obs_space_post.shape[0] - if len(obs_shape) == 1: - # bit of a workaround but I think the way transactions are written will change soon - obs_features = 1 - else: - obs_features = template_transation.obs_space_post.shape[1] - - # Create the action space headers array - action_header = [] - for x in range(action_length): - action_header.append("AS_" + str(x)) - - # Create the observation space headers array - obs_header_initial = [] - obs_header_new = [] - for x in range(obs_assets): - for y in range(obs_features): - obs_header_initial.append("OSI_" + str(x) + "_" + str(y)) - obs_header_new.append("OSN_" + str(x) + "_" + str(y)) - - # Open up a csv file - header = ["Timestamp", "Episode", "Step", "Reward"] - header = header + action_header + obs_header_initial + obs_header_new - - try: - filename = session_path / f"all_transactions_{timestamp_str}.csv" - _LOGGER.debug(f"Saving transaction logs: {filename}") - csv_file = open(filename, "w", encoding="UTF8", newline="") - csv_writer = csv.writer(csv_file) - csv_writer.writerow(header) - - for transaction in transaction_list: - csv_data = [ - str(transaction.timestamp), - str(transaction.episode_number), - str(transaction.step_number), - str(transaction.reward), - ] - csv_data = ( - csv_data - + turn_action_space_to_array(transaction.action_space) - + turn_obs_space_to_array( - transaction.obs_space_pre, obs_assets, obs_features - ) - + turn_obs_space_to_array( - transaction.obs_space_post, obs_assets, obs_features - ) - ) - csv_writer.writerow(csv_data) - - csv_file.close() - _LOGGER.debug("Finished writing transactions") - except Exception: - _LOGGER.error("Could not save the transaction file", exc_info=True) diff --git a/src/primaite/utils/session_output_writer.py b/src/primaite/utils/session_output_writer.py new file mode 100644 index 00000000..308e1fb3 --- /dev/null +++ b/src/primaite/utils/session_output_writer.py @@ -0,0 +1,73 @@ +import csv +from logging import Logger +from typing import List, Final, IO, Union, Tuple +from typing import TYPE_CHECKING + +from primaite import getLogger +from primaite.transactions.transaction import Transaction + +if TYPE_CHECKING: + from primaite.environment.primaite_env import Primaite + +_LOGGER: Logger = getLogger(__name__) + + +class SessionOutputWriter: + _AV_REWARD_PER_EPISODE_HEADER: Final[List[str]] = [ + "Episode", "Average Reward" + ] + + def __init__( + self, + env: "Primaite", + transaction_writer: bool = False, + learning_session: bool = True + ): + self._env = env + self.transaction_writer = transaction_writer + self.learning_session = learning_session + + if self.transaction_writer: + fn = f"all_transactions_{self._env.timestamp_str}.csv" + else: + fn = f"average_reward_per_episode_{self._env.timestamp_str}.csv" + + if self.learning_session: + self._csv_file_path = self._env.session_path / "learning" / fn + else: + self._csv_file_path = self._env.session_path / "evaluation" / fn + + self._csv_file_path.parent.mkdir(exist_ok=True, parents=True) + + self._csv_file = None + self._csv_writer = None + + self._first_write: bool = True + + def _init_csv_writer(self): + self._csv_file = open( + self._csv_file_path, "w", encoding="UTF8", newline="" + ) + + self._csv_writer = csv.writer(self._csv_file) + + def __del__(self): + if self._csv_file: + self._csv_file.close() + _LOGGER.info(f"Finished writing file: {self._csv_file_path}") + + def write( + self, + data: Union[Tuple, Transaction] + ): + if isinstance(data, Transaction): + header, data = data.as_csv_data() + else: + header = self._AV_REWARD_PER_EPISODE_HEADER + + if self._first_write: + self._init_csv_writer() + self._csv_writer.writerow(header) + self._first_write = False + + self._csv_writer.writerow(data) diff --git a/tests/conftest.py b/tests/conftest.py index 1bad5db0..945d23f0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -37,7 +37,6 @@ def _get_primaite_env_from_config( env = Primaite( training_config_path=training_config_path, lay_down_config_path=lay_down_config_path, - transaction_list=[], session_path=session_path, timestamp_str=timestamp_str, )