Merged PR 121: #1629 - Added rllib test

## Summary
Quick test that uses RLLIB in a session

## Test process
The learning session completes then we check that the number of rows in both the average reward per episode and all transactions csv files.

## Checklist
- [X] This PR is linked to a **work item**
- [X] I have performed **self-review** of the code
- [X] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [X] I have run **pre-commit** checks for code style

#1629 - Added rllib test

Related work items: #1629
This commit is contained in:
Christopher McCarthy
2023-07-17 17:28:51 +00:00
6 changed files with 232 additions and 11 deletions

View File

@@ -5,7 +5,7 @@ import shutil
import tempfile
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Union
from typing import Any, Dict, Tuple, Union
from unittest.mock import patch
import pytest
@@ -13,7 +13,7 @@ import pytest
from primaite import getLogger
from primaite.environment.primaite_env import Primaite
from primaite.primaite_session import PrimaiteSession
from primaite.utils.session_output_reader import av_rewards_dict
from primaite.utils.session_output_reader import all_transactions_dict, av_rewards_dict
from tests.mock_and_patch.get_session_path_mock import get_temp_session_path
ACTION_SPACE_NODE_VALUES = 1
@@ -37,16 +37,26 @@ class TempPrimaiteSession(PrimaiteSession):
super().__init__(training_config_path, lay_down_config_path)
self.setup()
def learn_av_reward_per_episode(self) -> Dict[int, float]:
def learn_av_reward_per_episode_dict(self) -> Dict[int, float]:
"""Get the learn av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.learning_path / csv_file)
def eval_av_reward_per_episode_csv(self) -> Dict[int, float]:
def eval_av_reward_per_episode_dict(self) -> Dict[int, float]:
"""Get the eval av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.evaluation_path / csv_file)
def learn_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""Get the learn all transactions from file."""
csv_file = f"all_transactions_{self.timestamp_str}.csv"
return all_transactions_dict(self.learning_path / csv_file)
def eval_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""Get the eval all transactions from file."""
csv_file = f"all_transactions_{self.timestamp_str}.csv"
return all_transactions_dict(self.evaluation_path / csv_file)
def metadata_file_as_dict(self) -> Dict[str, Any]:
"""Read the session_metadata.json file and return as a dict."""
with open(self.session_path / "session_metadata.json", "r") as file: