diff --git a/src/primaite/utils/session_output_reader.py b/src/primaite/utils/session_output_reader.py index ad3dd4f4..2ff4a16a 100644 --- a/src/primaite/utils/session_output_reader.py +++ b/src/primaite/utils/session_output_reader.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Dict, Union +from typing import Any, Dict, Tuple, Union # Using polars as it's faster than Pandas; it will speed things up when # files get big! @@ -13,8 +13,33 @@ def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]: The dictionary keys are the episode number, and the values are the mean reward that episode. :param av_rewards_csv_file: The average rewards per episode csv file path. - :return: The average rewards per episode cdv as a dict. + :return: The average rewards per episode csv as a dict. """ - df = pl.read_csv(av_rewards_csv_file).to_dict() + df_dict = pl.read_csv(av_rewards_csv_file).to_dict() - return {v: df["Average Reward"][i] for i, v in enumerate(df["Episode"])} + return {v: df_dict["Average Reward"][i] for i, v in enumerate(df_dict["Episode"])} + + +def all_transactions_dict(all_transactions_csv_file: Union[str, Path]) -> Dict[Tuple[int, int], Dict[str, Any]]: + """ + Read an all transactions csv file and return as a dict. + + The dict keys are a tuple with the structure (episode, step). The dict + values are the remaining columns as a dict. + + :param all_transactions_csv_file: The all transactions csv file path. + :return: The all transactions csv file as a dict. + """ + df_dict = pl.read_csv(all_transactions_csv_file).to_dict() + new_dict = {} + + episodes = df_dict["Episode"] + steps = df_dict["Step"] + keys = list(df_dict.keys()) + + for i in range(len(episodes)): + key = (episodes[i], steps[i]) + value_dict = {key: df_dict[key][i] for key in keys if key not in ["Episode", "Step"]} + new_dict[key] = value_dict + + return new_dict diff --git a/tests/conftest.py b/tests/conftest.py index aaf4dbce..5cfd2274 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ import shutil import tempfile from datetime import datetime from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Dict, Tuple, Union from unittest.mock import patch import pytest @@ -13,7 +13,7 @@ import pytest from primaite import getLogger from primaite.environment.primaite_env import Primaite from primaite.primaite_session import PrimaiteSession -from primaite.utils.session_output_reader import av_rewards_dict +from primaite.utils.session_output_reader import all_transactions_dict, av_rewards_dict from tests.mock_and_patch.get_session_path_mock import get_temp_session_path ACTION_SPACE_NODE_VALUES = 1 @@ -37,16 +37,26 @@ class TempPrimaiteSession(PrimaiteSession): super().__init__(training_config_path, lay_down_config_path) self.setup() - def learn_av_reward_per_episode(self) -> Dict[int, float]: + def learn_av_reward_per_episode_dict(self) -> Dict[int, float]: """Get the learn av reward per episode from file.""" csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" return av_rewards_dict(self.learning_path / csv_file) - def eval_av_reward_per_episode_csv(self) -> Dict[int, float]: + def eval_av_reward_per_episode_dict(self) -> Dict[int, float]: """Get the eval av reward per episode from file.""" csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" return av_rewards_dict(self.evaluation_path / csv_file) + def learn_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]: + """Get the learn all transactions from file.""" + csv_file = f"all_transactions_{self.timestamp_str}.csv" + return all_transactions_dict(self.learning_path / csv_file) + + def eval_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]: + """Get the eval all transactions from file.""" + csv_file = f"all_transactions_{self.timestamp_str}.csv" + return all_transactions_dict(self.evaluation_path / csv_file) + def metadata_file_as_dict(self) -> Dict[str, Any]: """Read the session_metadata.json file and return as a dict.""" with open(self.session_path / "session_metadata.json", "r") as file: diff --git a/tests/test_reward.py b/tests/test_reward.py index bb6eb1b0..2edfd44a 100644 --- a/tests/test_reward.py +++ b/tests/test_reward.py @@ -48,5 +48,5 @@ def test_rewards_are_being_penalised_at_each_step_function( """ with temp_primaite_session as session: session.evaluate() - ev_rewards = session.eval_av_reward_per_episode_csv() + ev_rewards = session.eval_av_reward_per_episode_dict() assert ev_rewards[1] == -8.0 diff --git a/tests/test_rllib_agent.py b/tests/test_rllib_agent.py index cd98734c..645214e3 100644 --- a/tests/test_rllib_agent.py +++ b/tests/test_rllib_agent.py @@ -19,4 +19,5 @@ def test_primaite_session(temp_primaite_session): assert session_path.exists() session.learn() - assert len(session.learn_av_reward_per_episode().keys()) == 10 + assert len(session.learn_av_reward_per_episode_dict().keys()) == 10 + assert len(session.learn_all_transactions_dict().keys()) == 10 * 256 diff --git a/tests/test_seeding_and_deterministic_session.py b/tests/test_seeding_and_deterministic_session.py index 34cb43fb..f52e9eee 100644 --- a/tests/test_seeding_and_deterministic_session.py +++ b/tests/test_seeding_and_deterministic_session.py @@ -28,7 +28,7 @@ def test_seeded_learning(temp_primaite_session): "Expected output is based upon a agent that was trained with " "seed 67890" ) session.learn() - actual_mean_reward_per_episode = session.learn_av_reward_per_episode() + actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict() assert actual_mean_reward_per_episode == expected_mean_reward_per_episode @@ -45,5 +45,5 @@ def test_deterministic_evaluation(temp_primaite_session): # do stuff session.learn() session.evaluate() - eval_mean_reward = session.eval_av_reward_per_episode_csv() + eval_mean_reward = session.eval_av_reward_per_episode_dict() assert len(set(eval_mean_reward.values())) == 1