diff --git a/src/primaite/session/policy/rllib.py b/src/primaite/session/policy/rllib.py index 7ba3edd0..be181797 100644 --- a/src/primaite/session/policy/rllib.py +++ b/src/primaite/session/policy/rllib.py @@ -78,17 +78,18 @@ class RayMultiAgentPolicy(PolicyABC, identifier="RLLIB_multi_agent"): def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: """Train the agent.""" + checkpoint_freq = self.session.io_manager.settings.checkpoint_interval tune.Tuner( "PPO", run_config=air.RunConfig( stop={"training_iteration": n_episodes * timesteps_per_episode}, - checkpoint_config=air.CheckpointConfig(checkpoint_frequency=10), + checkpoint_config=air.CheckpointConfig(checkpoint_frequency=checkpoint_freq), ), param_space=self.config, ).fit() def load(self, model_path: Path) -> None: - """Load policy paramters from a file.""" + """Load policy parameters from a file.""" return NotImplemented def eval(self, n_episodes: int, deterministic: bool) -> None: