#1593 - Ran pre-commit hook
This commit is contained in:
@@ -259,10 +259,13 @@ class AgentSessionABC(ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _saved_agent_path(self) -> Path:
|
def _saved_agent_path(self) -> Path:
|
||||||
file_name = (f"{self._training_config.agent_framework}_"
|
file_name = (
|
||||||
f"{self._training_config.agent_identifier}_"
|
f"{self._training_config.agent_framework}_"
|
||||||
f"{self.timestamp_str}.zip")
|
f"{self._training_config.agent_identifier}_"
|
||||||
|
f"{self.timestamp_str}.zip"
|
||||||
|
)
|
||||||
return self.learning_path / file_name
|
return self.learning_path / file_name
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save(self):
|
def save(self):
|
||||||
"""Save the agent."""
|
"""Save the agent."""
|
||||||
|
|||||||
@@ -85,10 +85,8 @@ class RLlibAgent(AgentSessionABC):
|
|||||||
metadata_dict = json.load(file)
|
metadata_dict = json.load(file)
|
||||||
|
|
||||||
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||||
metadata_dict["total_episodes"] = self._current_result[
|
metadata_dict["total_episodes"] = self._current_result["episodes_total"]
|
||||||
"episodes_total"]
|
metadata_dict["total_time_steps"] = self._current_result["timesteps_total"]
|
||||||
metadata_dict["total_time_steps"] = self._current_result[
|
|
||||||
"timesteps_total"]
|
|
||||||
|
|
||||||
filepath = self.session_path / "session_metadata.json"
|
filepath = self.session_path / "session_metadata.json"
|
||||||
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||||
@@ -111,8 +109,7 @@ class RLlibAgent(AgentSessionABC):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._agent_config.training(
|
self._agent_config.training(train_batch_size=self._training_config.num_steps)
|
||||||
train_batch_size=self._training_config.num_steps)
|
|
||||||
self._agent_config.framework(framework="tf")
|
self._agent_config.framework(framework="tf")
|
||||||
|
|
||||||
self._agent_config.rollouts(
|
self._agent_config.rollouts(
|
||||||
@@ -120,8 +117,7 @@ class RLlibAgent(AgentSessionABC):
|
|||||||
num_envs_per_worker=1,
|
num_envs_per_worker=1,
|
||||||
horizon=self._training_config.num_steps,
|
horizon=self._training_config.num_steps,
|
||||||
)
|
)
|
||||||
self._agent: Algorithm = self._agent_config.build(
|
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
|
||||||
logger_creator=_custom_log_creator(self.learning_path))
|
|
||||||
|
|
||||||
def _save_checkpoint(self):
|
def _save_checkpoint(self):
|
||||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||||
@@ -133,8 +129,8 @@ class RLlibAgent(AgentSessionABC):
|
|||||||
self._agent.save(str(self.checkpoints_path))
|
self._agent.save(str(self.checkpoints_path))
|
||||||
|
|
||||||
def learn(
|
def learn(
|
||||||
self,
|
self,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Evaluate the agent.
|
Evaluate the agent.
|
||||||
@@ -144,8 +140,7 @@ class RLlibAgent(AgentSessionABC):
|
|||||||
time_steps = self._training_config.num_steps
|
time_steps = self._training_config.num_steps
|
||||||
episodes = self._training_config.num_episodes
|
episodes = self._training_config.num_episodes
|
||||||
|
|
||||||
_LOGGER.info(
|
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
|
||||||
f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
|
|
||||||
for i in range(episodes):
|
for i in range(episodes):
|
||||||
self._current_result = self._agent.train()
|
self._current_result = self._agent.train()
|
||||||
self._save_checkpoint()
|
self._save_checkpoint()
|
||||||
@@ -154,8 +149,8 @@ class RLlibAgent(AgentSessionABC):
|
|||||||
super().learn()
|
super().learn()
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Evaluate the agent.
|
Evaluate the agent.
|
||||||
@@ -187,16 +182,11 @@ class RLlibAgent(AgentSessionABC):
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Zip the folder
|
# Zip the folder
|
||||||
shutil.make_archive(
|
shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa
|
||||||
str(self._saved_agent_path).replace(".zip", ""),
|
|
||||||
"zip",
|
|
||||||
checkpoint_dir # noqa
|
|
||||||
)
|
|
||||||
|
|
||||||
# Drop the temp directory
|
# Drop the temp directory
|
||||||
shutil.rmtree(temp_dir)
|
shutil.rmtree(temp_dir)
|
||||||
|
|
||||||
|
|
||||||
def export(self):
|
def export(self):
|
||||||
"""Export the agent to transportable file format."""
|
"""Export the agent to transportable file format."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
Reference in New Issue
Block a user