diff --git a/src/primaite/cli.py b/src/primaite/cli.py index 14db236c..3932c1bb 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -173,15 +173,18 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[ from primaite.main import run if load is not None: + # run a loaded session run(session_path=load) - if not tc: - tc = main_training_config_path() + else: + # start a new session using tc and ldc + if not tc: + tc = main_training_config_path() - if not ldc: - ldc = dos_very_basic_config_path() + if not ldc: + ldc = dos_very_basic_config_path() - run(training_config_path=tc, lay_down_config_path=ldc) + run(training_config_path=tc, lay_down_config_path=ldc) @app.command() diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index e6c0a85a..9203122a 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -72,13 +72,7 @@ class PrimaiteSession: if not isinstance(lay_down_config_path, Path): lay_down_config_path = Path(lay_down_config_path) self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path - self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) - - self._agent_session: AgentSessionABC = None # noqa - self.session_path: Path = None # noqa - self.timestamp_str: str = None # noqa - self.learning_path: Path = None # noqa - self.evaluation_path: Path = None # noqa + self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) # noqa def setup(self) -> None: """Performs the session setup.""" diff --git a/tests/test_session_loading.py b/tests/test_session_loading.py index c624e200..714b363f 100644 --- a/tests/test_session_loading.py +++ b/tests/test_session_loading.py @@ -6,10 +6,11 @@ from pathlib import Path from typing import Union from uuid import uuid4 -import pytest +from typer.testing import CliRunner from primaite import getLogger from primaite.agents.sb3 import SB3Agent +from primaite.cli import app from primaite.common.enums import AgentFramework, AgentIdentifier from primaite.main import run from primaite.primaite_session import PrimaiteSession @@ -18,6 +19,24 @@ from tests import TEST_ASSETS_ROOT _LOGGER = getLogger(__name__) +runner = CliRunner() + +sb3_expected_avg_reward_per_episode = { + 10: 0.0, + 11: -0.0011074218750000008, + 12: -0.0010000000000000007, + 13: -0.0016601562500000013, + 14: -0.001400390625000001, + 15: -0.0009863281250000007, + 16: -0.0011855468750000008, + 17: -0.0009511718750000007, + 18: -0.0008789062500000007, + 19: -0.0012226562500000009, + 20: -0.0010292968750000007, +} + +sb3_expected_eval_rewards = -0.0018515625000000014 + def copy_session_asset(asset_path: Union[str, Path]) -> str: """Copies the asset into a temporary test folder.""" @@ -43,25 +62,8 @@ def copy_session_asset(asset_path: Union[str, Path]) -> str: return copy_path -@pytest.mark.xfail( - reason="Loading works fine but the exact values change with code changes, a bug report has been created." -) def test_load_sb3_session(): """Test that loading an SB3 agent works.""" - expected_learn_mean_reward_per_episode = { - 10: 0, - 11: -0.008037109374999995, - 12: -0.007978515624999988, - 13: -0.008191406249999991, - 14: -0.00817578124999999, - 15: -0.008085937499999998, - 16: -0.007837890624999982, - 17: -0.007798828124999992, - 18: -0.007777343749999998, - 19: -0.007958984374999988, - 20: -0.0077499999999999835, - } - test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session") loaded_agent = SB3Agent(session_path=test_path) @@ -82,7 +84,7 @@ def test_load_sb3_session(): ) # run is seeded so should have the expected learn value - assert learn_mean_rewards == expected_learn_mean_reward_per_episode + assert learn_mean_rewards == sb3_expected_avg_reward_per_episode # run an evaluation loaded_agent.evaluate() @@ -96,29 +98,14 @@ def test_load_sb3_session(): assert len(set(eval_mean_reward.values())) == 1 # the evaluation should be the same as a previous run - assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988 + assert next(iter(set(eval_mean_reward.values()))) == sb3_expected_eval_rewards # delete the test directory shutil.rmtree(test_path) -@pytest.mark.xfail(reason="Temporarily don't worry about this not working") def test_load_primaite_session(): """Test that loading a Primaite session works.""" - expected_learn_mean_reward_per_episode = { - 10: 0, - 11: -0.008037109374999995, - 12: -0.007978515624999988, - 13: -0.008191406249999991, - 14: -0.00817578124999999, - 15: -0.008085937499999998, - 16: -0.007837890624999982, - 17: -0.007798828124999992, - 18: -0.007777343749999998, - 19: -0.007958984374999988, - 20: -0.0077499999999999835, - } - test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session") # create loaded session @@ -143,7 +130,7 @@ def test_load_primaite_session(): ) # run is seeded so should have the expected learn value - assert learn_mean_rewards == expected_learn_mean_reward_per_episode + assert learn_mean_rewards == sb3_expected_avg_reward_per_episode # run an evaluation session.evaluate() @@ -157,29 +144,14 @@ def test_load_primaite_session(): assert len(set(eval_mean_reward.values())) == 1 # the evaluation should be the same as a previous run - assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988 + assert next(iter(set(eval_mean_reward.values()))) == sb3_expected_eval_rewards # delete the test directory shutil.rmtree(test_path) -@pytest.mark.xfail(reason="Temporarily don't worry about this not working") def test_run_loading(): """Test loading session via main.run.""" - expected_learn_mean_reward_per_episode = { - 10: 0, - 11: -0.008037109374999995, - 12: -0.007978515624999988, - 13: -0.008191406249999991, - 14: -0.00817578124999999, - 15: -0.008085937499999998, - 16: -0.007837890624999982, - 17: -0.007798828124999992, - 18: -0.007777343749999998, - 19: -0.007958984374999988, - 20: -0.0077499999999999835, - } - test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session") # create loaded session @@ -190,7 +162,26 @@ def test_run_loading(): ) # run is seeded so should have the expected learn value - assert learn_mean_rewards == expected_learn_mean_reward_per_episode + assert learn_mean_rewards == sb3_expected_avg_reward_per_episode + + # delete the test directory + shutil.rmtree(test_path) + + +def test_cli(): + """Test loading session via CLI.""" + test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session") + result = runner.invoke(app, ["session", "--load", test_path]) + + # cli should work + assert result.exit_code == 0 + + learn_mean_rewards = av_rewards_dict( + next(Path(test_path).rglob("**/learning/average_reward_per_episode_*.csv"), None) + ) + + # run is seeded so should have the expected learn value + assert learn_mean_rewards == sb3_expected_avg_reward_per_episode # delete the test directory shutil.rmtree(test_path)