#1629 - Added assertion in the test that checks the length of the all transactions file too.
- Added supporting function on the TempPrimaiteSession class that reads the all transactions csv file. - Some renaming of the functions.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
# Using polars as it's faster than Pandas; it will speed things up when
|
||||
# files get big!
|
||||
@@ -13,8 +13,33 @@ def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
|
||||
The dictionary keys are the episode number, and the values are the mean reward that episode.
|
||||
|
||||
:param av_rewards_csv_file: The average rewards per episode csv file path.
|
||||
:return: The average rewards per episode cdv as a dict.
|
||||
:return: The average rewards per episode csv as a dict.
|
||||
"""
|
||||
df = pl.read_csv(av_rewards_csv_file).to_dict()
|
||||
df_dict = pl.read_csv(av_rewards_csv_file).to_dict()
|
||||
|
||||
return {v: df["Average Reward"][i] for i, v in enumerate(df["Episode"])}
|
||||
return {v: df_dict["Average Reward"][i] for i, v in enumerate(df_dict["Episode"])}
|
||||
|
||||
|
||||
def all_transactions_dict(all_transactions_csv_file: Union[str, Path]) -> Dict[Tuple[int, int], Dict[str, Any]]:
|
||||
"""
|
||||
Read an all transactions csv file and return as a dict.
|
||||
|
||||
The dict keys are a tuple with the structure (episode, step). The dict
|
||||
values are the remaining columns as a dict.
|
||||
|
||||
:param all_transactions_csv_file: The all transactions csv file path.
|
||||
:return: The all transactions csv file as a dict.
|
||||
"""
|
||||
df_dict = pl.read_csv(all_transactions_csv_file).to_dict()
|
||||
new_dict = {}
|
||||
|
||||
episodes = df_dict["Episode"]
|
||||
steps = df_dict["Step"]
|
||||
keys = list(df_dict.keys())
|
||||
|
||||
for i in range(len(episodes)):
|
||||
key = (episodes[i], steps[i])
|
||||
value_dict = {key: df_dict[key][i] for key in keys if key not in ["Episode", "Step"]}
|
||||
new_dict[key] = value_dict
|
||||
|
||||
return new_dict
|
||||
|
||||
@@ -5,7 +5,7 @@ import shutil
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Union
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -13,7 +13,7 @@ import pytest
|
||||
from primaite import getLogger
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.primaite_session import PrimaiteSession
|
||||
from primaite.utils.session_output_reader import av_rewards_dict
|
||||
from primaite.utils.session_output_reader import all_transactions_dict, av_rewards_dict
|
||||
from tests.mock_and_patch.get_session_path_mock import get_temp_session_path
|
||||
|
||||
ACTION_SPACE_NODE_VALUES = 1
|
||||
@@ -37,16 +37,26 @@ class TempPrimaiteSession(PrimaiteSession):
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
self.setup()
|
||||
|
||||
def learn_av_reward_per_episode(self) -> Dict[int, float]:
|
||||
def learn_av_reward_per_episode_dict(self) -> Dict[int, float]:
|
||||
"""Get the learn av reward per episode from file."""
|
||||
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
|
||||
return av_rewards_dict(self.learning_path / csv_file)
|
||||
|
||||
def eval_av_reward_per_episode_csv(self) -> Dict[int, float]:
|
||||
def eval_av_reward_per_episode_dict(self) -> Dict[int, float]:
|
||||
"""Get the eval av reward per episode from file."""
|
||||
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
|
||||
return av_rewards_dict(self.evaluation_path / csv_file)
|
||||
|
||||
def learn_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
|
||||
"""Get the learn all transactions from file."""
|
||||
csv_file = f"all_transactions_{self.timestamp_str}.csv"
|
||||
return all_transactions_dict(self.learning_path / csv_file)
|
||||
|
||||
def eval_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
|
||||
"""Get the eval all transactions from file."""
|
||||
csv_file = f"all_transactions_{self.timestamp_str}.csv"
|
||||
return all_transactions_dict(self.evaluation_path / csv_file)
|
||||
|
||||
def metadata_file_as_dict(self) -> Dict[str, Any]:
|
||||
"""Read the session_metadata.json file and return as a dict."""
|
||||
with open(self.session_path / "session_metadata.json", "r") as file:
|
||||
|
||||
@@ -48,5 +48,5 @@ def test_rewards_are_being_penalised_at_each_step_function(
|
||||
"""
|
||||
with temp_primaite_session as session:
|
||||
session.evaluate()
|
||||
ev_rewards = session.eval_av_reward_per_episode_csv()
|
||||
ev_rewards = session.eval_av_reward_per_episode_dict()
|
||||
assert ev_rewards[1] == -8.0
|
||||
|
||||
@@ -19,4 +19,5 @@ def test_primaite_session(temp_primaite_session):
|
||||
assert session_path.exists()
|
||||
session.learn()
|
||||
|
||||
assert len(session.learn_av_reward_per_episode().keys()) == 10
|
||||
assert len(session.learn_av_reward_per_episode_dict().keys()) == 10
|
||||
assert len(session.learn_all_transactions_dict().keys()) == 10 * 256
|
||||
|
||||
@@ -28,7 +28,7 @@ def test_seeded_learning(temp_primaite_session):
|
||||
"Expected output is based upon a agent that was trained with " "seed 67890"
|
||||
)
|
||||
session.learn()
|
||||
actual_mean_reward_per_episode = session.learn_av_reward_per_episode()
|
||||
actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict()
|
||||
|
||||
assert actual_mean_reward_per_episode == expected_mean_reward_per_episode
|
||||
|
||||
@@ -45,5 +45,5 @@ def test_deterministic_evaluation(temp_primaite_session):
|
||||
# do stuff
|
||||
session.learn()
|
||||
session.evaluate()
|
||||
eval_mean_reward = session.eval_av_reward_per_episode_csv()
|
||||
eval_mean_reward = session.eval_av_reward_per_episode_dict()
|
||||
assert len(set(eval_mean_reward.values())) == 1
|
||||
|
||||
Reference in New Issue
Block a user