#1386: fix saving of agent
This commit is contained in:
@@ -257,17 +257,14 @@ class AgentSessionABC(ABC):
|
||||
raise FileNotFoundError(msg)
|
||||
pass
|
||||
|
||||
def save(self, path_str: str | Path):
|
||||
def save(self):
|
||||
"""Save the agent."""
|
||||
if path_str:
|
||||
self._agent.save(path_str)
|
||||
return
|
||||
|
||||
# if no path, save to root but with a random UUID
|
||||
self._agent.save(
|
||||
agent_path = (
|
||||
self.session_path
|
||||
/ f"{self._training_config.agent_framework}_{self._training_config.agent_identifier}_{uuid4()}"
|
||||
/ f"{self._training_config.agent_framework}_{self._training_config.agent_identifier}_{self.timestamp_str}"
|
||||
)
|
||||
_LOGGER.debug(f"Saving agent: {agent_path}")
|
||||
self._agent.save(agent_path)
|
||||
|
||||
@abstractmethod
|
||||
def export(self):
|
||||
|
||||
@@ -141,8 +141,11 @@ class RLlibAgent(AgentSessionABC):
|
||||
self._current_result = self._agent.train()
|
||||
self._save_checkpoint()
|
||||
self._agent.stop()
|
||||
|
||||
super().learn()
|
||||
self.save(self.learning_path / f"rllib_{self._training_config.agent_identifier}_{self.timestamp_str}.zip")
|
||||
|
||||
# save agent
|
||||
self.save()
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
|
||||
@@ -89,11 +89,13 @@ class SB3Agent(AgentSessionABC):
|
||||
for i in range(episodes):
|
||||
self._agent.learn(total_timesteps=time_steps)
|
||||
self._save_checkpoint()
|
||||
self.save(self.learning_path / f"sb3_{self._training_config.agent_identifier}_{self.timestamp_str}.zip")
|
||||
self._env.reset()
|
||||
self._env.close()
|
||||
super().learn()
|
||||
|
||||
# save agent
|
||||
self.save()
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
deterministic: bool = True,
|
||||
@@ -124,7 +126,6 @@ class SB3Agent(AgentSessionABC):
|
||||
if isinstance(action, np.ndarray):
|
||||
action = np.int64(action)
|
||||
obs, rewards, done, info = self._env.step(action)
|
||||
self.save(self.evaluation_path / f"sb3_{self._training_config.agent_identifier}_{self.timestamp_str}.zip")
|
||||
self._env.reset()
|
||||
self._env.close()
|
||||
super().evaluate()
|
||||
|
||||
Reference in New Issue
Block a user