#1594 - Managed to get the evaluation of rllib agents working. A test has been added to test_primaite_session.py that now tests the full RLlib agent from end-to-end. I;ve also updated the tests in here to check that the mean reward per episode plot is created for both too. This will need a bit of a re-design further down the line, but for now, it works. Added a custom exception for RLlib eval only error.
This commit is contained in:
@@ -5,18 +5,25 @@ import pytest
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from primaite.config.training_config import main_training_config_path
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[[main_training_config_path(), dos_very_basic_config_path()]],
|
||||
[
|
||||
[TEST_CONFIG_ROOT / "session_test/training_config_main_rllib.yaml", dos_very_basic_config_path()],
|
||||
[TEST_CONFIG_ROOT / "session_test/training_config_main_sb3.yaml", dos_very_basic_config_path()],
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_primaite_session(temp_primaite_session):
|
||||
"""Tests the PrimaiteSession class and its outputs."""
|
||||
"""
|
||||
Tests the PrimaiteSession class and all of its outputs.
|
||||
|
||||
This test runs for both a Stable Baselines3 agent, and a Ray RLlib agent.
|
||||
"""
|
||||
with temp_primaite_session as session:
|
||||
session_path = session.session_path
|
||||
assert session_path.exists()
|
||||
@@ -47,6 +54,17 @@ def test_primaite_session(temp_primaite_session):
|
||||
if file.suffix == ".csv":
|
||||
assert "all_transactions" in file.name or "average_reward_per_episode" in file.name
|
||||
|
||||
# Check that the average reward per episode plots exist
|
||||
assert (session.learning_path / f"average_reward_per_episode_{session.timestamp_str}.png").exists()
|
||||
assert (session.evaluation_path / f"average_reward_per_episode_{session.timestamp_str}.png").exists()
|
||||
|
||||
# Check that the metadata has captured the correct number of learning and eval episodes and steps
|
||||
assert len(session.learn_av_reward_per_episode_dict().keys()) == 10
|
||||
assert len(session.learn_all_transactions_dict().keys()) == 10 * 256
|
||||
|
||||
assert len(session.eval_av_reward_per_episode_dict().keys()) == 3
|
||||
assert len(session.eval_all_transactions_dict().keys()) == 3 * 256
|
||||
|
||||
_LOGGER.debug("Inspecting files in temp session path...")
|
||||
for dir_path, dir_names, file_names in os.walk(session_path):
|
||||
for file in file_names:
|
||||
|
||||
Reference in New Issue
Block a user