diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index 4149d02d..fb062f54 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -75,7 +75,7 @@ class RLlibAgent(AgentSessionABC): super().__init__(training_config_path, lay_down_config_path) if self._training_config.session_type == SessionType.EVAL: - msg = "Cannot evaluate an RLlib agent that hasn't been through trainig yet." + msg = "Cannot evaluate an RLlib agent that hasn't been through training yet." _LOGGER.critical(msg) raise RLlibAgentError(msg) if not self._training_config.agent_framework == AgentFramework.RLLIB: @@ -194,6 +194,7 @@ class RLlibAgent(AgentSessionABC): self._plot_av_reward_per_episode(learning_session=True) def _unpack_saved_agent_into_eval(self) -> Path: + """Unpacks the pre-trained and saved RLlib agent so that it can be reloaded by Ray for eval.""" agent_restore_path = self.evaluation_path / "agent_restore" if agent_restore_path.exists(): shutil.rmtree(agent_restore_path) @@ -248,6 +249,9 @@ class RLlibAgent(AgentSessionABC): if self._training_config.session_type is not SessionType.TRAIN: self._train_agent.stop() self._plot_av_reward_per_episode(learning_session=True) + # Perform a clean-up of the unpacked agent + if (self.evaluation_path / "agent_restore").exists(): + shutil.rmtree((self.evaluation_path / "agent_restore")) def _get_latest_checkpoint(self) -> None: raise NotImplementedError