Files
PrimAITE/src/primaite/utils/session_output_reader.py
SunilSamra 6b59ce960d Merge remote-tracking branch 'origin/dev' into feature/1566-configure_episode-steps-learn-eval
# Conflicts:
#	src/primaite/config/training_config.py
2023-07-11 11:39:21 +01:00

21 lines
703 B
Python

from pathlib import Path
from typing import Dict, Union
# Using polars as it's faster than Pandas; it will speed things up when
# files get big!
import polars as pl
def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
"""
Read an average rewards per episode csv file and return as a dict.
The dictionary keys are the episode number, and the values are the mean reward that episode.
:param av_rewards_csv_file: The average rewards per episode csv file path.
:return: The average rewards per episode cdv as a dict.
"""
df = pl.read_csv(av_rewards_csv_file).to_dict()
return {v: df["Average Reward"][i] for i, v in enumerate(df["Episode"])}