#1593 - Ran pre-commit hook

This commit is contained in:
Chris McCarthy
2023-07-06 14:18:49 +01:00
parent 82d7c168fe
commit c9f4741655
2 changed files with 16 additions and 23 deletions

View File

@@ -259,10 +259,13 @@ class AgentSessionABC(ABC):
@property @property
def _saved_agent_path(self) -> Path: def _saved_agent_path(self) -> Path:
file_name = (f"{self._training_config.agent_framework}_" file_name = (
f"{self._training_config.agent_identifier}_" f"{self._training_config.agent_framework}_"
f"{self.timestamp_str}.zip") f"{self._training_config.agent_identifier}_"
f"{self.timestamp_str}.zip"
)
return self.learning_path / file_name return self.learning_path / file_name
@abstractmethod @abstractmethod
def save(self): def save(self):
"""Save the agent.""" """Save the agent."""

View File

@@ -85,10 +85,8 @@ class RLlibAgent(AgentSessionABC):
metadata_dict = json.load(file) metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat() metadata_dict["end_datetime"] = datetime.now().isoformat()
metadata_dict["total_episodes"] = self._current_result[ metadata_dict["total_episodes"] = self._current_result["episodes_total"]
"episodes_total"] metadata_dict["total_time_steps"] = self._current_result["timesteps_total"]
metadata_dict["total_time_steps"] = self._current_result[
"timesteps_total"]
filepath = self.session_path / "session_metadata.json" filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}") _LOGGER.debug(f"Updating Session Metadata file: {filepath}")
@@ -111,8 +109,7 @@ class RLlibAgent(AgentSessionABC):
), ),
) )
self._agent_config.training( self._agent_config.training(train_batch_size=self._training_config.num_steps)
train_batch_size=self._training_config.num_steps)
self._agent_config.framework(framework="tf") self._agent_config.framework(framework="tf")
self._agent_config.rollouts( self._agent_config.rollouts(
@@ -120,8 +117,7 @@ class RLlibAgent(AgentSessionABC):
num_envs_per_worker=1, num_envs_per_worker=1,
horizon=self._training_config.num_steps, horizon=self._training_config.num_steps,
) )
self._agent: Algorithm = self._agent_config.build( self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
logger_creator=_custom_log_creator(self.learning_path))
def _save_checkpoint(self): def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes checkpoint_n = self._training_config.checkpoint_every_n_episodes
@@ -133,8 +129,8 @@ class RLlibAgent(AgentSessionABC):
self._agent.save(str(self.checkpoints_path)) self._agent.save(str(self.checkpoints_path))
def learn( def learn(
self, self,
**kwargs, **kwargs,
): ):
""" """
Evaluate the agent. Evaluate the agent.
@@ -144,8 +140,7 @@ class RLlibAgent(AgentSessionABC):
time_steps = self._training_config.num_steps time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes episodes = self._training_config.num_episodes
_LOGGER.info( _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
for i in range(episodes): for i in range(episodes):
self._current_result = self._agent.train() self._current_result = self._agent.train()
self._save_checkpoint() self._save_checkpoint()
@@ -154,8 +149,8 @@ class RLlibAgent(AgentSessionABC):
super().learn() super().learn()
def evaluate( def evaluate(
self, self,
**kwargs, **kwargs,
): ):
""" """
Evaluate the agent. Evaluate the agent.
@@ -187,16 +182,11 @@ class RLlibAgent(AgentSessionABC):
break break
# Zip the folder # Zip the folder
shutil.make_archive( shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa
str(self._saved_agent_path).replace(".zip", ""),
"zip",
checkpoint_dir # noqa
)
# Drop the temp directory # Drop the temp directory
shutil.rmtree(temp_dir) shutil.rmtree(temp_dir)
def export(self): def export(self):
"""Export the agent to transportable file format.""" """Export the agent to transportable file format."""
raise NotImplementedError raise NotImplementedError