Merge branch 'dev' into feature/1647_Append_version_number_to_the_primaite_root_dir

This commit is contained in:
Chris McCarthy
2023-07-21 14:01:45 +01:00
3 changed files with 53 additions and 65 deletions

View File

@@ -159,15 +159,18 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[
from primaite.main import run
if load is not None:
# run a loaded session
run(session_path=load)
if not tc:
tc = main_training_config_path()
else:
# start a new session using tc and ldc
if not tc:
tc = main_training_config_path()
if not ldc:
ldc = dos_very_basic_config_path()
if not ldc:
ldc = dos_very_basic_config_path()
run(training_config_path=tc, lay_down_config_path=ldc)
run(training_config_path=tc, lay_down_config_path=ldc)
@app.command()

View File

@@ -72,13 +72,7 @@ class PrimaiteSession:
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
self._agent_session: AgentSessionABC = None # noqa
self.session_path: Path = None # noqa
self.timestamp_str: str = None # noqa
self.learning_path: Path = None # noqa
self.evaluation_path: Path = None # noqa
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) # noqa
def setup(self) -> None:
"""Performs the session setup."""

View File

@@ -6,10 +6,11 @@ from pathlib import Path
from typing import Union
from uuid import uuid4
import pytest
from typer.testing import CliRunner
from primaite import getLogger
from primaite.agents.sb3 import SB3Agent
from primaite.cli import app
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.main import run
from primaite.primaite_session import PrimaiteSession
@@ -18,6 +19,24 @@ from tests import TEST_ASSETS_ROOT
_LOGGER = getLogger(__name__)
runner = CliRunner()
sb3_expected_avg_reward_per_episode = {
10: 0.0,
11: -0.0011074218750000008,
12: -0.0010000000000000007,
13: -0.0016601562500000013,
14: -0.001400390625000001,
15: -0.0009863281250000007,
16: -0.0011855468750000008,
17: -0.0009511718750000007,
18: -0.0008789062500000007,
19: -0.0012226562500000009,
20: -0.0010292968750000007,
}
sb3_expected_eval_rewards = -0.0018515625000000014
def copy_session_asset(asset_path: Union[str, Path]) -> str:
"""Copies the asset into a temporary test folder."""
@@ -43,25 +62,8 @@ def copy_session_asset(asset_path: Union[str, Path]) -> str:
return copy_path
@pytest.mark.xfail(
reason="Loading works fine but the exact values change with code changes, a bug report has been created."
)
def test_load_sb3_session():
"""Test that loading an SB3 agent works."""
expected_learn_mean_reward_per_episode = {
10: 0,
11: -0.008037109374999995,
12: -0.007978515624999988,
13: -0.008191406249999991,
14: -0.00817578124999999,
15: -0.008085937499999998,
16: -0.007837890624999982,
17: -0.007798828124999992,
18: -0.007777343749999998,
19: -0.007958984374999988,
20: -0.0077499999999999835,
}
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
loaded_agent = SB3Agent(session_path=test_path)
@@ -82,7 +84,7 @@ def test_load_sb3_session():
)
# run is seeded so should have the expected learn value
assert learn_mean_rewards == expected_learn_mean_reward_per_episode
assert learn_mean_rewards == sb3_expected_avg_reward_per_episode
# run an evaluation
loaded_agent.evaluate()
@@ -96,29 +98,14 @@ def test_load_sb3_session():
assert len(set(eval_mean_reward.values())) == 1
# the evaluation should be the same as a previous run
assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988
assert next(iter(set(eval_mean_reward.values()))) == sb3_expected_eval_rewards
# delete the test directory
shutil.rmtree(test_path)
@pytest.mark.xfail(reason="Temporarily don't worry about this not working")
def test_load_primaite_session():
"""Test that loading a Primaite session works."""
expected_learn_mean_reward_per_episode = {
10: 0,
11: -0.008037109374999995,
12: -0.007978515624999988,
13: -0.008191406249999991,
14: -0.00817578124999999,
15: -0.008085937499999998,
16: -0.007837890624999982,
17: -0.007798828124999992,
18: -0.007777343749999998,
19: -0.007958984374999988,
20: -0.0077499999999999835,
}
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
# create loaded session
@@ -143,7 +130,7 @@ def test_load_primaite_session():
)
# run is seeded so should have the expected learn value
assert learn_mean_rewards == expected_learn_mean_reward_per_episode
assert learn_mean_rewards == sb3_expected_avg_reward_per_episode
# run an evaluation
session.evaluate()
@@ -157,29 +144,14 @@ def test_load_primaite_session():
assert len(set(eval_mean_reward.values())) == 1
# the evaluation should be the same as a previous run
assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988
assert next(iter(set(eval_mean_reward.values()))) == sb3_expected_eval_rewards
# delete the test directory
shutil.rmtree(test_path)
@pytest.mark.xfail(reason="Temporarily don't worry about this not working")
def test_run_loading():
"""Test loading session via main.run."""
expected_learn_mean_reward_per_episode = {
10: 0,
11: -0.008037109374999995,
12: -0.007978515624999988,
13: -0.008191406249999991,
14: -0.00817578124999999,
15: -0.008085937499999998,
16: -0.007837890624999982,
17: -0.007798828124999992,
18: -0.007777343749999998,
19: -0.007958984374999988,
20: -0.0077499999999999835,
}
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
# create loaded session
@@ -190,7 +162,26 @@ def test_run_loading():
)
# run is seeded so should have the expected learn value
assert learn_mean_rewards == expected_learn_mean_reward_per_episode
assert learn_mean_rewards == sb3_expected_avg_reward_per_episode
# delete the test directory
shutil.rmtree(test_path)
def test_cli():
"""Test loading session via CLI."""
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
result = runner.invoke(app, ["session", "--load", test_path])
# cli should work
assert result.exit_code == 0
learn_mean_rewards = av_rewards_dict(
next(Path(test_path).rglob("**/learning/average_reward_per_episode_*.csv"), None)
)
# run is seeded so should have the expected learn value
assert learn_mean_rewards == sb3_expected_avg_reward_per_episode
# delete the test directory
shutil.rmtree(test_path)