64 lines
2.6 KiB
Python
64 lines
2.6 KiB
Python
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from pydantic import BaseModel, ConfigDict
|
|
|
|
from primaite import PRIMAITE_PATHS
|
|
from primaite.simulator import SIM_OUTPUT
|
|
|
|
|
|
class SessionIOSettings(BaseModel):
|
|
"""Schema for session IO settings."""
|
|
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
save_final_model: bool = True
|
|
"""Whether to save the final model right at the end of training."""
|
|
save_checkpoints: bool = False
|
|
"""Whether to save a checkpoint model every `checkpoint_interval` episodes"""
|
|
checkpoint_interval: int = 10
|
|
"""How often to save a checkpoint model (if save_checkpoints is True)."""
|
|
save_logs: bool = True
|
|
"""Whether to save logs"""
|
|
save_transactions: bool = True
|
|
"""Whether to save transactions, If true, the session path will have a transactions folder."""
|
|
save_tensorboard_logs: bool = False
|
|
"""Whether to save tensorboard logs. If true, the session path will have a tenorboard_logs folder."""
|
|
|
|
|
|
class SessionIO:
|
|
"""
|
|
Class for managing session IO.
|
|
|
|
Currently it's handling path generation, but could expand to handle loading, transaction, tensorboard, and so on.
|
|
"""
|
|
|
|
def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None:
|
|
self.settings: SessionIOSettings = settings
|
|
self.session_path: Path = self.generate_session_path()
|
|
|
|
# set global SIM_OUTPUT path
|
|
SIM_OUTPUT.path = self.session_path / "simulation_output"
|
|
|
|
# warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's
|
|
# possible refactor needed
|
|
|
|
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
|
|
"""Create a folder for the session and return the path to it."""
|
|
if timestamp is None:
|
|
timestamp = datetime.now()
|
|
date_str = timestamp.strftime("%Y-%m-%d")
|
|
time_str = timestamp.strftime("%H-%M-%S")
|
|
session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str
|
|
session_path.mkdir(exist_ok=True, parents=True)
|
|
return session_path
|
|
|
|
def generate_model_save_path(self, agent_name: str) -> Path:
|
|
"""Return the path where the final model will be saved (excluding filename extension)."""
|
|
return self.session_path / "checkpoints" / f"{agent_name}_final"
|
|
|
|
def generate_checkpoint_save_path(self, agent_name: str, episode: int) -> Path:
|
|
"""Return the path where the checkpoint model will be saved (excluding filename extension)."""
|
|
return self.session_path / "checkpoints" / f"{agent_name}_checkpoint_{episode}.pt"
|