Improve config validation and fix tests

This commit is contained in:
Marek Wolan
2023-11-16 16:14:50 +00:00
parent 0861663cc1
commit ba580b00b4
4 changed files with 27 additions and 4 deletions

View File

@@ -2,7 +2,7 @@ from datetime import datetime
from pathlib import Path
from typing import Optional
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from primaite import PRIMAITE_PATHS
@@ -10,6 +10,8 @@ from primaite import PRIMAITE_PATHS
class SessionIOSettings(BaseModel):
"""Schema for session IO settings."""
model_config = ConfigDict(extra="forbid")
save_final_model: bool = True
"""Whether to save the final model right at the end of training."""
save_checkpoints: bool = False
@@ -34,6 +36,8 @@ class SessionIO:
def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None:
self.settings: SessionIOSettings = settings
self.session_path: Path = self.generate_session_path()
# warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's
# possible refactor needed
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
"""Create a folder for the session and return the path to it."""

View File

@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple
import enlighten
import gymnasium
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from primaite import getLogger
from primaite.game.agent.actions import ActionManager
@@ -104,6 +104,8 @@ class PrimaiteSessionOptions(BaseModel):
Currently this is used to restrict which ports and protocols exist in the world of the simulation.
"""
model_config = ConfigDict(extra="forbid")
ports: List[str]
protocols: List[str]
@@ -111,6 +113,8 @@ class PrimaiteSessionOptions(BaseModel):
class TrainingOptions(BaseModel):
"""Options for training the RL agent."""
model_config = ConfigDict(extra="forbid")
rl_framework: Literal["SB3", "RLLIB"]
rl_algorithm: Literal["PPO", "A2C"]
n_learn_episodes: int
@@ -522,6 +526,6 @@ class PrimaiteSession:
# READ IO SETTINGS
io_settings = cfg.get("io_settings", {})
sess.io_manager = SessionIO(settings=SessionIOSettings(**io_settings))
sess.io_manager.settings = SessionIOSettings(**io_settings)
return sess

View File

@@ -10,6 +10,10 @@ training_config:
agent_references:
- defender
io_settings:
save_checkpoints: true
checkpoint_interval: 5
game_config:
ports:

View File

@@ -32,7 +32,18 @@ class TestPrimaiteSession:
with temp_primaite_session as session:
session: TempPrimaiteSession
session.start_session()
# TODO: check that env was closed, that the model was saved, etc.
session_path = session.io_manager.session_path
assert session_path.exists()
print(list(session_path.glob("*")))
checkpoint_dir = session_path / "checkpoints" / "sb3_final"
assert checkpoint_dir.exists()
checkpoint_1 = checkpoint_dir / "sb3_model_640_steps.zip"
checkpoint_2 = checkpoint_dir / "sb3_model_1280_steps.zip"
checkpoint_3 = checkpoint_dir / "sb3_model_1920_steps.zip"
assert checkpoint_1.exists()
assert checkpoint_2.exists()
assert not checkpoint_3.exists()
@pytest.mark.parametrize("temp_primaite_session", [[TRAINING_ONLY_PATH]], indirect=True)
def test_training_only_session(self, temp_primaite_session):