Merge branch 'dev' into feature/1647_Append_version_number_to_the_primaite_root_dir
This commit is contained in:
@@ -159,15 +159,18 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[
|
||||
from primaite.main import run
|
||||
|
||||
if load is not None:
|
||||
# run a loaded session
|
||||
run(session_path=load)
|
||||
|
||||
if not tc:
|
||||
tc = main_training_config_path()
|
||||
else:
|
||||
# start a new session using tc and ldc
|
||||
if not tc:
|
||||
tc = main_training_config_path()
|
||||
|
||||
if not ldc:
|
||||
ldc = dos_very_basic_config_path()
|
||||
if not ldc:
|
||||
ldc = dos_very_basic_config_path()
|
||||
|
||||
run(training_config_path=tc, lay_down_config_path=ldc)
|
||||
run(training_config_path=tc, lay_down_config_path=ldc)
|
||||
|
||||
|
||||
@app.command()
|
||||
|
||||
@@ -72,13 +72,7 @@ class PrimaiteSession:
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
|
||||
self._agent_session: AgentSessionABC = None # noqa
|
||||
self.session_path: Path = None # noqa
|
||||
self.timestamp_str: str = None # noqa
|
||||
self.learning_path: Path = None # noqa
|
||||
self.evaluation_path: Path = None # noqa
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) # noqa
|
||||
|
||||
def setup(self) -> None:
|
||||
"""Performs the session setup."""
|
||||
|
||||
@@ -6,10 +6,11 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.sb3 import SB3Agent
|
||||
from primaite.cli import app
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.main import run
|
||||
from primaite.primaite_session import PrimaiteSession
|
||||
@@ -18,6 +19,24 @@ from tests import TEST_ASSETS_ROOT
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
sb3_expected_avg_reward_per_episode = {
|
||||
10: 0.0,
|
||||
11: -0.0011074218750000008,
|
||||
12: -0.0010000000000000007,
|
||||
13: -0.0016601562500000013,
|
||||
14: -0.001400390625000001,
|
||||
15: -0.0009863281250000007,
|
||||
16: -0.0011855468750000008,
|
||||
17: -0.0009511718750000007,
|
||||
18: -0.0008789062500000007,
|
||||
19: -0.0012226562500000009,
|
||||
20: -0.0010292968750000007,
|
||||
}
|
||||
|
||||
sb3_expected_eval_rewards = -0.0018515625000000014
|
||||
|
||||
|
||||
def copy_session_asset(asset_path: Union[str, Path]) -> str:
|
||||
"""Copies the asset into a temporary test folder."""
|
||||
@@ -43,25 +62,8 @@ def copy_session_asset(asset_path: Union[str, Path]) -> str:
|
||||
return copy_path
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Loading works fine but the exact values change with code changes, a bug report has been created."
|
||||
)
|
||||
def test_load_sb3_session():
|
||||
"""Test that loading an SB3 agent works."""
|
||||
expected_learn_mean_reward_per_episode = {
|
||||
10: 0,
|
||||
11: -0.008037109374999995,
|
||||
12: -0.007978515624999988,
|
||||
13: -0.008191406249999991,
|
||||
14: -0.00817578124999999,
|
||||
15: -0.008085937499999998,
|
||||
16: -0.007837890624999982,
|
||||
17: -0.007798828124999992,
|
||||
18: -0.007777343749999998,
|
||||
19: -0.007958984374999988,
|
||||
20: -0.0077499999999999835,
|
||||
}
|
||||
|
||||
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
|
||||
|
||||
loaded_agent = SB3Agent(session_path=test_path)
|
||||
@@ -82,7 +84,7 @@ def test_load_sb3_session():
|
||||
)
|
||||
|
||||
# run is seeded so should have the expected learn value
|
||||
assert learn_mean_rewards == expected_learn_mean_reward_per_episode
|
||||
assert learn_mean_rewards == sb3_expected_avg_reward_per_episode
|
||||
|
||||
# run an evaluation
|
||||
loaded_agent.evaluate()
|
||||
@@ -96,29 +98,14 @@ def test_load_sb3_session():
|
||||
assert len(set(eval_mean_reward.values())) == 1
|
||||
|
||||
# the evaluation should be the same as a previous run
|
||||
assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988
|
||||
assert next(iter(set(eval_mean_reward.values()))) == sb3_expected_eval_rewards
|
||||
|
||||
# delete the test directory
|
||||
shutil.rmtree(test_path)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Temporarily don't worry about this not working")
|
||||
def test_load_primaite_session():
|
||||
"""Test that loading a Primaite session works."""
|
||||
expected_learn_mean_reward_per_episode = {
|
||||
10: 0,
|
||||
11: -0.008037109374999995,
|
||||
12: -0.007978515624999988,
|
||||
13: -0.008191406249999991,
|
||||
14: -0.00817578124999999,
|
||||
15: -0.008085937499999998,
|
||||
16: -0.007837890624999982,
|
||||
17: -0.007798828124999992,
|
||||
18: -0.007777343749999998,
|
||||
19: -0.007958984374999988,
|
||||
20: -0.0077499999999999835,
|
||||
}
|
||||
|
||||
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
|
||||
|
||||
# create loaded session
|
||||
@@ -143,7 +130,7 @@ def test_load_primaite_session():
|
||||
)
|
||||
|
||||
# run is seeded so should have the expected learn value
|
||||
assert learn_mean_rewards == expected_learn_mean_reward_per_episode
|
||||
assert learn_mean_rewards == sb3_expected_avg_reward_per_episode
|
||||
|
||||
# run an evaluation
|
||||
session.evaluate()
|
||||
@@ -157,29 +144,14 @@ def test_load_primaite_session():
|
||||
assert len(set(eval_mean_reward.values())) == 1
|
||||
|
||||
# the evaluation should be the same as a previous run
|
||||
assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988
|
||||
assert next(iter(set(eval_mean_reward.values()))) == sb3_expected_eval_rewards
|
||||
|
||||
# delete the test directory
|
||||
shutil.rmtree(test_path)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Temporarily don't worry about this not working")
|
||||
def test_run_loading():
|
||||
"""Test loading session via main.run."""
|
||||
expected_learn_mean_reward_per_episode = {
|
||||
10: 0,
|
||||
11: -0.008037109374999995,
|
||||
12: -0.007978515624999988,
|
||||
13: -0.008191406249999991,
|
||||
14: -0.00817578124999999,
|
||||
15: -0.008085937499999998,
|
||||
16: -0.007837890624999982,
|
||||
17: -0.007798828124999992,
|
||||
18: -0.007777343749999998,
|
||||
19: -0.007958984374999988,
|
||||
20: -0.0077499999999999835,
|
||||
}
|
||||
|
||||
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
|
||||
|
||||
# create loaded session
|
||||
@@ -190,7 +162,26 @@ def test_run_loading():
|
||||
)
|
||||
|
||||
# run is seeded so should have the expected learn value
|
||||
assert learn_mean_rewards == expected_learn_mean_reward_per_episode
|
||||
assert learn_mean_rewards == sb3_expected_avg_reward_per_episode
|
||||
|
||||
# delete the test directory
|
||||
shutil.rmtree(test_path)
|
||||
|
||||
|
||||
def test_cli():
|
||||
"""Test loading session via CLI."""
|
||||
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
|
||||
result = runner.invoke(app, ["session", "--load", test_path])
|
||||
|
||||
# cli should work
|
||||
assert result.exit_code == 0
|
||||
|
||||
learn_mean_rewards = av_rewards_dict(
|
||||
next(Path(test_path).rglob("**/learning/average_reward_per_episode_*.csv"), None)
|
||||
)
|
||||
|
||||
# run is seeded so should have the expected learn value
|
||||
assert learn_mean_rewards == sb3_expected_avg_reward_per_episode
|
||||
|
||||
# delete the test directory
|
||||
shutil.rmtree(test_path)
|
||||
|
||||
Reference in New Issue
Block a user