diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 685fe776..37735812 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -257,10 +257,17 @@ class AgentSessionABC(ABC): raise FileNotFoundError(msg) pass - @abstractmethod - def save(self): + def save(self, path_str: str | Path): """Save the agent.""" - self._agent.save(self.session_path) + if path_str: + self._agent.save(path_str) + return + + # if no path, save to root but with a random UUID + self._agent.save( + self.session_path + / f"{self._training_config.agent_framework}_{self._training_config.agent_identifier}_{uuid4()}" + ) @abstractmethod def export(self): diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index d851ba9c..b058ffc7 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -142,6 +142,7 @@ class RLlibAgent(AgentSessionABC): self._save_checkpoint() self._agent.stop() super().learn() + self.save(self.learning_path / f"rllib_{self._training_config.agent_identifier}_{self.timestamp_str}.zip") def evaluate( self, @@ -162,10 +163,6 @@ class RLlibAgent(AgentSessionABC): """Load an agent from file.""" raise NotImplementedError - def save(self): - """Save the agent.""" - raise NotImplementedError - def export(self): """Export the agent to transportable file format.""" raise NotImplementedError diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index f5ac44cb..e0c1ee79 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -89,6 +89,7 @@ class SB3Agent(AgentSessionABC): for i in range(episodes): self._agent.learn(total_timesteps=time_steps) self._save_checkpoint() + self.save(self.learning_path / f"sb3_{self._training_config.agent_identifier}_{self.timestamp_str}.zip") self._env.reset() self._env.close() super().learn() @@ -123,6 +124,7 @@ class SB3Agent(AgentSessionABC): if isinstance(action, np.ndarray): action = np.int64(action) obs, rewards, done, info = self._env.step(action) + self.save(self.evaluation_path / f"sb3_{self._training_config.agent_identifier}_{self.timestamp_str}.zip") self._env.reset() self._env.close() super().evaluate() @@ -132,10 +134,6 @@ class SB3Agent(AgentSessionABC): """Load an agent from file.""" raise NotImplementedError - def save(self): - """Save the agent.""" - raise NotImplementedError - def export(self): """Export the agent to transportable file format.""" raise NotImplementedError