diff --git a/src/primaite/game/io.py b/src/primaite/game/io.py new file mode 100644 index 00000000..76d5ed1c --- /dev/null +++ b/src/primaite/game/io.py @@ -0,0 +1,54 @@ +from datetime import datetime +from pathlib import Path +from typing import Optional + +from pydantic import BaseModel + +from primaite import PRIMAITE_PATHS + + +class SessionIOSettings(BaseModel): + """Schema for session IO settings.""" + + save_final_model: bool = True + """Whether to save the final model right at the end of training.""" + save_checkpoints: bool = False + """Whether to save a checkpoint model every `checkpoint_interval` episodes""" + checkpoint_interval: int = 10 + """How often to save a checkpoint model (if save_checkpoints is True).""" + save_logs: bool = True + """Whether to save logs""" + save_transactions: bool = True + """Whether to save transactions, If true, the session path will have a transactions folder.""" + save_tensorboard_logs: bool = False + """Whether to save tensorboard logs. If true, the session path will have a tenorboard_logs folder.""" + + +class SessionIO: + """ + Class for managing session IO. + + Currently it's handling path generation, but could expand to handle loading, transaction, tensorboard, and so on. + """ + + def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None: + self.settings = settings + self.session_path = self.generate_session_path() + + def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path: + """Create a folder for the session and return the path to it.""" + if timestamp is None: + timestamp = datetime.now() + date_str = timestamp.strftime("%Y-%m-%d") + time_str = timestamp.strftime("%H-%M-%S") + session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str + session_path.mkdir(exist_ok=True, parents=True) + return session_path + + def generate_model_save_path(self, agent_name: str) -> Path: + """Return the path where the final model will be saved (excluding filename extension).""" + return self.session_path / "checkpoints" / f"{agent_name}_final" + + def generate_checkpoint_save_path(self, agent_name: str, episode: int) -> Path: + """Return the path where the checkpoint model will be saved (excluding filename extension).""" + return self.session_path / "checkpoints" / f"{agent_name}_checkpoint_{episode}.pt" diff --git a/src/primaite/game/policy/policy.py b/src/primaite/game/policy/policy.py index 4c8dc447..6a2381c1 100644 --- a/src/primaite/game/policy/policy.py +++ b/src/primaite/game/policy/policy.py @@ -1,5 +1,6 @@ """Base class and common logic for RL policies.""" from abc import ABC, abstractmethod +from pathlib import Path from typing import Any, Dict, TYPE_CHECKING if TYPE_CHECKING: @@ -54,7 +55,7 @@ class PolicyABC(ABC): pass @abstractmethod - def save(self) -> None: + def save(self, save_path: Path) -> None: """Save the agent.""" pass diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py index ff710944..1be4f915 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/game/policy/sb3.py @@ -1,4 +1,5 @@ """Stable baselines 3 policy.""" +from pathlib import Path from typing import Literal, Optional, TYPE_CHECKING, Union from stable_baselines3 import A2C, PPO @@ -50,12 +51,15 @@ class SB3Policy(PolicyABC, identifier="SB3"): ) print(reward_data) - def save(self) -> None: - """Save the agent.""" - savepath = ( - "temp/path/to/save.pth" # TODO: populate values once I figure out how to get them from the config / session - ) - self._agent.save(savepath) + def save(self, save_path: Path) -> None: + """ + Save the current policy parameters. + + Warning: The recommended way to save model checkpoints is to use a callback within the `learn()` method. Please + refer to https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html for more information. + Therefore, this method is only used to save the final model. + """ + self._agent.save(save_path) pass def load(self) -> None: diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index e85328ef..37c34da9 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -13,6 +13,7 @@ from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction +from primaite.game.io import SessionIO, SessionIOSettings from primaite.game.policy.policy import PolicyABC from primaite.simulator.network.hardware.base import Link, NIC, Node from primaite.simulator.network.hardware.nodes.computer import Computer @@ -179,6 +180,10 @@ class PrimaiteSession: """evaluation episodes counter""" self.mode: SessionMode = SessionMode.MANUAL + """Current session mode.""" + + self.io_manager = SessionIO() + """IO manager for the session.""" def start_session(self) -> None: """Commence the training session.""" @@ -190,6 +195,7 @@ class PrimaiteSession: n_eval_episodes = self.training_options.n_eval_episodes deterministic_eval = self.training_options.deterministic_eval self.policy.learn(n_time_steps=n_learn_steps) + self.save_models() self.mode = SessionMode.EVAL if n_eval_episodes > 0: @@ -198,6 +204,11 @@ class PrimaiteSession: self.mode = SessionMode.MANUAL + def save_models(self) -> None: + """Save the RL models.""" + save_path = self.io_manager.generate_model_save_path("temp_model_name") + self.policy.save(save_path) + def step(self): """ Perform one step of the simulation/agent loop. @@ -500,4 +511,8 @@ class PrimaiteSession: # CREATE POLICY sess.policy = PolicyABC.from_config(sess.training_options, session=sess) + # READ IO SETTINGS + io_settings = cfg.get("io_settings", {}) + sess.io_manager.settings = SessionIO(settings=SessionIOSettings(**io_settings)) + return sess