#1595: possibly fixed the tests by fixing the bug

This commit is contained in:
Czar Echavez
2023-07-21 09:17:38 +01:00
parent 6930c8ba7b
commit f313709f47

View File

@@ -6,7 +6,6 @@ from pathlib import Path
from typing import Union
from uuid import uuid4
import pytest
from typer.testing import CliRunner
from primaite import getLogger
@@ -23,19 +22,21 @@ _LOGGER = getLogger(__name__)
runner = CliRunner()
sb3_expected_avg_reward_per_episode = {
10: 0,
11: -0.008037109374999995,
12: -0.007978515624999988,
13: -0.008191406249999991,
14: -0.00817578124999999,
15: -0.008085937499999998,
16: -0.007837890624999982,
17: -0.007798828124999992,
18: -0.007777343749999998,
19: -0.007958984374999988,
20: -0.0077499999999999835,
10: 0.0,
11: -0.0011074218750000008,
12: -0.0010000000000000007,
13: -0.0016601562500000013,
14: -0.001400390625000001,
15: -0.0009863281250000007,
16: -0.0011855468750000008,
17: -0.0009511718750000007,
18: -0.0008789062500000007,
19: -0.0012226562500000009,
20: -0.0010292968750000007,
}
sb3_expected_mean_rewards = -0.0018515625000000014
def copy_session_asset(asset_path: Union[str, Path]) -> str:
"""Copies the asset into a temporary test folder."""
@@ -61,9 +62,6 @@ def copy_session_asset(asset_path: Union[str, Path]) -> str:
return copy_path
@pytest.mark.xfail(
reason="Loading works fine but the exact values change with code changes, a bug report has been created."
)
def test_load_sb3_session():
"""Test that loading an SB3 agent works."""
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
@@ -100,13 +98,12 @@ def test_load_sb3_session():
assert len(set(eval_mean_reward.values())) == 1
# the evaluation should be the same as a previous run
assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988
assert next(iter(set(eval_mean_reward.values()))) == sb3_expected_mean_rewards
# delete the test directory
shutil.rmtree(test_path)
@pytest.mark.xfail(reason="Temporarily don't worry about this not working")
def test_load_primaite_session():
"""Test that loading a Primaite session works."""
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
@@ -147,13 +144,12 @@ def test_load_primaite_session():
assert len(set(eval_mean_reward.values())) == 1
# the evaluation should be the same as a previous run
assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988
assert next(iter(set(eval_mean_reward.values()))) == sb3_expected_mean_rewards
# delete the test directory
shutil.rmtree(test_path)
@pytest.mark.xfail(reason="Temporarily don't worry about this not working")
def test_run_loading():
"""Test loading session via main.run."""
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
@@ -172,7 +168,6 @@ def test_run_loading():
shutil.rmtree(test_path)
@pytest.mark.xfail(reason="Temporarily don't worry about this not working")
def test_cli():
"""Test loading session via CLI."""
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")