#1594 - Added docstrings and fixed training type. Added a clean-up of the unpacked agent in eval dir.
This commit is contained in:
@@ -75,7 +75,7 @@ class RLlibAgent(AgentSessionABC):
|
|||||||
|
|
||||||
super().__init__(training_config_path, lay_down_config_path)
|
super().__init__(training_config_path, lay_down_config_path)
|
||||||
if self._training_config.session_type == SessionType.EVAL:
|
if self._training_config.session_type == SessionType.EVAL:
|
||||||
msg = "Cannot evaluate an RLlib agent that hasn't been through trainig yet."
|
msg = "Cannot evaluate an RLlib agent that hasn't been through training yet."
|
||||||
_LOGGER.critical(msg)
|
_LOGGER.critical(msg)
|
||||||
raise RLlibAgentError(msg)
|
raise RLlibAgentError(msg)
|
||||||
if not self._training_config.agent_framework == AgentFramework.RLLIB:
|
if not self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||||
@@ -194,6 +194,7 @@ class RLlibAgent(AgentSessionABC):
|
|||||||
self._plot_av_reward_per_episode(learning_session=True)
|
self._plot_av_reward_per_episode(learning_session=True)
|
||||||
|
|
||||||
def _unpack_saved_agent_into_eval(self) -> Path:
|
def _unpack_saved_agent_into_eval(self) -> Path:
|
||||||
|
"""Unpacks the pre-trained and saved RLlib agent so that it can be reloaded by Ray for eval."""
|
||||||
agent_restore_path = self.evaluation_path / "agent_restore"
|
agent_restore_path = self.evaluation_path / "agent_restore"
|
||||||
if agent_restore_path.exists():
|
if agent_restore_path.exists():
|
||||||
shutil.rmtree(agent_restore_path)
|
shutil.rmtree(agent_restore_path)
|
||||||
@@ -248,6 +249,9 @@ class RLlibAgent(AgentSessionABC):
|
|||||||
if self._training_config.session_type is not SessionType.TRAIN:
|
if self._training_config.session_type is not SessionType.TRAIN:
|
||||||
self._train_agent.stop()
|
self._train_agent.stop()
|
||||||
self._plot_av_reward_per_episode(learning_session=True)
|
self._plot_av_reward_per_episode(learning_session=True)
|
||||||
|
# Perform a clean-up of the unpacked agent
|
||||||
|
if (self.evaluation_path / "agent_restore").exists():
|
||||||
|
shutil.rmtree((self.evaluation_path / "agent_restore"))
|
||||||
|
|
||||||
def _get_latest_checkpoint(self) -> None:
|
def _get_latest_checkpoint(self) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
Reference in New Issue
Block a user