Add ability to save sb3 final model
This commit is contained in:
54
src/primaite/game/io.py
Normal file
54
src/primaite/game/io.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite import PRIMAITE_PATHS
|
||||
|
||||
|
||||
class SessionIOSettings(BaseModel):
|
||||
"""Schema for session IO settings."""
|
||||
|
||||
save_final_model: bool = True
|
||||
"""Whether to save the final model right at the end of training."""
|
||||
save_checkpoints: bool = False
|
||||
"""Whether to save a checkpoint model every `checkpoint_interval` episodes"""
|
||||
checkpoint_interval: int = 10
|
||||
"""How often to save a checkpoint model (if save_checkpoints is True)."""
|
||||
save_logs: bool = True
|
||||
"""Whether to save logs"""
|
||||
save_transactions: bool = True
|
||||
"""Whether to save transactions, If true, the session path will have a transactions folder."""
|
||||
save_tensorboard_logs: bool = False
|
||||
"""Whether to save tensorboard logs. If true, the session path will have a tenorboard_logs folder."""
|
||||
|
||||
|
||||
class SessionIO:
|
||||
"""
|
||||
Class for managing session IO.
|
||||
|
||||
Currently it's handling path generation, but could expand to handle loading, transaction, tensorboard, and so on.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None:
|
||||
self.settings = settings
|
||||
self.session_path = self.generate_session_path()
|
||||
|
||||
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
|
||||
"""Create a folder for the session and return the path to it."""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now()
|
||||
date_str = timestamp.strftime("%Y-%m-%d")
|
||||
time_str = timestamp.strftime("%H-%M-%S")
|
||||
session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
return session_path
|
||||
|
||||
def generate_model_save_path(self, agent_name: str) -> Path:
|
||||
"""Return the path where the final model will be saved (excluding filename extension)."""
|
||||
return self.session_path / "checkpoints" / f"{agent_name}_final"
|
||||
|
||||
def generate_checkpoint_save_path(self, agent_name: str, episode: int) -> Path:
|
||||
"""Return the path where the checkpoint model will be saved (excluding filename extension)."""
|
||||
return self.session_path / "checkpoints" / f"{agent_name}_checkpoint_{episode}.pt"
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Base class and common logic for RL policies."""
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -54,7 +55,7 @@ class PolicyABC(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self) -> None:
|
||||
def save(self, save_path: Path) -> None:
|
||||
"""Save the agent."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Stable baselines 3 policy."""
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from stable_baselines3 import A2C, PPO
|
||||
@@ -50,12 +51,15 @@ class SB3Policy(PolicyABC, identifier="SB3"):
|
||||
)
|
||||
print(reward_data)
|
||||
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
savepath = (
|
||||
"temp/path/to/save.pth" # TODO: populate values once I figure out how to get them from the config / session
|
||||
)
|
||||
self._agent.save(savepath)
|
||||
def save(self, save_path: Path) -> None:
|
||||
"""
|
||||
Save the current policy parameters.
|
||||
|
||||
Warning: The recommended way to save model checkpoints is to use a callback within the `learn()` method. Please
|
||||
refer to https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html for more information.
|
||||
Therefore, this method is only used to save the final model.
|
||||
"""
|
||||
self._agent.save(save_path)
|
||||
pass
|
||||
|
||||
def load(self) -> None:
|
||||
|
||||
@@ -13,6 +13,7 @@ from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.io import SessionIO, SessionIOSettings
|
||||
from primaite.game.policy.policy import PolicyABC
|
||||
from primaite.simulator.network.hardware.base import Link, NIC, Node
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
@@ -179,6 +180,10 @@ class PrimaiteSession:
|
||||
"""evaluation episodes counter"""
|
||||
|
||||
self.mode: SessionMode = SessionMode.MANUAL
|
||||
"""Current session mode."""
|
||||
|
||||
self.io_manager = SessionIO()
|
||||
"""IO manager for the session."""
|
||||
|
||||
def start_session(self) -> None:
|
||||
"""Commence the training session."""
|
||||
@@ -190,6 +195,7 @@ class PrimaiteSession:
|
||||
n_eval_episodes = self.training_options.n_eval_episodes
|
||||
deterministic_eval = self.training_options.deterministic_eval
|
||||
self.policy.learn(n_time_steps=n_learn_steps)
|
||||
self.save_models()
|
||||
|
||||
self.mode = SessionMode.EVAL
|
||||
if n_eval_episodes > 0:
|
||||
@@ -198,6 +204,11 @@ class PrimaiteSession:
|
||||
|
||||
self.mode = SessionMode.MANUAL
|
||||
|
||||
def save_models(self) -> None:
|
||||
"""Save the RL models."""
|
||||
save_path = self.io_manager.generate_model_save_path("temp_model_name")
|
||||
self.policy.save(save_path)
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
Perform one step of the simulation/agent loop.
|
||||
@@ -500,4 +511,8 @@ class PrimaiteSession:
|
||||
# CREATE POLICY
|
||||
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
|
||||
|
||||
# READ IO SETTINGS
|
||||
io_settings = cfg.get("io_settings", {})
|
||||
sess.io_manager.settings = SessionIO(settings=SessionIOSettings(**io_settings))
|
||||
|
||||
return sess
|
||||
|
||||
Reference in New Issue
Block a user