207 lines
7.8 KiB
Python
207 lines
7.8 KiB
Python
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from logging import Logger
|
|
from pathlib import Path
|
|
from typing import Any, Optional, Union
|
|
|
|
import numpy as np
|
|
from stable_baselines3 import A2C, PPO
|
|
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
|
|
|
from primaite import getLogger
|
|
from primaite.agents.agent_abc import AgentSessionABC
|
|
from primaite.common.enums import AgentFramework, AgentIdentifier
|
|
from primaite.environment.primaite_env import Primaite
|
|
|
|
_LOGGER: Logger = getLogger(__name__)
|
|
|
|
|
|
class SB3Agent(AgentSessionABC):
|
|
"""An AgentSession class that implements a Stable Baselines3 agent."""
|
|
|
|
def __init__(
|
|
self,
|
|
training_config_path: Optional[Union[str, Path]] = None,
|
|
lay_down_config_path: Optional[Union[str, Path]] = None,
|
|
session_path: Optional[Union[str, Path]] = None,
|
|
legacy_training_config: bool = False,
|
|
legacy_lay_down_config: bool = False,
|
|
) -> None:
|
|
"""
|
|
Initialise the SB3 Agent training session.
|
|
|
|
:param training_config_path: YAML file containing configurable items defined in
|
|
`primaite.config.training_config.TrainingConfig`
|
|
:type training_config_path: Union[path, str]
|
|
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
|
:type lay_down_config_path: Union[path, str]
|
|
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
|
|
otherwise False.
|
|
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
|
|
otherwise False.
|
|
:raises ValueError: If the training config contains an unexpected value for agent_framework (should be "SB3")
|
|
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
|
|
or `A2C`)
|
|
"""
|
|
super().__init__(
|
|
training_config_path, lay_down_config_path, session_path, legacy_training_config, legacy_lay_down_config
|
|
)
|
|
if not self._training_config.agent_framework == AgentFramework.SB3:
|
|
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
|
|
_LOGGER.error(msg)
|
|
raise ValueError(msg)
|
|
self._agent_class: Union[PPO, A2C]
|
|
if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
|
self._agent_class = PPO
|
|
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
|
self._agent_class = A2C
|
|
else:
|
|
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier}"
|
|
_LOGGER.error(msg)
|
|
raise ValueError(msg)
|
|
|
|
self._tensorboard_log_path = self.learning_path / "tensorboard_logs"
|
|
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
_LOGGER.debug(
|
|
f"Created {self.__class__.__name__} using: "
|
|
f"agent_framework={self._training_config.agent_framework}, "
|
|
f"agent_identifier="
|
|
f"{self._training_config.agent_identifier}"
|
|
)
|
|
|
|
self.is_eval = False
|
|
|
|
self._setup()
|
|
|
|
def _setup(self) -> None:
|
|
"""Set up the SB3 Agent."""
|
|
self._env = Primaite(
|
|
training_config_path=self._training_config_path,
|
|
lay_down_config_path=self._lay_down_config_path,
|
|
session_path=self.session_path,
|
|
timestamp_str=self.timestamp_str,
|
|
legacy_training_config=self.legacy_training_config,
|
|
legacy_lay_down_config=self.legacy_lay_down_config,
|
|
)
|
|
|
|
# check if there is a zip file that needs to be loaded
|
|
load_file = next(self.session_path.rglob("*.zip"), None)
|
|
|
|
if not load_file:
|
|
# create a new env and agent
|
|
|
|
self._agent = self._agent_class(
|
|
PPOMlp,
|
|
self._env,
|
|
verbose=self.sb3_output_verbose_level,
|
|
n_steps=self._training_config.num_train_steps,
|
|
tensorboard_log=str(self._tensorboard_log_path),
|
|
seed=self._training_config.seed,
|
|
)
|
|
else:
|
|
# set env values from session metadata
|
|
with open(self.session_path / "session_metadata.json", "r") as file:
|
|
md_dict = json.load(file)
|
|
|
|
# load environment values
|
|
if self.is_eval:
|
|
# evaluation always starts at 0
|
|
self._env.episode_count = 0
|
|
self._env.total_step_count = 0
|
|
else:
|
|
# carry on from previous learning sessions
|
|
self._env.episode_count = md_dict["learning"]["total_episodes"]
|
|
self._env.total_step_count = md_dict["learning"]["total_time_steps"]
|
|
|
|
# load the file
|
|
self._agent = self._agent_class.load(load_file, env=self._env)
|
|
|
|
# set agent values
|
|
self._agent.verbose = self.sb3_output_verbose_level
|
|
self._agent.tensorboard_log = self.session_path / "learning/tensorboard_logs"
|
|
|
|
super()._setup()
|
|
|
|
def _save_checkpoint(self) -> None:
|
|
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
|
episode_count = self._env.episode_count
|
|
save_checkpoint = False
|
|
if checkpoint_n:
|
|
save_checkpoint = episode_count % checkpoint_n == 0
|
|
if episode_count and save_checkpoint:
|
|
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
|
|
self._agent.save(checkpoint_path)
|
|
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
|
|
|
|
def _get_latest_checkpoint(self) -> None:
|
|
pass
|
|
|
|
def learn(
|
|
self,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""
|
|
Train the agent.
|
|
|
|
:param kwargs: Any agent-specific key-word args to be passed.
|
|
"""
|
|
time_steps = self._training_config.num_train_steps
|
|
episodes = self._training_config.num_train_episodes
|
|
self.is_eval = False
|
|
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
|
|
for i in range(episodes):
|
|
self._agent.learn(total_timesteps=time_steps)
|
|
self._save_checkpoint()
|
|
self._env._write_av_reward_per_episode() # noqa
|
|
self.save()
|
|
self._env.close()
|
|
super().learn()
|
|
|
|
# save agent
|
|
self.save()
|
|
|
|
self._plot_av_reward_per_episode(learning_session=True)
|
|
|
|
def evaluate(
|
|
self,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""
|
|
Evaluate the agent.
|
|
|
|
:param kwargs: Any agent-specific key-word args to be passed.
|
|
"""
|
|
time_steps = self._training_config.num_eval_steps
|
|
episodes = self._training_config.num_eval_episodes
|
|
self._env.set_as_eval()
|
|
self.is_eval = True
|
|
if self._training_config.deterministic:
|
|
deterministic_str = "deterministic"
|
|
else:
|
|
deterministic_str = "non-deterministic"
|
|
_LOGGER.info(
|
|
f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
|
|
)
|
|
for episode in range(episodes):
|
|
obs = self._env.reset()
|
|
|
|
for step in range(time_steps):
|
|
action, _states = self._agent.predict(obs, deterministic=self._training_config.deterministic)
|
|
if isinstance(action, np.ndarray):
|
|
action = np.int64(action)
|
|
obs, rewards, done, info = self._env.step(action)
|
|
self._env._write_av_reward_per_episode() # noqa
|
|
self._env.close()
|
|
super().evaluate()
|
|
|
|
def save(self) -> None:
|
|
"""Save the agent."""
|
|
self._agent.save(self._saved_agent_path)
|
|
|
|
def export(self) -> None:
|
|
"""Export the agent to transportable file format."""
|
|
raise NotImplementedError
|