diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 685fe776..32118597 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -257,10 +257,19 @@ class AgentSessionABC(ABC): raise FileNotFoundError(msg) pass + @property + def _saved_agent_path(self) -> Path: + file_name = ( + f"{self._training_config.agent_framework}_" + f"{self._training_config.agent_identifier}_" + f"{self.timestamp_str}.zip" + ) + return self.learning_path / file_name + @abstractmethod def save(self): """Save the agent.""" - self._agent.save(self.session_path) + pass @abstractmethod def export(self): diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index d851ba9c..427072c4 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -1,9 +1,11 @@ from __future__ import annotations import json +import shutil from datetime import datetime from pathlib import Path from typing import Union +from uuid import uuid4 from ray.rllib.algorithms import Algorithm from ray.rllib.algorithms.a2c import A2CConfig @@ -120,9 +122,11 @@ class RLlibAgent(AgentSessionABC): def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes episode_count = self._current_result["episodes_total"] - if checkpoint_n > 0 and episode_count > 0: - if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes): - self._agent.save(str(self.checkpoints_path)) + save_checkpoint = False + if checkpoint_n: + save_checkpoint = episode_count % checkpoint_n == 0 + if episode_count and save_checkpoint: + self._agent.save(str(self.checkpoints_path)) def learn( self, @@ -140,6 +144,7 @@ class RLlibAgent(AgentSessionABC): for i in range(episodes): self._current_result = self._agent.train() self._save_checkpoint() + self.save() self._agent.stop() super().learn() @@ -162,9 +167,25 @@ class RLlibAgent(AgentSessionABC): """Load an agent from file.""" raise NotImplementedError - def save(self): + def save(self, overwrite_existing: bool = True): """Save the agent.""" - raise NotImplementedError + # Make temp dir to save in isolation + temp_dir = self.learning_path / str(uuid4()) + temp_dir.mkdir() + + # Save the agent to the temp dir + self._agent.save(str(temp_dir)) + + # Capture the saved Rllib checkpoint inside the temp directory + for file in temp_dir.iterdir(): + checkpoint_dir = file + break + + # Zip the folder + shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa + + # Drop the temp directory + shutil.rmtree(temp_dir) def export(self): """Export the agent to transportable file format.""" diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index f5ac44cb..18e208e4 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -64,11 +64,13 @@ class SB3Agent(AgentSessionABC): def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes episode_count = self._env.episode_count - if checkpoint_n > 0 and episode_count > 0: - if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes): - checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip" - self._agent.save(checkpoint_path) - _LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}") + save_checkpoint = False + if checkpoint_n: + save_checkpoint = episode_count % checkpoint_n == 0 + if episode_count and save_checkpoint: + checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip" + self._agent.save(checkpoint_path) + _LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}") def _get_latest_checkpoint(self): pass @@ -90,6 +92,7 @@ class SB3Agent(AgentSessionABC): self._agent.learn(total_timesteps=time_steps) self._save_checkpoint() self._env.reset() + self.save() self._env.close() super().learn() @@ -134,7 +137,7 @@ class SB3Agent(AgentSessionABC): def save(self): """Save the agent.""" - raise NotImplementedError + self._agent.save(self._saved_agent_path) def export(self): """Export the agent to transportable file format.""" diff --git a/tests/test_primaite_session.py b/tests/test_primaite_session.py index ae0b0870..75ea5882 100644 --- a/tests/test_primaite_session.py +++ b/tests/test_primaite_session.py @@ -33,6 +33,9 @@ def test_primaite_session(temp_primaite_session): # Check that the network png file exists assert (session_path / f"network_{session.timestamp_str}.png").exists() + # Check that the saved agent exists + assert session._agent_session._saved_agent_path.exists() + # Check that both the transactions and av reward csv files exist for file in session.learning_path.iterdir(): if file.suffix == ".csv":