#1593 - Ran pre-commit hook
This commit is contained in:
@@ -259,10 +259,13 @@ class AgentSessionABC(ABC):
|
||||
|
||||
@property
|
||||
def _saved_agent_path(self) -> Path:
|
||||
file_name = (f"{self._training_config.agent_framework}_"
|
||||
f"{self._training_config.agent_identifier}_"
|
||||
f"{self.timestamp_str}.zip")
|
||||
file_name = (
|
||||
f"{self._training_config.agent_framework}_"
|
||||
f"{self._training_config.agent_identifier}_"
|
||||
f"{self.timestamp_str}.zip"
|
||||
)
|
||||
return self.learning_path / file_name
|
||||
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
"""Save the agent."""
|
||||
|
||||
@@ -85,10 +85,8 @@ class RLlibAgent(AgentSessionABC):
|
||||
metadata_dict = json.load(file)
|
||||
|
||||
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
metadata_dict["total_episodes"] = self._current_result[
|
||||
"episodes_total"]
|
||||
metadata_dict["total_time_steps"] = self._current_result[
|
||||
"timesteps_total"]
|
||||
metadata_dict["total_episodes"] = self._current_result["episodes_total"]
|
||||
metadata_dict["total_time_steps"] = self._current_result["timesteps_total"]
|
||||
|
||||
filepath = self.session_path / "session_metadata.json"
|
||||
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||
@@ -111,8 +109,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
),
|
||||
)
|
||||
|
||||
self._agent_config.training(
|
||||
train_batch_size=self._training_config.num_steps)
|
||||
self._agent_config.training(train_batch_size=self._training_config.num_steps)
|
||||
self._agent_config.framework(framework="tf")
|
||||
|
||||
self._agent_config.rollouts(
|
||||
@@ -120,8 +117,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
num_envs_per_worker=1,
|
||||
horizon=self._training_config.num_steps,
|
||||
)
|
||||
self._agent: Algorithm = self._agent_config.build(
|
||||
logger_creator=_custom_log_creator(self.learning_path))
|
||||
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
|
||||
|
||||
def _save_checkpoint(self):
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
@@ -133,8 +129,8 @@ class RLlibAgent(AgentSessionABC):
|
||||
self._agent.save(str(self.checkpoints_path))
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
@@ -144,8 +140,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
time_steps = self._training_config.num_steps
|
||||
episodes = self._training_config.num_episodes
|
||||
|
||||
_LOGGER.info(
|
||||
f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
|
||||
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
|
||||
for i in range(episodes):
|
||||
self._current_result = self._agent.train()
|
||||
self._save_checkpoint()
|
||||
@@ -154,8 +149,8 @@ class RLlibAgent(AgentSessionABC):
|
||||
super().learn()
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
@@ -187,16 +182,11 @@ class RLlibAgent(AgentSessionABC):
|
||||
break
|
||||
|
||||
# Zip the folder
|
||||
shutil.make_archive(
|
||||
str(self._saved_agent_path).replace(".zip", ""),
|
||||
"zip",
|
||||
checkpoint_dir # noqa
|
||||
)
|
||||
shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa
|
||||
|
||||
# Drop the temp directory
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
def export(self):
|
||||
"""Export the agent to transportable file format."""
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user