Files
PrimAITE/src/primaite/agents/rllib.py

287 lines
11 KiB
Python

# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from __future__ import annotations
import json
import shutil
import zipfile
from datetime import datetime
from logging import Logger
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union
from uuid import uuid4
from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.a2c import A2CConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import UnifiedLogger
from ray.tune.registry import register_env
from primaite import getLogger
from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier, SessionType
from primaite.environment.primaite_env import Primaite
from primaite.exceptions import RLlibAgentError
_LOGGER: Logger = getLogger(__name__)
# TODO: verify type of env_config
def _env_creator(env_config: Dict[str, Any]) -> Primaite:
return Primaite(
training_config_path=env_config["training_config_path"],
lay_down_config_path=env_config["lay_down_config_path"],
session_path=env_config["session_path"],
timestamp_str=env_config["timestamp_str"],
)
# TODO: verify type hint return type
def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]:
logdir = session_path / "ray_results"
logdir.mkdir(parents=True, exist_ok=True)
def logger_creator(config: Dict) -> UnifiedLogger:
return UnifiedLogger(config, logdir, loggers=None)
return logger_creator
class RLlibAgent(AgentSessionABC):
"""An AgentSession class that implements a Ray RLlib agent."""
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
) -> None:
"""
Initialise the RLLib Agent training session.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:raises ValueError: If the training config contains an unexpected value for agent_framework (should be "RLLIB")
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
or `A2C`)
"""
# TODO: implement RLlib agent loading
if session_path is not None:
msg = "RLlib agent loading has not been implemented yet"
_LOGGER.critical(msg)
raise NotImplementedError(msg)
super().__init__(training_config_path, lay_down_config_path)
if self._training_config.session_type == SessionType.EVAL:
msg = "Cannot evaluate an RLlib agent that hasn't been through training yet."
_LOGGER.critical(msg)
raise RLlibAgentError(msg)
if not self._training_config.agent_framework == AgentFramework.RLLIB:
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
raise ValueError(msg)
self._agent_config_class: Union[PPOConfig, A2CConfig]
if self._training_config.agent_identifier == AgentIdentifier.PPO:
self._agent_config_class = PPOConfig
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
self._agent_config_class = A2CConfig
else:
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}"
_LOGGER.error(msg)
raise ValueError(msg)
self._agent_config: Union[PPOConfig, A2CConfig]
self._current_result: dict
self._setup()
_LOGGER.debug(
f"Created {self.__class__.__name__} using: "
f"agent_framework={self._training_config.agent_framework}, "
f"agent_identifier="
f"{self._training_config.agent_identifier}, "
f"deep_learning_framework="
f"{self._training_config.deep_learning_framework}"
)
self._train_agent = None # Required to capture the learning agent to close after eval
def _update_session_metadata_file(self) -> None:
"""
Update the ``session_metadata.json`` file.
Updates the `session_metadata.json`` in the ``session_path`` directory
with the following key/value pairs:
- end_datetime: The date & time the session ended in iso format.
- total_episodes: The total number of training episodes completed.
- total_time_steps: The total number of training time steps completed.
"""
with open(self.session_path / "session_metadata.json", "r") as file:
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
if not self.is_eval:
metadata_dict["learning"]["total_episodes"] = self._current_result["episodes_total"] # noqa
metadata_dict["learning"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
else:
metadata_dict["evaluation"]["total_episodes"] = self._current_result["episodes_total"] # noqa
metadata_dict["evaluation"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
with open(filepath, "w") as file:
json.dump(metadata_dict, file)
_LOGGER.debug("Finished updating session metadata file")
def _setup(self) -> None:
super()._setup()
register_env("primaite", _env_creator)
self._agent_config = self._agent_config_class()
self._agent_config.environment(
env="primaite",
env_config=dict(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
),
)
self._agent_config.seed = self._training_config.seed
self._agent_config.training(train_batch_size=self._training_config.num_train_steps)
self._agent_config.framework(framework="tf")
self._agent_config.rollouts(
num_rollout_workers=1,
num_envs_per_worker=1,
horizon=self._training_config.num_train_steps,
)
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
def _save_checkpoint(self) -> None:
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._current_result["episodes_total"]
save_checkpoint = False
if checkpoint_n:
save_checkpoint = episode_count % checkpoint_n == 0
if episode_count and save_checkpoint:
self._agent.save(str(self.checkpoints_path))
def learn(
self,
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
time_steps = self._training_config.num_train_steps
episodes = self._training_config.num_train_episodes
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
for i in range(episodes):
self._current_result = self._agent.train()
self._save_checkpoint()
self.save()
super().learn()
# Done this way as the RLlib eval can only be performed if the session hasn't been stopped
if self._training_config.session_type is not SessionType.TRAIN:
self._train_agent = self._agent
else:
self._agent.stop()
self._plot_av_reward_per_episode(learning_session=True)
def _unpack_saved_agent_into_eval(self) -> Path:
"""Unpacks the pre-trained and saved RLlib agent so that it can be reloaded by Ray for eval."""
agent_restore_path = self.evaluation_path / "agent_restore"
if agent_restore_path.exists():
shutil.rmtree(agent_restore_path)
agent_restore_path.mkdir()
with zipfile.ZipFile(self._saved_agent_path, "r") as zip_file:
zip_file.extractall(agent_restore_path)
return agent_restore_path
def _setup_eval(self):
self._can_learn = False
self._can_evaluate = True
self._agent.restore(str(self._unpack_saved_agent_into_eval()))
def evaluate(
self,
**kwargs,
):
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
time_steps = self._training_config.num_eval_steps
episodes = self._training_config.num_eval_episodes
self._setup_eval()
self._env: Primaite = Primaite(
self._training_config_path, self._lay_down_config_path, self.session_path, self.timestamp_str
)
self._env.set_as_eval()
self.is_eval = True
if self._training_config.deterministic:
deterministic_str = "deterministic"
else:
deterministic_str = "non-deterministic"
_LOGGER.info(
f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
)
for episode in range(episodes):
obs = self._env.reset()
for step in range(time_steps):
action = self._agent.compute_single_action(observation=obs, explore=False)
obs, rewards, done, info = self._env.step(action)
self._env.reset()
self._env.close()
super().evaluate()
# Now we're safe to close the learning agent and write the mean rewards per episode for it
if self._training_config.session_type is not SessionType.TRAIN:
self._train_agent.stop()
self._plot_av_reward_per_episode(learning_session=True)
# Perform a clean-up of the unpacked agent
if (self.evaluation_path / "agent_restore").exists():
shutil.rmtree((self.evaluation_path / "agent_restore"))
def _get_latest_checkpoint(self) -> None:
raise NotImplementedError
@classmethod
def load(cls, path: Union[str, Path]) -> RLlibAgent:
"""Load an agent from file."""
raise NotImplementedError
def save(self, overwrite_existing: bool = True) -> None:
"""Save the agent."""
# Make temp dir to save in isolation
temp_dir = self.learning_path / str(uuid4())
temp_dir.mkdir()
# Save the agent to the temp dir
self._agent.save(str(temp_dir))
# Capture the saved Rllib checkpoint inside the temp directory
for file in temp_dir.iterdir():
checkpoint_dir = file
break
# Zip the folder
shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa
# Drop the temp directory
shutil.rmtree(temp_dir)
def export(self) -> None:
"""Export the agent to transportable file format."""
raise NotImplementedError