diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index 00f98f5c..a5a5fc77 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -487,7 +487,7 @@ class AclObservation(AbstractObservation): def space(self) -> spaces.Space: return spaces.Dict( { - "RULE": spaces.Dict( + "RULES": spaces.Dict( { i + 1: spaces.Dict( @@ -532,11 +532,11 @@ class NullObservation(AbstractObservation): self.default_observation: Dict = {} def observe(self, state: Dict) -> Dict: - return {} + return 0 @property def space(self) -> spaces.Space: - return spaces.Dict({}) + return spaces.Discrete(1) @classmethod def from_config(cls, config: Dict, session: Optional["PrimaiteSession"] = None) -> "NullObservation": diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index d978b848..ec6b8e86 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -94,26 +94,27 @@ class PrimaiteGATEClient(GATEClient): @property def observation_space(self) -> spaces.Space: - print("YEEY0") - print(flatten_space(spaces.Dict({}))) - print("YEEY1") - # print(self.parent_session.rl_agent.observation_space.space) return flatten_space(self.parent_session.rl_agent.observation_space.space) def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, Dict]: self.parent_session.rl_agent.most_recent_action = action self.parent_session.step() - obs = self.parent_session.rl_agent.observation_space.observe() + state = self.parent_session.simulation.describe_state() + obs = self.parent_session.rl_agent.observation_space.observe(state) obs = flatten(self.parent_session.rl_agent.observation_space.space, obs) - rew = self.parent_session.rl_agent.reward_function.calculate() + rew = self.parent_session.rl_agent.reward_function.calculate(state) term = False trunc = False info = {} return obs, rew, term, trunc, info - def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> Tuple[ndarray, Dict]: + def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> Tuple[ObsType, Dict]: self.parent_session.reset() + state = self.parent_session.simulation.describe_state() + obs = self.parent_session.rl_agent.observation_space.observe(state) + obs = flatten(self.parent_session.rl_agent.observation_space.space, obs) + return obs, {} def close(self): self.parent_session.close()