#1595: possibly fixed the tests by fixing the bug
This commit is contained in:
@@ -6,7 +6,6 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from primaite import getLogger
|
||||
@@ -23,19 +22,21 @@ _LOGGER = getLogger(__name__)
|
||||
runner = CliRunner()
|
||||
|
||||
sb3_expected_avg_reward_per_episode = {
|
||||
10: 0,
|
||||
11: -0.008037109374999995,
|
||||
12: -0.007978515624999988,
|
||||
13: -0.008191406249999991,
|
||||
14: -0.00817578124999999,
|
||||
15: -0.008085937499999998,
|
||||
16: -0.007837890624999982,
|
||||
17: -0.007798828124999992,
|
||||
18: -0.007777343749999998,
|
||||
19: -0.007958984374999988,
|
||||
20: -0.0077499999999999835,
|
||||
10: 0.0,
|
||||
11: -0.0011074218750000008,
|
||||
12: -0.0010000000000000007,
|
||||
13: -0.0016601562500000013,
|
||||
14: -0.001400390625000001,
|
||||
15: -0.0009863281250000007,
|
||||
16: -0.0011855468750000008,
|
||||
17: -0.0009511718750000007,
|
||||
18: -0.0008789062500000007,
|
||||
19: -0.0012226562500000009,
|
||||
20: -0.0010292968750000007,
|
||||
}
|
||||
|
||||
sb3_expected_mean_rewards = -0.0018515625000000014
|
||||
|
||||
|
||||
def copy_session_asset(asset_path: Union[str, Path]) -> str:
|
||||
"""Copies the asset into a temporary test folder."""
|
||||
@@ -61,9 +62,6 @@ def copy_session_asset(asset_path: Union[str, Path]) -> str:
|
||||
return copy_path
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Loading works fine but the exact values change with code changes, a bug report has been created."
|
||||
)
|
||||
def test_load_sb3_session():
|
||||
"""Test that loading an SB3 agent works."""
|
||||
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
|
||||
@@ -100,13 +98,12 @@ def test_load_sb3_session():
|
||||
assert len(set(eval_mean_reward.values())) == 1
|
||||
|
||||
# the evaluation should be the same as a previous run
|
||||
assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988
|
||||
assert next(iter(set(eval_mean_reward.values()))) == sb3_expected_mean_rewards
|
||||
|
||||
# delete the test directory
|
||||
shutil.rmtree(test_path)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Temporarily don't worry about this not working")
|
||||
def test_load_primaite_session():
|
||||
"""Test that loading a Primaite session works."""
|
||||
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
|
||||
@@ -147,13 +144,12 @@ def test_load_primaite_session():
|
||||
assert len(set(eval_mean_reward.values())) == 1
|
||||
|
||||
# the evaluation should be the same as a previous run
|
||||
assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988
|
||||
assert next(iter(set(eval_mean_reward.values()))) == sb3_expected_mean_rewards
|
||||
|
||||
# delete the test directory
|
||||
shutil.rmtree(test_path)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Temporarily don't worry about this not working")
|
||||
def test_run_loading():
|
||||
"""Test loading session via main.run."""
|
||||
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
|
||||
@@ -172,7 +168,6 @@ def test_run_loading():
|
||||
shutil.rmtree(test_path)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Temporarily don't worry about this not working")
|
||||
def test_cli():
|
||||
"""Test loading session via CLI."""
|
||||
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
|
||||
|
||||
Reference in New Issue
Block a user