From 9d0a98b22122e8f8b4c000ad84d613acf45252f8 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 17 Nov 2023 20:30:07 +0000 Subject: [PATCH] Apply suggestions from code review --- src/primaite/game/policy/sb3.py | 4 ---- tests/assets/configs/bad_primaite_session.yaml | 2 +- tests/e2e_integration_tests/test_primaite_session.py | 6 ++++++ 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py index bb35775a..a4870054 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/game/policy/sb3.py @@ -74,10 +74,6 @@ class SB3Policy(PolicyABC, identifier="SB3"): """Load agent from a checkpoint.""" self._agent = self._agent_class.load(model_path, env=self.session.env) - def close(self) -> None: - """Close the agent.""" - pass - @classmethod def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy": """Create an agent from config file.""" diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index 752d98a5..80567aea 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -1,7 +1,7 @@ training_config: rl_framework: SB3 rl_algorithm: PPO - se3ed: 333 + se3ed: 333 # Purposeful typo to check that error is raised with bad configuration. n_learn_steps: 2560 n_eval_episodes: 5 diff --git a/tests/e2e_integration_tests/test_primaite_session.py b/tests/e2e_integration_tests/test_primaite_session.py index 3ef5b6da..b6122bad 100644 --- a/tests/e2e_integration_tests/test_primaite_session.py +++ b/tests/e2e_integration_tests/test_primaite_session.py @@ -1,3 +1,4 @@ +import pydantic import pytest from tests.conftest import TempPrimaiteSession @@ -5,6 +6,7 @@ from tests.conftest import TempPrimaiteSession CFG_PATH = "tests/assets/configs/test_primaite_session.yaml" TRAINING_ONLY_PATH = "tests/assets/configs/train_only_primaite_session.yaml" EVAL_ONLY_PATH = "tests/assets/configs/eval_only_primaite_session.yaml" +MISCONFIGURED_PATH = "tests/assets/configs/bad_primaite_session.yaml" class TestPrimaiteSession: @@ -60,3 +62,7 @@ class TestPrimaiteSession: session: TempPrimaiteSession session.start_session() # TODO: include checks that the model was loaded and that the eval-only session ran + + def test_error_thrown_on_bad_configuration(self): + with pytest.raises(pydantic.ValidationError): + session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH)