#1386: fix bug with agent zip file not being saved after run

This commit is contained in:
Czar Echavez
2023-07-04 16:30:31 +01:00
parent 410afc1d40
commit 9001510fe7
3 changed files with 13 additions and 11 deletions

View File

@@ -257,10 +257,17 @@ class AgentSessionABC(ABC):
raise FileNotFoundError(msg) raise FileNotFoundError(msg)
pass pass
@abstractmethod def save(self, path_str: str | Path):
def save(self):
"""Save the agent.""" """Save the agent."""
self._agent.save(self.session_path) if path_str:
self._agent.save(path_str)
return
# if no path, save to root but with a random UUID
self._agent.save(
self.session_path
/ f"{self._training_config.agent_framework}_{self._training_config.agent_identifier}_{uuid4()}"
)
@abstractmethod @abstractmethod
def export(self): def export(self):

View File

@@ -142,6 +142,7 @@ class RLlibAgent(AgentSessionABC):
self._save_checkpoint() self._save_checkpoint()
self._agent.stop() self._agent.stop()
super().learn() super().learn()
self.save(self.learning_path / f"rllib_{self._training_config.agent_identifier}_{self.timestamp_str}.zip")
def evaluate( def evaluate(
self, self,
@@ -162,10 +163,6 @@ class RLlibAgent(AgentSessionABC):
"""Load an agent from file.""" """Load an agent from file."""
raise NotImplementedError raise NotImplementedError
def save(self):
"""Save the agent."""
raise NotImplementedError
def export(self): def export(self):
"""Export the agent to transportable file format.""" """Export the agent to transportable file format."""
raise NotImplementedError raise NotImplementedError

View File

@@ -89,6 +89,7 @@ class SB3Agent(AgentSessionABC):
for i in range(episodes): for i in range(episodes):
self._agent.learn(total_timesteps=time_steps) self._agent.learn(total_timesteps=time_steps)
self._save_checkpoint() self._save_checkpoint()
self.save(self.learning_path / f"sb3_{self._training_config.agent_identifier}_{self.timestamp_str}.zip")
self._env.reset() self._env.reset()
self._env.close() self._env.close()
super().learn() super().learn()
@@ -123,6 +124,7 @@ class SB3Agent(AgentSessionABC):
if isinstance(action, np.ndarray): if isinstance(action, np.ndarray):
action = np.int64(action) action = np.int64(action)
obs, rewards, done, info = self._env.step(action) obs, rewards, done, info = self._env.step(action)
self.save(self.evaluation_path / f"sb3_{self._training_config.agent_identifier}_{self.timestamp_str}.zip")
self._env.reset() self._env.reset()
self._env.close() self._env.close()
super().evaluate() super().evaluate()
@@ -132,10 +134,6 @@ class SB3Agent(AgentSessionABC):
"""Load an agent from file.""" """Load an agent from file."""
raise NotImplementedError raise NotImplementedError
def save(self):
"""Save the agent."""
raise NotImplementedError
def export(self): def export(self):
"""Export the agent to transportable file format.""" """Export the agent to transportable file format."""
raise NotImplementedError raise NotImplementedError