diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index dca9620f..e0ff9276 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -2,7 +2,7 @@ training_config: rl_framework: SB3 rl_algorithm: PPO seed: 333 - n_learn_steps: 2560 + n_learn_episodes: 25 n_eval_episodes: 5 max_steps_per_episode: 128 deterministic_eval: false @@ -10,6 +10,10 @@ training_config: agent_references: - defender +io_settings: + save_checkpoints: true + checkpoint_interval: 5 + game_config: ports: diff --git a/src/primaite/game/io.py b/src/primaite/game/io.py index 76d5ed1c..e613316d 100644 --- a/src/primaite/game/io.py +++ b/src/primaite/game/io.py @@ -32,8 +32,8 @@ class SessionIO: """ def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None: - self.settings = settings - self.session_path = self.generate_session_path() + self.settings: SessionIOSettings = settings + self.session_path: Path = self.generate_session_path() def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path: """Create a folder for the session and return the path to it.""" diff --git a/src/primaite/game/policy/policy.py b/src/primaite/game/policy/policy.py index 6a2381c1..a7052367 100644 --- a/src/primaite/game/policy/policy.py +++ b/src/primaite/game/policy/policy.py @@ -45,12 +45,12 @@ class PolicyABC(ABC): """Reference to the session.""" @abstractmethod - def learn(self, n_episodes: int, n_time_steps: int) -> None: + def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: """Train the agent.""" pass @abstractmethod - def eval(self, n_episodes: int, n_time_steps: int, deterministic: bool) -> None: + def eval(self, n_episodes: int, timesteps_per_episode: int, deterministic: bool) -> None: """Evaluate the agent.""" pass diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py index 1be4f915..10f22e05 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/game/policy/sb3.py @@ -4,6 +4,7 @@ from typing import Literal, Optional, TYPE_CHECKING, Union from stable_baselines3 import A2C, PPO from stable_baselines3.a2c import MlpPolicy as A2C_MLP +from stable_baselines3.common.callbacks import CheckpointCallback from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.ppo import MlpPolicy as PPO_MLP @@ -36,9 +37,17 @@ class SB3Policy(PolicyABC, identifier="SB3"): seed=seed, ) - def learn(self, n_time_steps: int) -> None: + def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: """Train the agent.""" - self._agent.learn(total_timesteps=n_time_steps) + if self.session.io_manager.settings.save_checkpoints: + checkpoint_callback = CheckpointCallback( + save_freq=timesteps_per_episode * self.session.io_manager.settings.checkpoint_interval, + save_path=self.session.io_manager.generate_model_save_path("sb3"), + name_prefix="sb3_model", + ) + else: + checkpoint_callback = None + self._agent.learn(total_timesteps=n_episodes * timesteps_per_episode, callback=checkpoint_callback) def eval(self, n_episodes: int, deterministic: bool) -> None: """Evaluate the agent.""" @@ -60,12 +69,10 @@ class SB3Policy(PolicyABC, identifier="SB3"): Therefore, this method is only used to save the final model. """ self._agent.save(save_path) - pass - def load(self) -> None: + def load(self, model_path: Path) -> None: """Load agent from a checkpoint.""" - self._agent_class.load("temp/path/to/save.pth", env=self.session.env) - pass + self._agent = self._agent_class.load(model_path, env=self.session.env) def close(self) -> None: """Close the agent.""" diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 37c34da9..a2e83cbb 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -112,11 +112,12 @@ class TrainingOptions(BaseModel): rl_framework: Literal["SB3", "RLLIB"] rl_algorithm: Literal["PPO", "A2C"] - seed: Optional[int] - n_learn_steps: int + n_learn_episodes: int n_eval_episodes: Optional[int] = None max_steps_per_episode: int + # checkpoint_freq: Optional[int] = None deterministic_eval: bool + seed: Optional[int] n_agents: int agent_references: List[str] @@ -188,13 +189,18 @@ class PrimaiteSession: def start_session(self) -> None: """Commence the training session.""" self.mode = SessionMode.TRAIN - self.training_progress_bar = progress_bar_manager.counter( - total=self.training_options.n_learn_steps, desc="Training steps" - ) - n_learn_steps = self.training_options.n_learn_steps + n_learn_episodes = self.training_options.n_learn_episodes n_eval_episodes = self.training_options.n_eval_episodes + max_steps_per_episode = self.training_options.max_steps_per_episode + self.training_progress_bar = progress_bar_manager.counter( + total=n_learn_episodes * max_steps_per_episode, desc="Training steps" + ) + deterministic_eval = self.training_options.deterministic_eval - self.policy.learn(n_time_steps=n_learn_steps) + self.policy.learn( + n_episodes=n_learn_episodes, + timesteps_per_episode=max_steps_per_episode, + ) self.save_models() self.mode = SessionMode.EVAL @@ -513,6 +519,6 @@ class PrimaiteSession: # READ IO SETTINGS io_settings = cfg.get("io_settings", {}) - sess.io_manager.settings = SessionIO(settings=SessionIOSettings(**io_settings)) + sess.io_manager = SessionIO(settings=SessionIOSettings(**io_settings)) return sess