Improve config validation and fix tests
This commit is contained in:
@@ -2,7 +2,7 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import PRIMAITE_PATHS
|
||||
|
||||
@@ -10,6 +10,8 @@ from primaite import PRIMAITE_PATHS
|
||||
class SessionIOSettings(BaseModel):
|
||||
"""Schema for session IO settings."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
save_final_model: bool = True
|
||||
"""Whether to save the final model right at the end of training."""
|
||||
save_checkpoints: bool = False
|
||||
@@ -34,6 +36,8 @@ class SessionIO:
|
||||
def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None:
|
||||
self.settings: SessionIOSettings = settings
|
||||
self.session_path: Path = self.generate_session_path()
|
||||
# warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's
|
||||
# possible refactor needed
|
||||
|
||||
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
|
||||
"""Create a folder for the session and return the path to it."""
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple
|
||||
import enlighten
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
@@ -104,6 +104,8 @@ class PrimaiteSessionOptions(BaseModel):
|
||||
Currently this is used to restrict which ports and protocols exist in the world of the simulation.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
ports: List[str]
|
||||
protocols: List[str]
|
||||
|
||||
@@ -111,6 +113,8 @@ class PrimaiteSessionOptions(BaseModel):
|
||||
class TrainingOptions(BaseModel):
|
||||
"""Options for training the RL agent."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
rl_framework: Literal["SB3", "RLLIB"]
|
||||
rl_algorithm: Literal["PPO", "A2C"]
|
||||
n_learn_episodes: int
|
||||
@@ -522,6 +526,6 @@ class PrimaiteSession:
|
||||
|
||||
# READ IO SETTINGS
|
||||
io_settings = cfg.get("io_settings", {})
|
||||
sess.io_manager = SessionIO(settings=SessionIOSettings(**io_settings))
|
||||
sess.io_manager.settings = SessionIOSettings(**io_settings)
|
||||
|
||||
return sess
|
||||
|
||||
@@ -10,6 +10,10 @@ training_config:
|
||||
agent_references:
|
||||
- defender
|
||||
|
||||
io_settings:
|
||||
save_checkpoints: true
|
||||
checkpoint_interval: 5
|
||||
|
||||
|
||||
game_config:
|
||||
ports:
|
||||
|
||||
@@ -32,7 +32,18 @@ class TestPrimaiteSession:
|
||||
with temp_primaite_session as session:
|
||||
session: TempPrimaiteSession
|
||||
session.start_session()
|
||||
# TODO: check that env was closed, that the model was saved, etc.
|
||||
|
||||
session_path = session.io_manager.session_path
|
||||
assert session_path.exists()
|
||||
print(list(session_path.glob("*")))
|
||||
checkpoint_dir = session_path / "checkpoints" / "sb3_final"
|
||||
assert checkpoint_dir.exists()
|
||||
checkpoint_1 = checkpoint_dir / "sb3_model_640_steps.zip"
|
||||
checkpoint_2 = checkpoint_dir / "sb3_model_1280_steps.zip"
|
||||
checkpoint_3 = checkpoint_dir / "sb3_model_1920_steps.zip"
|
||||
assert checkpoint_1.exists()
|
||||
assert checkpoint_2.exists()
|
||||
assert not checkpoint_3.exists()
|
||||
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[TRAINING_ONLY_PATH]], indirect=True)
|
||||
def test_training_only_session(self, temp_primaite_session):
|
||||
|
||||
Reference in New Issue
Block a user