diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 685fe776..fecf84d0 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -257,10 +257,16 @@ class AgentSessionABC(ABC): raise FileNotFoundError(msg) pass + @property + def _saved_agent_path(self) -> Path: + file_name = (f"{self._training_config.agent_framework}_" + f"{self._training_config.agent_identifier}_" + f"{self.timestamp_str}.zip") + return self.learning_path / file_name @abstractmethod def save(self): """Save the agent.""" - self._agent.save(self.session_path) + pass @abstractmethod def export(self): diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index d851ba9c..32dc3dc0 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -1,9 +1,11 @@ from __future__ import annotations import json +import shutil from datetime import datetime from pathlib import Path from typing import Union +from uuid import uuid4 from ray.rllib.algorithms import Algorithm from ray.rllib.algorithms.a2c import A2CConfig @@ -83,8 +85,10 @@ class RLlibAgent(AgentSessionABC): metadata_dict = json.load(file) metadata_dict["end_datetime"] = datetime.now().isoformat() - metadata_dict["total_episodes"] = self._current_result["episodes_total"] - metadata_dict["total_time_steps"] = self._current_result["timesteps_total"] + metadata_dict["total_episodes"] = self._current_result[ + "episodes_total"] + metadata_dict["total_time_steps"] = self._current_result[ + "timesteps_total"] filepath = self.session_path / "session_metadata.json" _LOGGER.debug(f"Updating Session Metadata file: {filepath}") @@ -107,7 +111,8 @@ class RLlibAgent(AgentSessionABC): ), ) - self._agent_config.training(train_batch_size=self._training_config.num_steps) + self._agent_config.training( + train_batch_size=self._training_config.num_steps) self._agent_config.framework(framework="tf") self._agent_config.rollouts( @@ -115,18 +120,21 @@ class RLlibAgent(AgentSessionABC): num_envs_per_worker=1, horizon=self._training_config.num_steps, ) - self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path)) + self._agent: Algorithm = self._agent_config.build( + logger_creator=_custom_log_creator(self.learning_path)) def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes episode_count = self._current_result["episodes_total"] - if checkpoint_n > 0 and episode_count > 0: - if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes): - self._agent.save(str(self.checkpoints_path)) + save_checkpoint = False + if checkpoint_n: + save_checkpoint = episode_count % checkpoint_n == 0 + if episode_count and save_checkpoint: + self._agent.save(str(self.checkpoints_path)) def learn( - self, - **kwargs, + self, + **kwargs, ): """ Evaluate the agent. @@ -136,16 +144,18 @@ class RLlibAgent(AgentSessionABC): time_steps = self._training_config.num_steps episodes = self._training_config.num_episodes - _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...") + _LOGGER.info( + f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...") for i in range(episodes): self._current_result = self._agent.train() self._save_checkpoint() + self.save() self._agent.stop() super().learn() def evaluate( - self, - **kwargs, + self, + **kwargs, ): """ Evaluate the agent. @@ -162,9 +172,29 @@ class RLlibAgent(AgentSessionABC): """Load an agent from file.""" raise NotImplementedError - def save(self): + def save(self, overwrite_existing: bool = True): """Save the agent.""" - raise NotImplementedError + # Make temp dir to save in isolation + temp_dir = self.learning_path / str(uuid4()) + temp_dir.mkdir() + + # Save the agent to the temp dir + self._agent.save(str(temp_dir)) + + # Capture the saved Rllib checkpoint inside the temp directory + for file in temp_dir.iterdir(): + checkpoint_dir = file + break + + shutil.make_archive( + str(self._saved_agent_path).replace(".zip", ""), + "zip", + checkpoint_dir # noqa + ) + + # Drop the temp directory + shutil.rmtree(temp_dir) + def export(self): """Export the agent to transportable file format.""" diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index f5ac44cb..18e208e4 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -64,11 +64,13 @@ class SB3Agent(AgentSessionABC): def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes episode_count = self._env.episode_count - if checkpoint_n > 0 and episode_count > 0: - if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes): - checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip" - self._agent.save(checkpoint_path) - _LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}") + save_checkpoint = False + if checkpoint_n: + save_checkpoint = episode_count % checkpoint_n == 0 + if episode_count and save_checkpoint: + checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip" + self._agent.save(checkpoint_path) + _LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}") def _get_latest_checkpoint(self): pass @@ -90,6 +92,7 @@ class SB3Agent(AgentSessionABC): self._agent.learn(total_timesteps=time_steps) self._save_checkpoint() self._env.reset() + self.save() self._env.close() super().learn() @@ -134,7 +137,7 @@ class SB3Agent(AgentSessionABC): def save(self): """Save the agent.""" - raise NotImplementedError + self._agent.save(self._saved_agent_path) def export(self): """Export the agent to transportable file format."""