#1594 - Added docstrings and fixed training type. Added a clean-up of the unpacked agent in eval dir.

This commit is contained in:
Chris McCarthy
2023-07-21 10:33:22 +01:00
parent df52236a7d
commit 722fe97c84

View File

@@ -75,7 +75,7 @@ class RLlibAgent(AgentSessionABC):
super().__init__(training_config_path, lay_down_config_path)
if self._training_config.session_type == SessionType.EVAL:
msg = "Cannot evaluate an RLlib agent that hasn't been through trainig yet."
msg = "Cannot evaluate an RLlib agent that hasn't been through training yet."
_LOGGER.critical(msg)
raise RLlibAgentError(msg)
if not self._training_config.agent_framework == AgentFramework.RLLIB:
@@ -194,6 +194,7 @@ class RLlibAgent(AgentSessionABC):
self._plot_av_reward_per_episode(learning_session=True)
def _unpack_saved_agent_into_eval(self) -> Path:
"""Unpacks the pre-trained and saved RLlib agent so that it can be reloaded by Ray for eval."""
agent_restore_path = self.evaluation_path / "agent_restore"
if agent_restore_path.exists():
shutil.rmtree(agent_restore_path)
@@ -248,6 +249,9 @@ class RLlibAgent(AgentSessionABC):
if self._training_config.session_type is not SessionType.TRAIN:
self._train_agent.stop()
self._plot_av_reward_per_episode(learning_session=True)
# Perform a clean-up of the unpacked agent
if (self.evaluation_path / "agent_restore").exists():
shutil.rmtree((self.evaluation_path / "agent_restore"))
def _get_latest_checkpoint(self) -> None:
raise NotImplementedError