Merge remote-tracking branch 'origin/dev' into feature/1572-fix-docs-formatting

This commit is contained in:
Marek Wolan
2023-07-07 10:30:11 +01:00
17 changed files with 590 additions and 200 deletions

View File

@@ -266,10 +266,19 @@ class AgentSessionABC(ABC):
raise FileNotFoundError(msg)
pass
@property
def _saved_agent_path(self) -> Path:
file_name = (
f"{self._training_config.agent_framework}_"
f"{self._training_config.agent_identifier}_"
f"{self.timestamp_str}.zip"
)
return self.learning_path / file_name
@abstractmethod
def save(self):
"""Save the agent."""
self._agent.save(self.session_path)
pass
@abstractmethod
def export(self):

View File

@@ -1,9 +1,11 @@
from __future__ import annotations
import json
import shutil
from datetime import datetime
from pathlib import Path
from typing import Union
from uuid import uuid4
from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.a2c import A2CConfig
@@ -118,6 +120,7 @@ class RLlibAgent(AgentSessionABC):
timestamp_str=self.timestamp_str,
),
)
self._agent_config.seed = self._training_config.seed
self._agent_config.training(train_batch_size=self._training_config.num_steps)
self._agent_config.framework(framework="tf")
@@ -132,9 +135,11 @@ class RLlibAgent(AgentSessionABC):
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._current_result["episodes_total"]
if checkpoint_n > 0 and episode_count > 0:
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
self._agent.save(str(self.checkpoints_path))
save_checkpoint = False
if checkpoint_n:
save_checkpoint = episode_count % checkpoint_n == 0
if episode_count and save_checkpoint:
self._agent.save(str(self.checkpoints_path))
def learn(
self,
@@ -152,9 +157,14 @@ class RLlibAgent(AgentSessionABC):
for i in range(episodes):
self._current_result = self._agent.train()
self._save_checkpoint()
self.save()
self._agent.stop()
super().learn()
# save agent
self.save()
def evaluate(
self,
**kwargs,
@@ -174,9 +184,25 @@ class RLlibAgent(AgentSessionABC):
"""Load an agent from file."""
raise NotImplementedError
def save(self):
def save(self, overwrite_existing: bool = True):
"""Save the agent."""
raise NotImplementedError
# Make temp dir to save in isolation
temp_dir = self.learning_path / str(uuid4())
temp_dir.mkdir()
# Save the agent to the temp dir
self._agent.save(str(temp_dir))
# Capture the saved Rllib checkpoint inside the temp directory
for file in temp_dir.iterdir():
checkpoint_dir = file
break
# Zip the folder
shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa
# Drop the temp directory
shutil.rmtree(temp_dir)
def export(self):
"""Export the agent to transportable file format."""

View File

@@ -71,16 +71,19 @@ class SB3Agent(AgentSessionABC):
verbose=self.sb3_output_verbose_level,
n_steps=self._training_config.num_steps,
tensorboard_log=str(self._tensorboard_log_path),
seed=self._training_config.seed,
)
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._env.episode_count
if checkpoint_n > 0 and episode_count > 0:
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
self._agent.save(checkpoint_path)
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
save_checkpoint = False
if checkpoint_n:
save_checkpoint = episode_count % checkpoint_n == 0
if episode_count and save_checkpoint:
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
self._agent.save(checkpoint_path)
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
def _get_latest_checkpoint(self):
pass
@@ -102,25 +105,27 @@ class SB3Agent(AgentSessionABC):
self._agent.learn(total_timesteps=time_steps)
self._save_checkpoint()
self._env.reset()
self.save()
self._env.close()
super().learn()
# save agent
self.save()
def evaluate(
self,
deterministic: bool = True,
**kwargs,
):
"""
Evaluate the agent.
:param deterministic: Whether the evaluation is deterministic.
:param kwargs: Any agent-specific key-word args to be passed.
"""
time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes
self._env.set_as_eval()
self.is_eval = True
if deterministic:
if self._training_config.deterministic:
deterministic_str = "deterministic"
else:
deterministic_str = "non-deterministic"
@@ -131,7 +136,7 @@ class SB3Agent(AgentSessionABC):
obs = self._env.reset()
for step in range(time_steps):
action, _states = self._agent.predict(obs, deterministic=deterministic)
action, _states = self._agent.predict(obs, deterministic=self._training_config.deterministic)
if isinstance(action, np.ndarray):
action = np.int64(action)
obs, rewards, done, info = self._env.step(action)
@@ -146,7 +151,7 @@ class SB3Agent(AgentSessionABC):
def save(self):
"""Save the agent."""
raise NotImplementedError
self._agent.save(self._saved_agent_path)
def export(self):
"""Export the agent to transportable file format."""