Merged PR 109: Auto save agent at end of training

## Summary
* Made RLlib and SB3 agents save at the end of each learning session by default using a common file naming format. Also now agents only checkpoint every n and not on the final episode.

## Test process
*Tests saved agent file in the test_primaite_session test.

## Checklist
- [X] This PR is linked to a **work item**
- [X] I have performed **self-review** of the code
- [X] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [X] I have run **pre-commit** checks for code style

Related work items: #1593
This commit is contained in:
Christopher McCarthy
2023-07-06 16:29:48 +00:00
4 changed files with 48 additions and 12 deletions

View File

@@ -257,10 +257,19 @@ class AgentSessionABC(ABC):
raise FileNotFoundError(msg)
pass
@property
def _saved_agent_path(self) -> Path:
file_name = (
f"{self._training_config.agent_framework}_"
f"{self._training_config.agent_identifier}_"
f"{self.timestamp_str}.zip"
)
return self.learning_path / file_name
@abstractmethod
def save(self):
"""Save the agent."""
self._agent.save(self.session_path)
pass
@abstractmethod
def export(self):

View File

@@ -1,9 +1,11 @@
from __future__ import annotations
import json
import shutil
from datetime import datetime
from pathlib import Path
from typing import Union
from uuid import uuid4
from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.a2c import A2CConfig
@@ -120,9 +122,11 @@ class RLlibAgent(AgentSessionABC):
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._current_result["episodes_total"]
if checkpoint_n > 0 and episode_count > 0:
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
self._agent.save(str(self.checkpoints_path))
save_checkpoint = False
if checkpoint_n:
save_checkpoint = episode_count % checkpoint_n == 0
if episode_count and save_checkpoint:
self._agent.save(str(self.checkpoints_path))
def learn(
self,
@@ -140,6 +144,7 @@ class RLlibAgent(AgentSessionABC):
for i in range(episodes):
self._current_result = self._agent.train()
self._save_checkpoint()
self.save()
self._agent.stop()
super().learn()
@@ -162,9 +167,25 @@ class RLlibAgent(AgentSessionABC):
"""Load an agent from file."""
raise NotImplementedError
def save(self):
def save(self, overwrite_existing: bool = True):
"""Save the agent."""
raise NotImplementedError
# Make temp dir to save in isolation
temp_dir = self.learning_path / str(uuid4())
temp_dir.mkdir()
# Save the agent to the temp dir
self._agent.save(str(temp_dir))
# Capture the saved Rllib checkpoint inside the temp directory
for file in temp_dir.iterdir():
checkpoint_dir = file
break
# Zip the folder
shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa
# Drop the temp directory
shutil.rmtree(temp_dir)
def export(self):
"""Export the agent to transportable file format."""

View File

@@ -64,11 +64,13 @@ class SB3Agent(AgentSessionABC):
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._env.episode_count
if checkpoint_n > 0 and episode_count > 0:
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
self._agent.save(checkpoint_path)
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
save_checkpoint = False
if checkpoint_n:
save_checkpoint = episode_count % checkpoint_n == 0
if episode_count and save_checkpoint:
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
self._agent.save(checkpoint_path)
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
def _get_latest_checkpoint(self):
pass
@@ -90,6 +92,7 @@ class SB3Agent(AgentSessionABC):
self._agent.learn(total_timesteps=time_steps)
self._save_checkpoint()
self._env.reset()
self.save()
self._env.close()
super().learn()
@@ -134,7 +137,7 @@ class SB3Agent(AgentSessionABC):
def save(self):
"""Save the agent."""
raise NotImplementedError
self._agent.save(self._saved_agent_path)
def export(self):
"""Export the agent to transportable file format."""

View File

@@ -33,6 +33,9 @@ def test_primaite_session(temp_primaite_session):
# Check that the network png file exists
assert (session_path / f"network_{session.timestamp_str}.png").exists()
# Check that the saved agent exists
assert session._agent_session._saved_agent_path.exists()
# Check that both the transactions and av reward csv files exist
for file in session.learning_path.iterdir():
if file.suffix == ".csv":