Files
PrimAITE/src/primaite/agents/rllib.py
Chris McCarthy 8edb26a65c Merge remote-tracking branch 'origin/dev' into 1566-configure-episode-steps-learn-eval
# Conflicts:
#	src/primaite/agents/rllib.py
2023-07-07 14:34:20 +01:00

196 lines
6.7 KiB
Python

from __future__ import annotations
import json
import shutil
from datetime import datetime
from pathlib import Path
from typing import Union
from uuid import uuid4
from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.a2c import A2CConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import UnifiedLogger
from ray.tune.registry import register_env
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
def _env_creator(env_config):
return Primaite(
training_config_path=env_config["training_config_path"],
lay_down_config_path=env_config["lay_down_config_path"],
session_path=env_config["session_path"],
timestamp_str=env_config["timestamp_str"],
)
def _custom_log_creator(session_path: Path):
logdir = session_path / "ray_results"
logdir.mkdir(parents=True, exist_ok=True)
def logger_creator(config):
return UnifiedLogger(config, logdir, loggers=None)
return logger_creator
class RLlibAgent(AgentSessionABC):
"""An AgentSession class that implements a Ray RLlib agent."""
def __init__(self, training_config_path, lay_down_config_path):
super().__init__(training_config_path, lay_down_config_path)
if not self._training_config.agent_framework == AgentFramework.RLLIB:
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
raise ValueError(msg)
if self._training_config.agent_identifier == AgentIdentifier.PPO:
self._agent_config_class = PPOConfig
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
self._agent_config_class = A2CConfig
else:
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}"
_LOGGER.error(msg)
raise ValueError(msg)
self._agent_config: Union[PPOConfig, A2CConfig]
self._current_result: dict
self._setup()
_LOGGER.debug(
f"Created {self.__class__.__name__} using: "
f"agent_framework={self._training_config.agent_framework}, "
f"agent_identifier="
f"{self._training_config.agent_identifier}, "
f"deep_learning_framework="
f"{self._training_config.deep_learning_framework}"
)
def _update_session_metadata_file(self):
"""
Update the ``session_metadata.json`` file.
Updates the `session_metadata.json`` in the ``session_path`` directory
with the following key/value pairs:
- end_datetime: The date & time the session ended in iso format.
- total_episodes: The total number of training episodes completed.
- total_time_steps: The total number of training time steps completed.
"""
with open(self.session_path / "session_metadata.json", "r") as file:
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
metadata_dict["total_episodes"] = self._current_result["episodes_total"]
metadata_dict["total_time_steps"] = self._current_result["timesteps_total"]
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
with open(filepath, "w") as file:
json.dump(metadata_dict, file)
_LOGGER.debug("Finished updating session metadata file")
def _setup(self):
super()._setup()
register_env("primaite", _env_creator)
self._agent_config = self._agent_config_class()
self._agent_config.environment(
env="primaite",
env_config=dict(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
),
)
self._agent_config.seed = self._training_config.seed
self._agent_config.training(train_batch_size=self._training_config.num_train_steps)
self._agent_config.framework(framework="tf")
self._agent_config.rollouts(
num_rollout_workers=1,
num_envs_per_worker=1,
horizon=self._training_config.num_train_steps,
)
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._current_result["episodes_total"]
save_checkpoint = False
if checkpoint_n:
save_checkpoint = episode_count % checkpoint_n == 0
if episode_count and save_checkpoint:
self._agent.save(str(self.checkpoints_path))
def learn(
self,
**kwargs,
):
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
time_steps = self._training_config.num_train_steps
episodes = self._training_config.num_train_episodes
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
for i in range(episodes):
self._current_result = self._agent.train()
self._save_checkpoint()
self.save()
self._agent.stop()
super().learn()
def evaluate(
self,
**kwargs,
):
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
raise NotImplementedError
def _get_latest_checkpoint(self):
raise NotImplementedError
@classmethod
def load(cls, path: Union[str, Path]) -> RLlibAgent:
"""Load an agent from file."""
raise NotImplementedError
def save(self, overwrite_existing: bool = True):
"""Save the agent."""
# Make temp dir to save in isolation
temp_dir = self.learning_path / str(uuid4())
temp_dir.mkdir()
# Save the agent to the temp dir
self._agent.save(str(temp_dir))
# Capture the saved Rllib checkpoint inside the temp directory
for file in temp_dir.iterdir():
checkpoint_dir = file
break
# Zip the folder
shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa
# Drop the temp directory
shutil.rmtree(temp_dir)
def export(self):
"""Export the agent to transportable file format."""
raise NotImplementedError