diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 8657fc45..01df33de 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -45,6 +45,7 @@ class AgentSettings(BaseModel): start_settings: Optional[AgentStartSettings] = None "Configuration for when an agent begins performing it's actions" flatten_obs: bool = True + "Whether to flatten the observation space before passing it to the agent. True by default." @classmethod def from_config(cls, config: Optional[Dict]) -> "AgentSettings": @@ -176,7 +177,7 @@ class ProxyAgent(AbstractAgent): reward_function=reward_function, ) self.most_recent_action: ActType - self.flatten_obs: bool = agent_settings.flatten_obs + self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: """ diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 36ab3f58..6701f183 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -23,7 +23,6 @@ class PrimaiteGymEnv(gymnasium.Env): super().__init__() self.game: "PrimaiteGame" = game self.agent: ProxyAgent = self.game.rl_agents[0] - self.flatten_obs: bool = False def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """Perform a step in the environment."""