This commit is contained in:
Chris McCarthy
2023-07-06 10:07:54 +01:00
parent 6006f022a1
commit f92d2fb65d
3 changed files with 60 additions and 21 deletions

View File

@@ -257,10 +257,16 @@ class AgentSessionABC(ABC):
raise FileNotFoundError(msg)
pass
@property
def _saved_agent_path(self) -> Path:
file_name = (f"{self._training_config.agent_framework}_"
f"{self._training_config.agent_identifier}_"
f"{self.timestamp_str}.zip")
return self.learning_path / file_name
@abstractmethod
def save(self):
"""Save the agent."""
self._agent.save(self.session_path)
pass
@abstractmethod
def export(self):

View File

@@ -1,9 +1,11 @@
from __future__ import annotations
import json
import shutil
from datetime import datetime
from pathlib import Path
from typing import Union
from uuid import uuid4
from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.a2c import A2CConfig
@@ -83,8 +85,10 @@ class RLlibAgent(AgentSessionABC):
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
metadata_dict["total_episodes"] = self._current_result["episodes_total"]
metadata_dict["total_time_steps"] = self._current_result["timesteps_total"]
metadata_dict["total_episodes"] = self._current_result[
"episodes_total"]
metadata_dict["total_time_steps"] = self._current_result[
"timesteps_total"]
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
@@ -107,7 +111,8 @@ class RLlibAgent(AgentSessionABC):
),
)
self._agent_config.training(train_batch_size=self._training_config.num_steps)
self._agent_config.training(
train_batch_size=self._training_config.num_steps)
self._agent_config.framework(framework="tf")
self._agent_config.rollouts(
@@ -115,18 +120,21 @@ class RLlibAgent(AgentSessionABC):
num_envs_per_worker=1,
horizon=self._training_config.num_steps,
)
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
self._agent: Algorithm = self._agent_config.build(
logger_creator=_custom_log_creator(self.learning_path))
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._current_result["episodes_total"]
if checkpoint_n > 0 and episode_count > 0:
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
self._agent.save(str(self.checkpoints_path))
save_checkpoint = False
if checkpoint_n:
save_checkpoint = episode_count % checkpoint_n == 0
if episode_count and save_checkpoint:
self._agent.save(str(self.checkpoints_path))
def learn(
self,
**kwargs,
self,
**kwargs,
):
"""
Evaluate the agent.
@@ -136,16 +144,18 @@ class RLlibAgent(AgentSessionABC):
time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
_LOGGER.info(
f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
for i in range(episodes):
self._current_result = self._agent.train()
self._save_checkpoint()
self.save()
self._agent.stop()
super().learn()
def evaluate(
self,
**kwargs,
self,
**kwargs,
):
"""
Evaluate the agent.
@@ -162,9 +172,29 @@ class RLlibAgent(AgentSessionABC):
"""Load an agent from file."""
raise NotImplementedError
def save(self):
def save(self, overwrite_existing: bool = True):
"""Save the agent."""
raise NotImplementedError
# Make temp dir to save in isolation
temp_dir = self.learning_path / str(uuid4())
temp_dir.mkdir()
# Save the agent to the temp dir
self._agent.save(str(temp_dir))
# Capture the saved Rllib checkpoint inside the temp directory
for file in temp_dir.iterdir():
checkpoint_dir = file
break
shutil.make_archive(
str(self._saved_agent_path).replace(".zip", ""),
"zip",
checkpoint_dir # noqa
)
# Drop the temp directory
shutil.rmtree(temp_dir)
def export(self):
"""Export the agent to transportable file format."""

View File

@@ -64,11 +64,13 @@ class SB3Agent(AgentSessionABC):
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._env.episode_count
if checkpoint_n > 0 and episode_count > 0:
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
self._agent.save(checkpoint_path)
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
save_checkpoint = False
if checkpoint_n:
save_checkpoint = episode_count % checkpoint_n == 0
if episode_count and save_checkpoint:
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
self._agent.save(checkpoint_path)
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
def _get_latest_checkpoint(self):
pass
@@ -90,6 +92,7 @@ class SB3Agent(AgentSessionABC):
self._agent.learn(total_timesteps=time_steps)
self._save_checkpoint()
self._env.reset()
self.save()
self._env.close()
super().learn()
@@ -134,7 +137,7 @@ class SB3Agent(AgentSessionABC):
def save(self):
"""Save the agent."""
raise NotImplementedError
self._agent.save(self._saved_agent_path)
def export(self):
"""Export the agent to transportable file format."""