From ba580b00b41324343c31e1ccc4f4767ffc8c26a2 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 16 Nov 2023 16:14:50 +0000 Subject: [PATCH] Improve config validation and fix tests --- src/primaite/game/io.py | 6 +++++- src/primaite/game/session.py | 8 ++++++-- tests/assets/configs/test_primaite_session.yaml | 4 ++++ .../e2e_integration_tests/test_primaite_session.py | 13 ++++++++++++- 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/src/primaite/game/io.py b/src/primaite/game/io.py index e613316d..d510d108 100644 --- a/src/primaite/game/io.py +++ b/src/primaite/game/io.py @@ -2,7 +2,7 @@ from datetime import datetime from pathlib import Path from typing import Optional -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from primaite import PRIMAITE_PATHS @@ -10,6 +10,8 @@ from primaite import PRIMAITE_PATHS class SessionIOSettings(BaseModel): """Schema for session IO settings.""" + model_config = ConfigDict(extra="forbid") + save_final_model: bool = True """Whether to save the final model right at the end of training.""" save_checkpoints: bool = False @@ -34,6 +36,8 @@ class SessionIO: def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None: self.settings: SessionIOSettings = settings self.session_path: Path = self.generate_session_path() + # warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's + # possible refactor needed def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path: """Create a folder for the session and return the path to it.""" diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index f265b7d9..655e2459 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple import enlighten import gymnasium from gymnasium.core import ActType, ObsType -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from primaite import getLogger from primaite.game.agent.actions import ActionManager @@ -104,6 +104,8 @@ class PrimaiteSessionOptions(BaseModel): Currently this is used to restrict which ports and protocols exist in the world of the simulation. """ + model_config = ConfigDict(extra="forbid") + ports: List[str] protocols: List[str] @@ -111,6 +113,8 @@ class PrimaiteSessionOptions(BaseModel): class TrainingOptions(BaseModel): """Options for training the RL agent.""" + model_config = ConfigDict(extra="forbid") + rl_framework: Literal["SB3", "RLLIB"] rl_algorithm: Literal["PPO", "A2C"] n_learn_episodes: int @@ -522,6 +526,6 @@ class PrimaiteSession: # READ IO SETTINGS io_settings = cfg.get("io_settings", {}) - sess.io_manager = SessionIO(settings=SessionIOSettings(**io_settings)) + sess.io_manager.settings = SessionIOSettings(**io_settings) return sess diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index 201528eb..9445cd2b 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -10,6 +10,10 @@ training_config: agent_references: - defender +io_settings: + save_checkpoints: true + checkpoint_interval: 5 + game_config: ports: diff --git a/tests/e2e_integration_tests/test_primaite_session.py b/tests/e2e_integration_tests/test_primaite_session.py index 5e1da4ff..3ef5b6da 100644 --- a/tests/e2e_integration_tests/test_primaite_session.py +++ b/tests/e2e_integration_tests/test_primaite_session.py @@ -32,7 +32,18 @@ class TestPrimaiteSession: with temp_primaite_session as session: session: TempPrimaiteSession session.start_session() - # TODO: check that env was closed, that the model was saved, etc. + + session_path = session.io_manager.session_path + assert session_path.exists() + print(list(session_path.glob("*"))) + checkpoint_dir = session_path / "checkpoints" / "sb3_final" + assert checkpoint_dir.exists() + checkpoint_1 = checkpoint_dir / "sb3_model_640_steps.zip" + checkpoint_2 = checkpoint_dir / "sb3_model_1280_steps.zip" + checkpoint_3 = checkpoint_dir / "sb3_model_1920_steps.zip" + assert checkpoint_1.exists() + assert checkpoint_2.exists() + assert not checkpoint_3.exists() @pytest.mark.parametrize("temp_primaite_session", [[TRAINING_ONLY_PATH]], indirect=True) def test_training_only_session(self, temp_primaite_session):