diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index d08f60cb..0281de7e 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -66,6 +66,7 @@ class RLlibAgent(AgentSessionABC): msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}" _LOGGER.error(msg) raise ValueError(msg) + self._agent_config_class: Union[PPOConfig, A2CConfig] if self._training_config.agent_identifier == AgentIdentifier.PPO: self._agent_config_class = PPOConfig elif self._training_config.agent_identifier == AgentIdentifier.A2C: