diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 37735812..4b39839a 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -257,17 +257,14 @@ class AgentSessionABC(ABC): raise FileNotFoundError(msg) pass - def save(self, path_str: str | Path): + def save(self): """Save the agent.""" - if path_str: - self._agent.save(path_str) - return - - # if no path, save to root but with a random UUID - self._agent.save( + agent_path = ( self.session_path - / f"{self._training_config.agent_framework}_{self._training_config.agent_identifier}_{uuid4()}" + / f"{self._training_config.agent_framework}_{self._training_config.agent_identifier}_{self.timestamp_str}" ) + _LOGGER.debug(f"Saving agent: {agent_path}") + self._agent.save(agent_path) @abstractmethod def export(self): diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index b058ffc7..30edd93c 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -141,8 +141,11 @@ class RLlibAgent(AgentSessionABC): self._current_result = self._agent.train() self._save_checkpoint() self._agent.stop() + super().learn() - self.save(self.learning_path / f"rllib_{self._training_config.agent_identifier}_{self.timestamp_str}.zip") + + # save agent + self.save() def evaluate( self, diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index e0c1ee79..17fbe0a6 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -89,11 +89,13 @@ class SB3Agent(AgentSessionABC): for i in range(episodes): self._agent.learn(total_timesteps=time_steps) self._save_checkpoint() - self.save(self.learning_path / f"sb3_{self._training_config.agent_identifier}_{self.timestamp_str}.zip") self._env.reset() self._env.close() super().learn() + # save agent + self.save() + def evaluate( self, deterministic: bool = True, @@ -124,7 +126,6 @@ class SB3Agent(AgentSessionABC): if isinstance(action, np.ndarray): action = np.int64(action) obs, rewards, done, info = self._env.step(action) - self.save(self.evaluation_path / f"sb3_{self._training_config.agent_identifier}_{self.timestamp_str}.zip") self._env.reset() self._env.close() super().evaluate()