temp
This commit is contained in:
@@ -257,10 +257,16 @@ class AgentSessionABC(ABC):
|
||||
raise FileNotFoundError(msg)
|
||||
pass
|
||||
|
||||
@property
|
||||
def _saved_agent_path(self) -> Path:
|
||||
file_name = (f"{self._training_config.agent_framework}_"
|
||||
f"{self._training_config.agent_identifier}_"
|
||||
f"{self.timestamp_str}.zip")
|
||||
return self.learning_path / file_name
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
"""Save the agent."""
|
||||
self._agent.save(self.session_path)
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def export(self):
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from uuid import uuid4
|
||||
|
||||
from ray.rllib.algorithms import Algorithm
|
||||
from ray.rllib.algorithms.a2c import A2CConfig
|
||||
@@ -83,8 +85,10 @@ class RLlibAgent(AgentSessionABC):
|
||||
metadata_dict = json.load(file)
|
||||
|
||||
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
metadata_dict["total_episodes"] = self._current_result["episodes_total"]
|
||||
metadata_dict["total_time_steps"] = self._current_result["timesteps_total"]
|
||||
metadata_dict["total_episodes"] = self._current_result[
|
||||
"episodes_total"]
|
||||
metadata_dict["total_time_steps"] = self._current_result[
|
||||
"timesteps_total"]
|
||||
|
||||
filepath = self.session_path / "session_metadata.json"
|
||||
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||
@@ -107,7 +111,8 @@ class RLlibAgent(AgentSessionABC):
|
||||
),
|
||||
)
|
||||
|
||||
self._agent_config.training(train_batch_size=self._training_config.num_steps)
|
||||
self._agent_config.training(
|
||||
train_batch_size=self._training_config.num_steps)
|
||||
self._agent_config.framework(framework="tf")
|
||||
|
||||
self._agent_config.rollouts(
|
||||
@@ -115,18 +120,21 @@ class RLlibAgent(AgentSessionABC):
|
||||
num_envs_per_worker=1,
|
||||
horizon=self._training_config.num_steps,
|
||||
)
|
||||
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
|
||||
self._agent: Algorithm = self._agent_config.build(
|
||||
logger_creator=_custom_log_creator(self.learning_path))
|
||||
|
||||
def _save_checkpoint(self):
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._current_result["episodes_total"]
|
||||
if checkpoint_n > 0 and episode_count > 0:
|
||||
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
|
||||
self._agent.save(str(self.checkpoints_path))
|
||||
save_checkpoint = False
|
||||
if checkpoint_n:
|
||||
save_checkpoint = episode_count % checkpoint_n == 0
|
||||
if episode_count and save_checkpoint:
|
||||
self._agent.save(str(self.checkpoints_path))
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
@@ -136,16 +144,18 @@ class RLlibAgent(AgentSessionABC):
|
||||
time_steps = self._training_config.num_steps
|
||||
episodes = self._training_config.num_episodes
|
||||
|
||||
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
|
||||
_LOGGER.info(
|
||||
f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
|
||||
for i in range(episodes):
|
||||
self._current_result = self._agent.train()
|
||||
self._save_checkpoint()
|
||||
self.save()
|
||||
self._agent.stop()
|
||||
super().learn()
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
@@ -162,9 +172,29 @@ class RLlibAgent(AgentSessionABC):
|
||||
"""Load an agent from file."""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self):
|
||||
def save(self, overwrite_existing: bool = True):
|
||||
"""Save the agent."""
|
||||
raise NotImplementedError
|
||||
# Make temp dir to save in isolation
|
||||
temp_dir = self.learning_path / str(uuid4())
|
||||
temp_dir.mkdir()
|
||||
|
||||
# Save the agent to the temp dir
|
||||
self._agent.save(str(temp_dir))
|
||||
|
||||
# Capture the saved Rllib checkpoint inside the temp directory
|
||||
for file in temp_dir.iterdir():
|
||||
checkpoint_dir = file
|
||||
break
|
||||
|
||||
shutil.make_archive(
|
||||
str(self._saved_agent_path).replace(".zip", ""),
|
||||
"zip",
|
||||
checkpoint_dir # noqa
|
||||
)
|
||||
|
||||
# Drop the temp directory
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
def export(self):
|
||||
"""Export the agent to transportable file format."""
|
||||
|
||||
@@ -64,11 +64,13 @@ class SB3Agent(AgentSessionABC):
|
||||
def _save_checkpoint(self):
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._env.episode_count
|
||||
if checkpoint_n > 0 and episode_count > 0:
|
||||
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
|
||||
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
|
||||
self._agent.save(checkpoint_path)
|
||||
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
|
||||
save_checkpoint = False
|
||||
if checkpoint_n:
|
||||
save_checkpoint = episode_count % checkpoint_n == 0
|
||||
if episode_count and save_checkpoint:
|
||||
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
|
||||
self._agent.save(checkpoint_path)
|
||||
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
pass
|
||||
@@ -90,6 +92,7 @@ class SB3Agent(AgentSessionABC):
|
||||
self._agent.learn(total_timesteps=time_steps)
|
||||
self._save_checkpoint()
|
||||
self._env.reset()
|
||||
self.save()
|
||||
self._env.close()
|
||||
super().learn()
|
||||
|
||||
@@ -134,7 +137,7 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
def save(self):
|
||||
"""Save the agent."""
|
||||
raise NotImplementedError
|
||||
self._agent.save(self._saved_agent_path)
|
||||
|
||||
def export(self):
|
||||
"""Export the agent to transportable file format."""
|
||||
|
||||
Reference in New Issue
Block a user