Remove hardcoded checkpoint frequency in rllib

This commit is contained in:
Marek Wolan
2023-11-24 09:37:26 +00:00
parent 6754dbf541
commit abba1ef86b

View File

@@ -78,17 +78,18 @@ class RayMultiAgentPolicy(PolicyABC, identifier="RLLIB_multi_agent"):
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
checkpoint_freq = self.session.io_manager.settings.checkpoint_interval
tune.Tuner(
"PPO",
run_config=air.RunConfig(
stop={"training_iteration": n_episodes * timesteps_per_episode},
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=10),
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=checkpoint_freq),
),
param_space=self.config,
).fit()
def load(self, model_path: Path) -> None:
"""Load policy paramters from a file."""
"""Load policy parameters from a file."""
return NotImplemented
def eval(self, n_episodes: int, deterministic: bool) -> None: