- Added ability to load sessions via PrimaiteSession
- PrimaiteSession loading test
- Added a NotImplemented RLlib loading for now
- Added the ability to load sessions for hardcoded agents
- Moved Session metadata parsing to utils
This commit is contained in:
Czar.Echavez
2023-07-14 14:14:03 +01:00
parent dce0d10383
commit fbdb8aab28
6 changed files with 195 additions and 62 deletions

View File

@@ -8,6 +8,7 @@ from uuid import uuid4
from primaite import getLogger
from primaite.agents.sb3 import SB3Agent
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.primaite_session import PrimaiteSession
from primaite.utils.session_output_reader import av_rewards_dict
from tests import TEST_ASSETS_ROOT
@@ -96,4 +97,59 @@ def test_load_sb3_session():
def test_load_primaite_session():
"""Test that loading a Primaite session works."""
pass
expected_learn_mean_reward_per_episode = {
10: 0,
11: -0.008037109374999995,
12: -0.007978515624999988,
13: -0.008191406249999991,
14: -0.00817578124999999,
15: -0.008085937499999998,
16: -0.007837890624999982,
17: -0.007798828124999992,
18: -0.007777343749999998,
19: -0.007958984374999988,
20: -0.0077499999999999835,
}
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
# create loaded session
session = PrimaiteSession(session_path=test_path)
# run setup on session
session.setup()
# make sure that the session was loaded correctly
assert session._agent_session.uuid == "301874d3-2e14-43c2-ba7f-e2b03ad05dde"
assert session._agent_session._training_config.agent_framework == AgentFramework.SB3.name
assert session._agent_session._training_config.agent_identifier == AgentIdentifier.PPO.name
assert session._agent_session._training_config.deterministic
assert session._agent_session._training_config.seed == 12345
assert str(session._agent_session.session_path) == str(test_path)
# run another learn session
session.learn()
learn_mean_rewards = av_rewards_dict(
session.learning_path / f"average_reward_per_episode_{session.timestamp_str}.csv"
)
# run is seeded so should have the expected learn value
assert learn_mean_rewards == expected_learn_mean_reward_per_episode
# run an evaluation
session.evaluate()
# load the evaluation average reward csv file
eval_mean_reward = av_rewards_dict(
session.evaluation_path / f"average_reward_per_episode_{session.timestamp_str}.csv"
)
# the agent config ran the evaluation in deterministic mode, so should have the same reward value
assert len(set(eval_mean_reward.values())) == 1
# the evaluation should be the same as a previous run
assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988
# delete the test directory
shutil.rmtree(test_path)