Remove hardcoded checkpoint frequency in rllib
This commit is contained in:
@@ -78,17 +78,18 @@ class RayMultiAgentPolicy(PolicyABC, identifier="RLLIB_multi_agent"):
|
||||
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
checkpoint_freq = self.session.io_manager.settings.checkpoint_interval
|
||||
tune.Tuner(
|
||||
"PPO",
|
||||
run_config=air.RunConfig(
|
||||
stop={"training_iteration": n_episodes * timesteps_per_episode},
|
||||
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=10),
|
||||
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=checkpoint_freq),
|
||||
),
|
||||
param_space=self.config,
|
||||
).fit()
|
||||
|
||||
def load(self, model_path: Path) -> None:
|
||||
"""Load policy paramters from a file."""
|
||||
"""Load policy parameters from a file."""
|
||||
return NotImplemented
|
||||
|
||||
def eval(self, n_episodes: int, deterministic: bool) -> None:
|
||||
|
||||
Reference in New Issue
Block a user