Files
PrimAITE/src/primaite/agents/rllib.py
2023-07-31 12:13:52 +01:00

288 lines
12 KiB
Python

# # © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# from __future__ import annotations
# import json
# import shutil
# import zipfile
# from datetime import datetime
# from logging import Logger
# from pathlib import Path
# from typing import Any, Callable, Dict, Optional, Union
# from uuid import uuid4
# from primaite import getLogger
# from primaite.agents.agent_abc import AgentSessionABC
# from primaite.common.enums import AgentFramework, AgentIdentifier, SessionType
# from primaite.environment.primaite_env import Primaite
# # from ray.rllib.algorithms import Algorithm
# # from ray.rllib.algorithms.a2c import A2CConfig
# # from ray.rllib.algorithms.ppo import PPOConfig
# # from ray.tune.logger import UnifiedLogger
# # from ray.tune.registry import register_env
# # from primaite.exceptions import RLlibAgentError
# _LOGGER: Logger = getLogger(__name__)
# # TODO: verify type of env_config
# def _env_creator(env_config: Dict[str, Any]) -> Primaite:
# return Primaite(
# training_config_path=env_config["training_config_path"],
# lay_down_config_path=env_config["lay_down_config_path"],
# session_path=env_config["session_path"],
# timestamp_str=env_config["timestamp_str"],
# )
# # # TODO: verify type hint return type
# # def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]:
# # logdir = session_path / "ray_results"
# # logdir.mkdir(parents=True, exist_ok=True)
# # def logger_creator(config: Dict) -> UnifiedLogger:
# # return UnifiedLogger(config, logdir, loggers=None)
# return logger_creator
# # class RLlibAgent(AgentSessionABC):
# # """An AgentSession class that implements a Ray RLlib agent."""
# # def __init__(
# # self,
# # training_config_path: Optional[Union[str, Path]] = "",
# # lay_down_config_path: Optional[Union[str, Path]] = "",
# # session_path: Optional[Union[str, Path]] = None,
# # ) -> None:
# # """
# # Initialise the RLLib Agent training session.
# # :param training_config_path: YAML file containing configurable items defined in
# # `primaite.config.training_config.TrainingConfig`
# # :type training_config_path: Union[path, str]
# # :param lay_down_config_path: YAML file containing configurable items for generating network laydown.
# # :type lay_down_config_path: Union[path, str]
# # :raises ValueError: If the training config contains a bad value for agent_framework (should be "RLLIB")
# # :raises ValueError: If the training config contains a bad value for agent_identifies (should be `PPO`
# # or `A2C`)
# # """
# # # TODO: implement RLlib agent loading
# # if session_path is not None:
# # msg = "RLlib agent loading has not been implemented yet"
# # _LOGGER.critical(msg)
# # raise NotImplementedError(msg)
# # super().__init__(training_config_path, lay_down_config_path)
# # if self._training_config.session_type == SessionType.EVAL:
# # msg = "Cannot evaluate an RLlib agent that hasn't been through training yet."
# # _LOGGER.critical(msg)
# # raise RLlibAgentError(msg)
# # if not self._training_config.agent_framework == AgentFramework.RLLIB:
# # msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
# # _LOGGER.error(msg)
# # raise ValueError(msg)
# # self._agent_config_class: Union[PPOConfig, A2CConfig]
# # if self._training_config.agent_identifier == AgentIdentifier.PPO:
# # self._agent_config_class = PPOConfig
# # elif self._training_config.agent_identifier == AgentIdentifier.A2C:
# # self._agent_config_class = A2CConfig
# # else:
# # msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}"
# # _LOGGER.error(msg)
# # raise ValueError(msg)
# # self._agent_config: Union[PPOConfig, A2CConfig]
# # self._current_result: dict
# # self._setup()
# # _LOGGER.debug(
# # f"Created {self.__class__.__name__} using: "
# # f"agent_framework={self._training_config.agent_framework}, "
# # f"agent_identifier="
# # f"{self._training_config.agent_identifier}, "
# # f"deep_learning_framework="
# # f"{self._training_config.deep_learning_framework}"
# # )
# # self._train_agent = None # Required to capture the learning agent to close after eval
# # def _update_session_metadata_file(self) -> None:
# # """
# # Update the ``session_metadata.json`` file.
# # Updates the `session_metadata.json`` in the ``session_path`` directory
# # with the following key/value pairs:
# # - end_datetime: The date & time the session ended in iso format.
# # - total_episodes: The total number of training episodes completed.
# # - total_time_steps: The total number of training time steps completed.
# # """
# # with open(self.session_path / "session_metadata.json", "r") as file:
# # metadata_dict = json.load(file)
# # metadata_dict["end_datetime"] = datetime.now().isoformat()
# # if not self.is_eval:
# # metadata_dict["learning"]["total_episodes"] = self._current_result["episodes_total"] # noqa
# # metadata_dict["learning"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
# # else:
# # metadata_dict["evaluation"]["total_episodes"] = self._current_result["episodes_total"] # noqa
# # metadata_dict["evaluation"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
# # filepath = self.session_path / "session_metadata.json"
# # _LOGGER.debug(f"Updating Session Metadata file: {filepath}")
# # with open(filepath, "w") as file:
# # json.dump(metadata_dict, file)
# # _LOGGER.debug("Finished updating session metadata file")
# # def _setup(self) -> None:
# # super()._setup()
# # register_env("primaite", _env_creator)
# # self._agent_config = self._agent_config_class()
# # self._agent_config.environment(
# # env="primaite",
# # env_config=dict(
# # training_config_path=self._training_config_path,
# # lay_down_config_path=self._lay_down_config_path,
# # session_path=self.session_path,
# # timestamp_str=self.timestamp_str,
# # ),
# # )
# # self._agent_config.seed = self._training_config.seed
# # self._agent_config.training(train_batch_size=self._training_config.num_train_steps)
# # self._agent_config.framework(framework="tf")
# # self._agent_config.rollouts(
# # num_rollout_workers=1,
# # num_envs_per_worker=1,
# # horizon=self._training_config.num_train_steps,
# # )
# # self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
# # def _save_checkpoint(self) -> None:
# # checkpoint_n = self._training_config.checkpoint_every_n_episodes
# # episode_count = self._current_result["episodes_total"]
# # save_checkpoint = False
# # if checkpoint_n:
# # save_checkpoint = episode_count % checkpoint_n == 0
# # if episode_count and save_checkpoint:
# # self._agent.save(str(self.checkpoints_path))
# # def learn(
# # self,
# # **kwargs: Any,
# # ) -> None:
# # """
# # Evaluate the agent.
# # :param kwargs: Any agent-specific key-word args to be passed.
# # """
# # time_steps = self._training_config.num_train_steps
# # episodes = self._training_config.num_train_episodes
# # _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
# # for i in range(episodes):
# # self._current_result = self._agent.train()
# # self._save_checkpoint()
# # self.save()
# # super().learn()
# # # Done this way as the RLlib eval can only be performed if the session hasn't been stopped
# # if self._training_config.session_type is not SessionType.TRAIN:
# # self._train_agent = self._agent
# # else:
# # self._agent.stop()
# # self._plot_av_reward_per_episode(learning_session=True)
# # def _unpack_saved_agent_into_eval(self) -> Path:
# # """Unpacks the pre-trained and saved RLlib agent so that it can be reloaded by Ray for eval."""
# # agent_restore_path = self.evaluation_path / "agent_restore"
# # if agent_restore_path.exists():
# # shutil.rmtree(agent_restore_path)
# # agent_restore_path.mkdir()
# # with zipfile.ZipFile(self._saved_agent_path, "r") as zip_file:
# # zip_file.extractall(agent_restore_path)
# # return agent_restore_path
# # def _setup_eval(self):
# # self._can_learn = False
# # self._can_evaluate = True
# # self._agent.restore(str(self._unpack_saved_agent_into_eval()))
# # def evaluate(
# # self,
# # **kwargs,
# # ):
# # """
# # Evaluate the agent.
# # :param kwargs: Any agent-specific key-word args to be passed.
# # """
# # time_steps = self._training_config.num_eval_steps
# # episodes = self._training_config.num_eval_episodes
# # self._setup_eval()
# # self._env: Primaite = Primaite(
# # self._training_config_path, self._lay_down_config_path, self.session_path, self.timestamp_str
# # )
# # self._env.set_as_eval()
# # self.is_eval = True
# # if self._training_config.deterministic:
# # deterministic_str = "deterministic"
# # else:
# # deterministic_str = "non-deterministic"
# # _LOGGER.info(
# # f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
# # )
# # for episode in range(episodes):
# # obs = self._env.reset()
# # for step in range(time_steps):
# # action = self._agent.compute_single_action(observation=obs, explore=False)
# # obs, rewards, done, info = self._env.step(action)
# # self._env.reset()
# # self._env.close()
# # super().evaluate()
# # # Now we're safe to close the learning agent and write the mean rewards per episode for it
# # if self._training_config.session_type is not SessionType.TRAIN:
# # self._train_agent.stop()
# # self._plot_av_reward_per_episode(learning_session=True)
# # # Perform a clean-up of the unpacked agent
# # if (self.evaluation_path / "agent_restore").exists():
# # shutil.rmtree((self.evaluation_path / "agent_restore"))
# # def _get_latest_checkpoint(self) -> None:
# # raise NotImplementedError
# # @classmethod
# # def load(cls, path: Union[str, Path]) -> RLlibAgent:
# # """Load an agent from file."""
# # raise NotImplementedError
# # def save(self, overwrite_existing: bool = True) -> None:
# # """Save the agent."""
# # # Make temp dir to save in isolation
# # temp_dir = self.learning_path / str(uuid4())
# # temp_dir.mkdir()
# # # Save the agent to the temp dir
# # self._agent.save(str(temp_dir))
# # # Capture the saved Rllib checkpoint inside the temp directory
# # for file in temp_dir.iterdir():
# # checkpoint_dir = file
# # break
# # # Zip the folder
# # shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa
# # # Drop the temp directory
# # shutil.rmtree(temp_dir)
# # def export(self) -> None:
# # """Export the agent to transportable file format."""
# # raise NotImplementedError