Merged PR 121: #1629 - Added rllib test

## Summary
Quick test that uses RLLIB in a session

## Test process
The learning session completes then we check that the number of rows in both the average reward per episode and all transactions csv files.

## Checklist
- [X] This PR is linked to a **work item**
- [X] I have performed **self-review** of the code
- [X] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [X] I have run **pre-commit** checks for code style

#1629 - Added rllib test

Related work items: #1629
This commit is contained in:
Christopher McCarthy
2023-07-17 17:28:51 +00:00
6 changed files with 232 additions and 11 deletions

View File

@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Dict, Union
from typing import Any, Dict, Tuple, Union
# Using polars as it's faster than Pandas; it will speed things up when
# files get big!
@@ -13,8 +13,33 @@ def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
The dictionary keys are the episode number, and the values are the mean reward that episode.
:param av_rewards_csv_file: The average rewards per episode csv file path.
:return: The average rewards per episode cdv as a dict.
:return: The average rewards per episode csv as a dict.
"""
df = pl.read_csv(av_rewards_csv_file).to_dict()
df_dict = pl.read_csv(av_rewards_csv_file).to_dict()
return {v: df["Average Reward"][i] for i, v in enumerate(df["Episode"])}
return {v: df_dict["Average Reward"][i] for i, v in enumerate(df_dict["Episode"])}
def all_transactions_dict(all_transactions_csv_file: Union[str, Path]) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""
Read an all transactions csv file and return as a dict.
The dict keys are a tuple with the structure (episode, step). The dict
values are the remaining columns as a dict.
:param all_transactions_csv_file: The all transactions csv file path.
:return: The all transactions csv file as a dict.
"""
df_dict = pl.read_csv(all_transactions_csv_file).to_dict()
new_dict = {}
episodes = df_dict["Episode"]
steps = df_dict["Step"]
keys = list(df_dict.keys())
for i in range(len(episodes)):
key = (episodes[i], steps[i])
value_dict = {key: df_dict[key][i] for key in keys if key not in ["Episode", "Step"]}
new_dict[key] = value_dict
return new_dict