#1386: fix bug with agent zip file not being saved after run
This commit is contained in:
@@ -257,10 +257,17 @@ class AgentSessionABC(ABC):
|
|||||||
raise FileNotFoundError(msg)
|
raise FileNotFoundError(msg)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
def save(self, path_str: str | Path):
|
||||||
def save(self):
|
|
||||||
"""Save the agent."""
|
"""Save the agent."""
|
||||||
self._agent.save(self.session_path)
|
if path_str:
|
||||||
|
self._agent.save(path_str)
|
||||||
|
return
|
||||||
|
|
||||||
|
# if no path, save to root but with a random UUID
|
||||||
|
self._agent.save(
|
||||||
|
self.session_path
|
||||||
|
/ f"{self._training_config.agent_framework}_{self._training_config.agent_identifier}_{uuid4()}"
|
||||||
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def export(self):
|
def export(self):
|
||||||
|
|||||||
@@ -142,6 +142,7 @@ class RLlibAgent(AgentSessionABC):
|
|||||||
self._save_checkpoint()
|
self._save_checkpoint()
|
||||||
self._agent.stop()
|
self._agent.stop()
|
||||||
super().learn()
|
super().learn()
|
||||||
|
self.save(self.learning_path / f"rllib_{self._training_config.agent_identifier}_{self.timestamp_str}.zip")
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
@@ -162,10 +163,6 @@ class RLlibAgent(AgentSessionABC):
|
|||||||
"""Load an agent from file."""
|
"""Load an agent from file."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def save(self):
|
|
||||||
"""Save the agent."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def export(self):
|
def export(self):
|
||||||
"""Export the agent to transportable file format."""
|
"""Export the agent to transportable file format."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -89,6 +89,7 @@ class SB3Agent(AgentSessionABC):
|
|||||||
for i in range(episodes):
|
for i in range(episodes):
|
||||||
self._agent.learn(total_timesteps=time_steps)
|
self._agent.learn(total_timesteps=time_steps)
|
||||||
self._save_checkpoint()
|
self._save_checkpoint()
|
||||||
|
self.save(self.learning_path / f"sb3_{self._training_config.agent_identifier}_{self.timestamp_str}.zip")
|
||||||
self._env.reset()
|
self._env.reset()
|
||||||
self._env.close()
|
self._env.close()
|
||||||
super().learn()
|
super().learn()
|
||||||
@@ -123,6 +124,7 @@ class SB3Agent(AgentSessionABC):
|
|||||||
if isinstance(action, np.ndarray):
|
if isinstance(action, np.ndarray):
|
||||||
action = np.int64(action)
|
action = np.int64(action)
|
||||||
obs, rewards, done, info = self._env.step(action)
|
obs, rewards, done, info = self._env.step(action)
|
||||||
|
self.save(self.evaluation_path / f"sb3_{self._training_config.agent_identifier}_{self.timestamp_str}.zip")
|
||||||
self._env.reset()
|
self._env.reset()
|
||||||
self._env.close()
|
self._env.close()
|
||||||
super().evaluate()
|
super().evaluate()
|
||||||
@@ -132,10 +134,6 @@ class SB3Agent(AgentSessionABC):
|
|||||||
"""Load an agent from file."""
|
"""Load an agent from file."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def save(self):
|
|
||||||
"""Save the agent."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def export(self):
|
def export(self):
|
||||||
"""Export the agent to transportable file format."""
|
"""Export the agent to transportable file format."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
Reference in New Issue
Block a user