diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index fecf84d0..32118597 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -259,10 +259,13 @@ class AgentSessionABC(ABC): @property def _saved_agent_path(self) -> Path: - file_name = (f"{self._training_config.agent_framework}_" - f"{self._training_config.agent_identifier}_" - f"{self.timestamp_str}.zip") + file_name = ( + f"{self._training_config.agent_framework}_" + f"{self._training_config.agent_identifier}_" + f"{self.timestamp_str}.zip" + ) return self.learning_path / file_name + @abstractmethod def save(self): """Save the agent.""" diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index ce072a03..427072c4 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -85,10 +85,8 @@ class RLlibAgent(AgentSessionABC): metadata_dict = json.load(file) metadata_dict["end_datetime"] = datetime.now().isoformat() - metadata_dict["total_episodes"] = self._current_result[ - "episodes_total"] - metadata_dict["total_time_steps"] = self._current_result[ - "timesteps_total"] + metadata_dict["total_episodes"] = self._current_result["episodes_total"] + metadata_dict["total_time_steps"] = self._current_result["timesteps_total"] filepath = self.session_path / "session_metadata.json" _LOGGER.debug(f"Updating Session Metadata file: {filepath}") @@ -111,8 +109,7 @@ class RLlibAgent(AgentSessionABC): ), ) - self._agent_config.training( - train_batch_size=self._training_config.num_steps) + self._agent_config.training(train_batch_size=self._training_config.num_steps) self._agent_config.framework(framework="tf") self._agent_config.rollouts( @@ -120,8 +117,7 @@ class RLlibAgent(AgentSessionABC): num_envs_per_worker=1, horizon=self._training_config.num_steps, ) - self._agent: Algorithm = self._agent_config.build( - logger_creator=_custom_log_creator(self.learning_path)) + self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path)) def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes @@ -133,8 +129,8 @@ class RLlibAgent(AgentSessionABC): self._agent.save(str(self.checkpoints_path)) def learn( - self, - **kwargs, + self, + **kwargs, ): """ Evaluate the agent. @@ -144,8 +140,7 @@ class RLlibAgent(AgentSessionABC): time_steps = self._training_config.num_steps episodes = self._training_config.num_episodes - _LOGGER.info( - f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...") + _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...") for i in range(episodes): self._current_result = self._agent.train() self._save_checkpoint() @@ -154,8 +149,8 @@ class RLlibAgent(AgentSessionABC): super().learn() def evaluate( - self, - **kwargs, + self, + **kwargs, ): """ Evaluate the agent. @@ -187,16 +182,11 @@ class RLlibAgent(AgentSessionABC): break # Zip the folder - shutil.make_archive( - str(self._saved_agent_path).replace(".zip", ""), - "zip", - checkpoint_dir # noqa - ) + shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa # Drop the temp directory shutil.rmtree(temp_dir) - def export(self): """Export the agent to transportable file format.""" raise NotImplementedError