Merged PR 109: Auto save agent at end of training
## Summary * Made RLlib and SB3 agents save at the end of each learning session by default using a common file naming format. Also now agents only checkpoint every n and not on the final episode. ## Test process *Tests saved agent file in the test_primaite_session test. ## Checklist - [X] This PR is linked to a **work item** - [X] I have performed **self-review** of the code - [X] I have written **tests** for any new functionality added with this PR - [ ] I have updated the **documentation** if this PR changes or adds functionality - [X] I have run **pre-commit** checks for code style Related work items: #1593
This commit is contained in:
@@ -257,10 +257,19 @@ class AgentSessionABC(ABC):
|
||||
raise FileNotFoundError(msg)
|
||||
pass
|
||||
|
||||
@property
|
||||
def _saved_agent_path(self) -> Path:
|
||||
file_name = (
|
||||
f"{self._training_config.agent_framework}_"
|
||||
f"{self._training_config.agent_identifier}_"
|
||||
f"{self.timestamp_str}.zip"
|
||||
)
|
||||
return self.learning_path / file_name
|
||||
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
"""Save the agent."""
|
||||
self._agent.save(self.session_path)
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def export(self):
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from uuid import uuid4
|
||||
|
||||
from ray.rllib.algorithms import Algorithm
|
||||
from ray.rllib.algorithms.a2c import A2CConfig
|
||||
@@ -120,9 +122,11 @@ class RLlibAgent(AgentSessionABC):
|
||||
def _save_checkpoint(self):
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._current_result["episodes_total"]
|
||||
if checkpoint_n > 0 and episode_count > 0:
|
||||
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
|
||||
self._agent.save(str(self.checkpoints_path))
|
||||
save_checkpoint = False
|
||||
if checkpoint_n:
|
||||
save_checkpoint = episode_count % checkpoint_n == 0
|
||||
if episode_count and save_checkpoint:
|
||||
self._agent.save(str(self.checkpoints_path))
|
||||
|
||||
def learn(
|
||||
self,
|
||||
@@ -140,6 +144,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
for i in range(episodes):
|
||||
self._current_result = self._agent.train()
|
||||
self._save_checkpoint()
|
||||
self.save()
|
||||
self._agent.stop()
|
||||
super().learn()
|
||||
|
||||
@@ -162,9 +167,25 @@ class RLlibAgent(AgentSessionABC):
|
||||
"""Load an agent from file."""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self):
|
||||
def save(self, overwrite_existing: bool = True):
|
||||
"""Save the agent."""
|
||||
raise NotImplementedError
|
||||
# Make temp dir to save in isolation
|
||||
temp_dir = self.learning_path / str(uuid4())
|
||||
temp_dir.mkdir()
|
||||
|
||||
# Save the agent to the temp dir
|
||||
self._agent.save(str(temp_dir))
|
||||
|
||||
# Capture the saved Rllib checkpoint inside the temp directory
|
||||
for file in temp_dir.iterdir():
|
||||
checkpoint_dir = file
|
||||
break
|
||||
|
||||
# Zip the folder
|
||||
shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa
|
||||
|
||||
# Drop the temp directory
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
def export(self):
|
||||
"""Export the agent to transportable file format."""
|
||||
|
||||
@@ -64,11 +64,13 @@ class SB3Agent(AgentSessionABC):
|
||||
def _save_checkpoint(self):
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._env.episode_count
|
||||
if checkpoint_n > 0 and episode_count > 0:
|
||||
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
|
||||
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
|
||||
self._agent.save(checkpoint_path)
|
||||
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
|
||||
save_checkpoint = False
|
||||
if checkpoint_n:
|
||||
save_checkpoint = episode_count % checkpoint_n == 0
|
||||
if episode_count and save_checkpoint:
|
||||
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
|
||||
self._agent.save(checkpoint_path)
|
||||
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
pass
|
||||
@@ -90,6 +92,7 @@ class SB3Agent(AgentSessionABC):
|
||||
self._agent.learn(total_timesteps=time_steps)
|
||||
self._save_checkpoint()
|
||||
self._env.reset()
|
||||
self.save()
|
||||
self._env.close()
|
||||
super().learn()
|
||||
|
||||
@@ -134,7 +137,7 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
def save(self):
|
||||
"""Save the agent."""
|
||||
raise NotImplementedError
|
||||
self._agent.save(self._saved_agent_path)
|
||||
|
||||
def export(self):
|
||||
"""Export the agent to transportable file format."""
|
||||
|
||||
@@ -33,6 +33,9 @@ def test_primaite_session(temp_primaite_session):
|
||||
# Check that the network png file exists
|
||||
assert (session_path / f"network_{session.timestamp_str}.png").exists()
|
||||
|
||||
# Check that the saved agent exists
|
||||
assert session._agent_session._saved_agent_path.exists()
|
||||
|
||||
# Check that both the transactions and av reward csv files exist
|
||||
for file in session.learning_path.iterdir():
|
||||
if file.suffix == ".csv":
|
||||
|
||||
Reference in New Issue
Block a user