#1629 - Added assertion in the test that checks the length of the all transactions file too.

- Added supporting function on the TempPrimaiteSession class that reads the all transactions csv file.
- Some renaming of the functions.
This commit is contained in:
Chris McCarthy
2023-07-17 12:14:47 +01:00
parent 75c91b9eb9
commit 360eb38c2b
5 changed files with 48 additions and 12 deletions

View File

@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Dict, Union
from typing import Any, Dict, Tuple, Union
# Using polars as it's faster than Pandas; it will speed things up when
# files get big!
@@ -13,8 +13,33 @@ def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
The dictionary keys are the episode number, and the values are the mean reward that episode.
:param av_rewards_csv_file: The average rewards per episode csv file path.
:return: The average rewards per episode cdv as a dict.
:return: The average rewards per episode csv as a dict.
"""
df = pl.read_csv(av_rewards_csv_file).to_dict()
df_dict = pl.read_csv(av_rewards_csv_file).to_dict()
return {v: df["Average Reward"][i] for i, v in enumerate(df["Episode"])}
return {v: df_dict["Average Reward"][i] for i, v in enumerate(df_dict["Episode"])}
def all_transactions_dict(all_transactions_csv_file: Union[str, Path]) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""
Read an all transactions csv file and return as a dict.
The dict keys are a tuple with the structure (episode, step). The dict
values are the remaining columns as a dict.
:param all_transactions_csv_file: The all transactions csv file path.
:return: The all transactions csv file as a dict.
"""
df_dict = pl.read_csv(all_transactions_csv_file).to_dict()
new_dict = {}
episodes = df_dict["Episode"]
steps = df_dict["Step"]
keys = list(df_dict.keys())
for i in range(len(episodes)):
key = (episodes[i], steps[i])
value_dict = {key: df_dict[key][i] for key in keys if key not in ["Episode", "Step"]}
new_dict[key] = value_dict
return new_dict

View File

@@ -5,7 +5,7 @@ import shutil
import tempfile
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Union
from typing import Any, Dict, Tuple, Union
from unittest.mock import patch
import pytest
@@ -13,7 +13,7 @@ import pytest
from primaite import getLogger
from primaite.environment.primaite_env import Primaite
from primaite.primaite_session import PrimaiteSession
from primaite.utils.session_output_reader import av_rewards_dict
from primaite.utils.session_output_reader import all_transactions_dict, av_rewards_dict
from tests.mock_and_patch.get_session_path_mock import get_temp_session_path
ACTION_SPACE_NODE_VALUES = 1
@@ -37,16 +37,26 @@ class TempPrimaiteSession(PrimaiteSession):
super().__init__(training_config_path, lay_down_config_path)
self.setup()
def learn_av_reward_per_episode(self) -> Dict[int, float]:
def learn_av_reward_per_episode_dict(self) -> Dict[int, float]:
"""Get the learn av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.learning_path / csv_file)
def eval_av_reward_per_episode_csv(self) -> Dict[int, float]:
def eval_av_reward_per_episode_dict(self) -> Dict[int, float]:
"""Get the eval av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.evaluation_path / csv_file)
def learn_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""Get the learn all transactions from file."""
csv_file = f"all_transactions_{self.timestamp_str}.csv"
return all_transactions_dict(self.learning_path / csv_file)
def eval_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""Get the eval all transactions from file."""
csv_file = f"all_transactions_{self.timestamp_str}.csv"
return all_transactions_dict(self.evaluation_path / csv_file)
def metadata_file_as_dict(self) -> Dict[str, Any]:
"""Read the session_metadata.json file and return as a dict."""
with open(self.session_path / "session_metadata.json", "r") as file:

View File

@@ -48,5 +48,5 @@ def test_rewards_are_being_penalised_at_each_step_function(
"""
with temp_primaite_session as session:
session.evaluate()
ev_rewards = session.eval_av_reward_per_episode_csv()
ev_rewards = session.eval_av_reward_per_episode_dict()
assert ev_rewards[1] == -8.0

View File

@@ -19,4 +19,5 @@ def test_primaite_session(temp_primaite_session):
assert session_path.exists()
session.learn()
assert len(session.learn_av_reward_per_episode().keys()) == 10
assert len(session.learn_av_reward_per_episode_dict().keys()) == 10
assert len(session.learn_all_transactions_dict().keys()) == 10 * 256

View File

@@ -28,7 +28,7 @@ def test_seeded_learning(temp_primaite_session):
"Expected output is based upon a agent that was trained with " "seed 67890"
)
session.learn()
actual_mean_reward_per_episode = session.learn_av_reward_per_episode()
actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict()
assert actual_mean_reward_per_episode == expected_mean_reward_per_episode
@@ -45,5 +45,5 @@ def test_deterministic_evaluation(temp_primaite_session):
# do stuff
session.learn()
session.evaluate()
eval_mean_reward = session.eval_av_reward_per_episode_csv()
eval_mean_reward = session.eval_av_reward_per_episode_dict()
assert len(set(eval_mean_reward.values())) == 1