#1593 - Ran pre-commit hook

This commit is contained in:
Chris McCarthy
2023-07-06 14:18:49 +01:00
parent fc98441a11
commit ddabf991ce
2 changed files with 16 additions and 23 deletions

View File

@@ -259,10 +259,13 @@ class AgentSessionABC(ABC):
@property
def _saved_agent_path(self) -> Path:
file_name = (f"{self._training_config.agent_framework}_"
f"{self._training_config.agent_identifier}_"
f"{self.timestamp_str}.zip")
file_name = (
f"{self._training_config.agent_framework}_"
f"{self._training_config.agent_identifier}_"
f"{self.timestamp_str}.zip"
)
return self.learning_path / file_name
@abstractmethod
def save(self):
"""Save the agent."""

View File

@@ -85,10 +85,8 @@ class RLlibAgent(AgentSessionABC):
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
metadata_dict["total_episodes"] = self._current_result[
"episodes_total"]
metadata_dict["total_time_steps"] = self._current_result[
"timesteps_total"]
metadata_dict["total_episodes"] = self._current_result["episodes_total"]
metadata_dict["total_time_steps"] = self._current_result["timesteps_total"]
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
@@ -111,8 +109,7 @@ class RLlibAgent(AgentSessionABC):
),
)
self._agent_config.training(
train_batch_size=self._training_config.num_steps)
self._agent_config.training(train_batch_size=self._training_config.num_steps)
self._agent_config.framework(framework="tf")
self._agent_config.rollouts(
@@ -120,8 +117,7 @@ class RLlibAgent(AgentSessionABC):
num_envs_per_worker=1,
horizon=self._training_config.num_steps,
)
self._agent: Algorithm = self._agent_config.build(
logger_creator=_custom_log_creator(self.learning_path))
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
@@ -133,8 +129,8 @@ class RLlibAgent(AgentSessionABC):
self._agent.save(str(self.checkpoints_path))
def learn(
self,
**kwargs,
self,
**kwargs,
):
"""
Evaluate the agent.
@@ -144,8 +140,7 @@ class RLlibAgent(AgentSessionABC):
time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes
_LOGGER.info(
f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
for i in range(episodes):
self._current_result = self._agent.train()
self._save_checkpoint()
@@ -154,8 +149,8 @@ class RLlibAgent(AgentSessionABC):
super().learn()
def evaluate(
self,
**kwargs,
self,
**kwargs,
):
"""
Evaluate the agent.
@@ -187,16 +182,11 @@ class RLlibAgent(AgentSessionABC):
break
# Zip the folder
shutil.make_archive(
str(self._saved_agent_path).replace(".zip", ""),
"zip",
checkpoint_dir # noqa
)
shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa
# Drop the temp directory
shutil.rmtree(temp_dir)
def export(self):
"""Export the agent to transportable file format."""
raise NotImplementedError