#1386: fix saving of agent

This commit is contained in:
Czar Echavez
2023-07-05 11:41:18 +01:00
parent 9001510fe7
commit 075b11aeca
3 changed files with 12 additions and 11 deletions

View File

@@ -257,17 +257,14 @@ class AgentSessionABC(ABC):
raise FileNotFoundError(msg)
pass
def save(self, path_str: str | Path):
def save(self):
"""Save the agent."""
if path_str:
self._agent.save(path_str)
return
# if no path, save to root but with a random UUID
self._agent.save(
agent_path = (
self.session_path
/ f"{self._training_config.agent_framework}_{self._training_config.agent_identifier}_{uuid4()}"
/ f"{self._training_config.agent_framework}_{self._training_config.agent_identifier}_{self.timestamp_str}"
)
_LOGGER.debug(f"Saving agent: {agent_path}")
self._agent.save(agent_path)
@abstractmethod
def export(self):

View File

@@ -141,8 +141,11 @@ class RLlibAgent(AgentSessionABC):
self._current_result = self._agent.train()
self._save_checkpoint()
self._agent.stop()
super().learn()
self.save(self.learning_path / f"rllib_{self._training_config.agent_identifier}_{self.timestamp_str}.zip")
# save agent
self.save()
def evaluate(
self,

View File

@@ -89,11 +89,13 @@ class SB3Agent(AgentSessionABC):
for i in range(episodes):
self._agent.learn(total_timesteps=time_steps)
self._save_checkpoint()
self.save(self.learning_path / f"sb3_{self._training_config.agent_identifier}_{self.timestamp_str}.zip")
self._env.reset()
self._env.close()
super().learn()
# save agent
self.save()
def evaluate(
self,
deterministic: bool = True,
@@ -124,7 +126,6 @@ class SB3Agent(AgentSessionABC):
if isinstance(action, np.ndarray):
action = np.int64(action)
obs, rewards, done, info = self._env.step(action)
self.save(self.evaluation_path / f"sb3_{self._training_config.agent_identifier}_{self.timestamp_str}.zip")
self._env.reset()
self._env.close()
super().evaluate()