Get sb3 checkpoints saving during training

This commit is contained in:
Marek Wolan
2023-11-16 14:37:37 +00:00
parent 4cc7ba1522
commit 829500a60f
5 changed files with 36 additions and 19 deletions

View File

@@ -2,7 +2,7 @@ training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_steps: 2560
n_learn_episodes: 25
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
@@ -10,6 +10,10 @@ training_config:
agent_references:
- defender
io_settings:
save_checkpoints: true
checkpoint_interval: 5
game_config:
ports:

View File

@@ -32,8 +32,8 @@ class SessionIO:
"""
def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None:
self.settings = settings
self.session_path = self.generate_session_path()
self.settings: SessionIOSettings = settings
self.session_path: Path = self.generate_session_path()
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
"""Create a folder for the session and return the path to it."""

View File

@@ -45,12 +45,12 @@ class PolicyABC(ABC):
"""Reference to the session."""
@abstractmethod
def learn(self, n_episodes: int, n_time_steps: int) -> None:
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
pass
@abstractmethod
def eval(self, n_episodes: int, n_time_steps: int, deterministic: bool) -> None:
def eval(self, n_episodes: int, timesteps_per_episode: int, deterministic: bool) -> None:
"""Evaluate the agent."""
pass

View File

@@ -4,6 +4,7 @@ from typing import Literal, Optional, TYPE_CHECKING, Union
from stable_baselines3 import A2C, PPO
from stable_baselines3.a2c import MlpPolicy as A2C_MLP
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo import MlpPolicy as PPO_MLP
@@ -36,9 +37,17 @@ class SB3Policy(PolicyABC, identifier="SB3"):
seed=seed,
)
def learn(self, n_time_steps: int) -> None:
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
self._agent.learn(total_timesteps=n_time_steps)
if self.session.io_manager.settings.save_checkpoints:
checkpoint_callback = CheckpointCallback(
save_freq=timesteps_per_episode * self.session.io_manager.settings.checkpoint_interval,
save_path=self.session.io_manager.generate_model_save_path("sb3"),
name_prefix="sb3_model",
)
else:
checkpoint_callback = None
self._agent.learn(total_timesteps=n_episodes * timesteps_per_episode, callback=checkpoint_callback)
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate the agent."""
@@ -60,12 +69,10 @@ class SB3Policy(PolicyABC, identifier="SB3"):
Therefore, this method is only used to save the final model.
"""
self._agent.save(save_path)
pass
def load(self) -> None:
def load(self, model_path: Path) -> None:
"""Load agent from a checkpoint."""
self._agent_class.load("temp/path/to/save.pth", env=self.session.env)
pass
self._agent = self._agent_class.load(model_path, env=self.session.env)
def close(self) -> None:
"""Close the agent."""

View File

@@ -112,11 +112,12 @@ class TrainingOptions(BaseModel):
rl_framework: Literal["SB3", "RLLIB"]
rl_algorithm: Literal["PPO", "A2C"]
seed: Optional[int]
n_learn_steps: int
n_learn_episodes: int
n_eval_episodes: Optional[int] = None
max_steps_per_episode: int
# checkpoint_freq: Optional[int] = None
deterministic_eval: bool
seed: Optional[int]
n_agents: int
agent_references: List[str]
@@ -188,13 +189,18 @@ class PrimaiteSession:
def start_session(self) -> None:
"""Commence the training session."""
self.mode = SessionMode.TRAIN
self.training_progress_bar = progress_bar_manager.counter(
total=self.training_options.n_learn_steps, desc="Training steps"
)
n_learn_steps = self.training_options.n_learn_steps
n_learn_episodes = self.training_options.n_learn_episodes
n_eval_episodes = self.training_options.n_eval_episodes
max_steps_per_episode = self.training_options.max_steps_per_episode
self.training_progress_bar = progress_bar_manager.counter(
total=n_learn_episodes * max_steps_per_episode, desc="Training steps"
)
deterministic_eval = self.training_options.deterministic_eval
self.policy.learn(n_time_steps=n_learn_steps)
self.policy.learn(
n_episodes=n_learn_episodes,
timesteps_per_episode=max_steps_per_episode,
)
self.save_models()
self.mode = SessionMode.EVAL
@@ -513,6 +519,6 @@ class PrimaiteSession:
# READ IO SETTINGS
io_settings = cfg.get("io_settings", {})
sess.io_manager.settings = SessionIO(settings=SessionIOSettings(**io_settings))
sess.io_manager = SessionIO(settings=SessionIOSettings(**io_settings))
return sess