Separate game, environment, and session
This commit is contained in:
63
src/primaite/session/io.py
Normal file
63
src/primaite/session/io.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import PRIMAITE_PATHS
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
|
||||
|
||||
class SessionIOSettings(BaseModel):
|
||||
"""Schema for session IO settings."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
save_final_model: bool = True
|
||||
"""Whether to save the final model right at the end of training."""
|
||||
save_checkpoints: bool = False
|
||||
"""Whether to save a checkpoint model every `checkpoint_interval` episodes"""
|
||||
checkpoint_interval: int = 10
|
||||
"""How often to save a checkpoint model (if save_checkpoints is True)."""
|
||||
save_logs: bool = True
|
||||
"""Whether to save logs"""
|
||||
save_transactions: bool = True
|
||||
"""Whether to save transactions, If true, the session path will have a transactions folder."""
|
||||
save_tensorboard_logs: bool = False
|
||||
"""Whether to save tensorboard logs. If true, the session path will have a tenorboard_logs folder."""
|
||||
|
||||
|
||||
class SessionIO:
|
||||
"""
|
||||
Class for managing session IO.
|
||||
|
||||
Currently it's handling path generation, but could expand to handle loading, transaction, tensorboard, and so on.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None:
|
||||
self.settings: SessionIOSettings = settings
|
||||
self.session_path: Path = self.generate_session_path()
|
||||
|
||||
# set global SIM_OUTPUT path
|
||||
SIM_OUTPUT.path = self.session_path / "simulation_output"
|
||||
|
||||
# warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's
|
||||
# possible refactor needed
|
||||
|
||||
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
|
||||
"""Create a folder for the session and return the path to it."""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now()
|
||||
date_str = timestamp.strftime("%Y-%m-%d")
|
||||
time_str = timestamp.strftime("%H-%M-%S")
|
||||
session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
return session_path
|
||||
|
||||
def generate_model_save_path(self, agent_name: str) -> Path:
|
||||
"""Return the path where the final model will be saved (excluding filename extension)."""
|
||||
return self.session_path / "checkpoints" / f"{agent_name}_final"
|
||||
|
||||
def generate_checkpoint_save_path(self, agent_name: str, episode: int) -> Path:
|
||||
"""Return the path where the checkpoint model will be saved (excluding filename extension)."""
|
||||
return self.session_path / "checkpoints" / f"{agent_name}_checkpoint_{episode}.pt"
|
||||
Reference in New Issue
Block a user