Add ability to save sb3 final model

This commit is contained in:
Marek Wolan
2023-11-15 16:59:56 +00:00
parent 64e8b3bcea
commit 4cc7ba1522
4 changed files with 81 additions and 7 deletions

54
src/primaite/game/io.py Normal file
View File

@@ -0,0 +1,54 @@
from datetime import datetime
from pathlib import Path
from typing import Optional
from pydantic import BaseModel
from primaite import PRIMAITE_PATHS
class SessionIOSettings(BaseModel):
"""Schema for session IO settings."""
save_final_model: bool = True
"""Whether to save the final model right at the end of training."""
save_checkpoints: bool = False
"""Whether to save a checkpoint model every `checkpoint_interval` episodes"""
checkpoint_interval: int = 10
"""How often to save a checkpoint model (if save_checkpoints is True)."""
save_logs: bool = True
"""Whether to save logs"""
save_transactions: bool = True
"""Whether to save transactions, If true, the session path will have a transactions folder."""
save_tensorboard_logs: bool = False
"""Whether to save tensorboard logs. If true, the session path will have a tenorboard_logs folder."""
class SessionIO:
"""
Class for managing session IO.
Currently it's handling path generation, but could expand to handle loading, transaction, tensorboard, and so on.
"""
def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None:
self.settings = settings
self.session_path = self.generate_session_path()
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
"""Create a folder for the session and return the path to it."""
if timestamp is None:
timestamp = datetime.now()
date_str = timestamp.strftime("%Y-%m-%d")
time_str = timestamp.strftime("%H-%M-%S")
session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str
session_path.mkdir(exist_ok=True, parents=True)
return session_path
def generate_model_save_path(self, agent_name: str) -> Path:
"""Return the path where the final model will be saved (excluding filename extension)."""
return self.session_path / "checkpoints" / f"{agent_name}_final"
def generate_checkpoint_save_path(self, agent_name: str, episode: int) -> Path:
"""Return the path where the checkpoint model will be saved (excluding filename extension)."""
return self.session_path / "checkpoints" / f"{agent_name}_checkpoint_{episode}.pt"

View File

@@ -1,5 +1,6 @@
"""Base class and common logic for RL policies."""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, TYPE_CHECKING
if TYPE_CHECKING:
@@ -54,7 +55,7 @@ class PolicyABC(ABC):
pass
@abstractmethod
def save(self) -> None:
def save(self, save_path: Path) -> None:
"""Save the agent."""
pass

View File

@@ -1,4 +1,5 @@
"""Stable baselines 3 policy."""
from pathlib import Path
from typing import Literal, Optional, TYPE_CHECKING, Union
from stable_baselines3 import A2C, PPO
@@ -50,12 +51,15 @@ class SB3Policy(PolicyABC, identifier="SB3"):
)
print(reward_data)
def save(self) -> None:
"""Save the agent."""
savepath = (
"temp/path/to/save.pth" # TODO: populate values once I figure out how to get them from the config / session
)
self._agent.save(savepath)
def save(self, save_path: Path) -> None:
"""
Save the current policy parameters.
Warning: The recommended way to save model checkpoints is to use a callback within the `learn()` method. Please
refer to https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html for more information.
Therefore, this method is only used to save the final model.
"""
self._agent.save(save_path)
pass
def load(self) -> None:

View File

@@ -13,6 +13,7 @@ from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.io import SessionIO, SessionIOSettings
from primaite.game.policy.policy import PolicyABC
from primaite.simulator.network.hardware.base import Link, NIC, Node
from primaite.simulator.network.hardware.nodes.computer import Computer
@@ -179,6 +180,10 @@ class PrimaiteSession:
"""evaluation episodes counter"""
self.mode: SessionMode = SessionMode.MANUAL
"""Current session mode."""
self.io_manager = SessionIO()
"""IO manager for the session."""
def start_session(self) -> None:
"""Commence the training session."""
@@ -190,6 +195,7 @@ class PrimaiteSession:
n_eval_episodes = self.training_options.n_eval_episodes
deterministic_eval = self.training_options.deterministic_eval
self.policy.learn(n_time_steps=n_learn_steps)
self.save_models()
self.mode = SessionMode.EVAL
if n_eval_episodes > 0:
@@ -198,6 +204,11 @@ class PrimaiteSession:
self.mode = SessionMode.MANUAL
def save_models(self) -> None:
"""Save the RL models."""
save_path = self.io_manager.generate_model_save_path("temp_model_name")
self.policy.save(save_path)
def step(self):
"""
Perform one step of the simulation/agent loop.
@@ -500,4 +511,8 @@ class PrimaiteSession:
# CREATE POLICY
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
# READ IO SETTINGS
io_settings = cfg.get("io_settings", {})
sess.io_manager.settings = SessionIO(settings=SessionIOSettings(**io_settings))
return sess