Get sb3 checkpoints saving during training
This commit is contained in:
@@ -2,7 +2,7 @@ training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_steps: 2560
|
||||
n_learn_episodes: 25
|
||||
n_eval_episodes: 5
|
||||
max_steps_per_episode: 128
|
||||
deterministic_eval: false
|
||||
@@ -10,6 +10,10 @@ training_config:
|
||||
agent_references:
|
||||
- defender
|
||||
|
||||
io_settings:
|
||||
save_checkpoints: true
|
||||
checkpoint_interval: 5
|
||||
|
||||
|
||||
game_config:
|
||||
ports:
|
||||
|
||||
@@ -32,8 +32,8 @@ class SessionIO:
|
||||
"""
|
||||
|
||||
def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None:
|
||||
self.settings = settings
|
||||
self.session_path = self.generate_session_path()
|
||||
self.settings: SessionIOSettings = settings
|
||||
self.session_path: Path = self.generate_session_path()
|
||||
|
||||
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
|
||||
"""Create a folder for the session and return the path to it."""
|
||||
|
||||
@@ -45,12 +45,12 @@ class PolicyABC(ABC):
|
||||
"""Reference to the session."""
|
||||
|
||||
@abstractmethod
|
||||
def learn(self, n_episodes: int, n_time_steps: int) -> None:
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def eval(self, n_episodes: int, n_time_steps: int, deterministic: bool) -> None:
|
||||
def eval(self, n_episodes: int, timesteps_per_episode: int, deterministic: bool) -> None:
|
||||
"""Evaluate the agent."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Literal, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.a2c import MlpPolicy as A2C_MLP
|
||||
from stable_baselines3.common.callbacks import CheckpointCallback
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.ppo import MlpPolicy as PPO_MLP
|
||||
|
||||
@@ -36,9 +37,17 @@ class SB3Policy(PolicyABC, identifier="SB3"):
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
def learn(self, n_time_steps: int) -> None:
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
self._agent.learn(total_timesteps=n_time_steps)
|
||||
if self.session.io_manager.settings.save_checkpoints:
|
||||
checkpoint_callback = CheckpointCallback(
|
||||
save_freq=timesteps_per_episode * self.session.io_manager.settings.checkpoint_interval,
|
||||
save_path=self.session.io_manager.generate_model_save_path("sb3"),
|
||||
name_prefix="sb3_model",
|
||||
)
|
||||
else:
|
||||
checkpoint_callback = None
|
||||
self._agent.learn(total_timesteps=n_episodes * timesteps_per_episode, callback=checkpoint_callback)
|
||||
|
||||
def eval(self, n_episodes: int, deterministic: bool) -> None:
|
||||
"""Evaluate the agent."""
|
||||
@@ -60,12 +69,10 @@ class SB3Policy(PolicyABC, identifier="SB3"):
|
||||
Therefore, this method is only used to save the final model.
|
||||
"""
|
||||
self._agent.save(save_path)
|
||||
pass
|
||||
|
||||
def load(self) -> None:
|
||||
def load(self, model_path: Path) -> None:
|
||||
"""Load agent from a checkpoint."""
|
||||
self._agent_class.load("temp/path/to/save.pth", env=self.session.env)
|
||||
pass
|
||||
self._agent = self._agent_class.load(model_path, env=self.session.env)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the agent."""
|
||||
|
||||
@@ -112,11 +112,12 @@ class TrainingOptions(BaseModel):
|
||||
|
||||
rl_framework: Literal["SB3", "RLLIB"]
|
||||
rl_algorithm: Literal["PPO", "A2C"]
|
||||
seed: Optional[int]
|
||||
n_learn_steps: int
|
||||
n_learn_episodes: int
|
||||
n_eval_episodes: Optional[int] = None
|
||||
max_steps_per_episode: int
|
||||
# checkpoint_freq: Optional[int] = None
|
||||
deterministic_eval: bool
|
||||
seed: Optional[int]
|
||||
n_agents: int
|
||||
agent_references: List[str]
|
||||
|
||||
@@ -188,13 +189,18 @@ class PrimaiteSession:
|
||||
def start_session(self) -> None:
|
||||
"""Commence the training session."""
|
||||
self.mode = SessionMode.TRAIN
|
||||
self.training_progress_bar = progress_bar_manager.counter(
|
||||
total=self.training_options.n_learn_steps, desc="Training steps"
|
||||
)
|
||||
n_learn_steps = self.training_options.n_learn_steps
|
||||
n_learn_episodes = self.training_options.n_learn_episodes
|
||||
n_eval_episodes = self.training_options.n_eval_episodes
|
||||
max_steps_per_episode = self.training_options.max_steps_per_episode
|
||||
self.training_progress_bar = progress_bar_manager.counter(
|
||||
total=n_learn_episodes * max_steps_per_episode, desc="Training steps"
|
||||
)
|
||||
|
||||
deterministic_eval = self.training_options.deterministic_eval
|
||||
self.policy.learn(n_time_steps=n_learn_steps)
|
||||
self.policy.learn(
|
||||
n_episodes=n_learn_episodes,
|
||||
timesteps_per_episode=max_steps_per_episode,
|
||||
)
|
||||
self.save_models()
|
||||
|
||||
self.mode = SessionMode.EVAL
|
||||
@@ -513,6 +519,6 @@ class PrimaiteSession:
|
||||
|
||||
# READ IO SETTINGS
|
||||
io_settings = cfg.get("io_settings", {})
|
||||
sess.io_manager.settings = SessionIO(settings=SessionIOSettings(**io_settings))
|
||||
sess.io_manager = SessionIO(settings=SessionIOSettings(**io_settings))
|
||||
|
||||
return sess
|
||||
|
||||
Reference in New Issue
Block a user