Fix GATE Client to work successfully with GATE Server

This commit is contained in:
Marek Wolan
2023-10-11 14:08:55 +01:00
parent def2c3699b
commit b84ab84385
2 changed files with 11 additions and 10 deletions

View File

@@ -487,7 +487,7 @@ class AclObservation(AbstractObservation):
def space(self) -> spaces.Space:
return spaces.Dict(
{
"RULE": spaces.Dict(
"RULES": spaces.Dict(
{
i
+ 1: spaces.Dict(
@@ -532,11 +532,11 @@ class NullObservation(AbstractObservation):
self.default_observation: Dict = {}
def observe(self, state: Dict) -> Dict:
return {}
return 0
@property
def space(self) -> spaces.Space:
return spaces.Dict({})
return spaces.Discrete(1)
@classmethod
def from_config(cls, config: Dict, session: Optional["PrimaiteSession"] = None) -> "NullObservation":

View File

@@ -94,26 +94,27 @@ class PrimaiteGATEClient(GATEClient):
@property
def observation_space(self) -> spaces.Space:
print("YEEY0")
print(flatten_space(spaces.Dict({})))
print("YEEY1")
# print(self.parent_session.rl_agent.observation_space.space)
return flatten_space(self.parent_session.rl_agent.observation_space.space)
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, Dict]:
self.parent_session.rl_agent.most_recent_action = action
self.parent_session.step()
obs = self.parent_session.rl_agent.observation_space.observe()
state = self.parent_session.simulation.describe_state()
obs = self.parent_session.rl_agent.observation_space.observe(state)
obs = flatten(self.parent_session.rl_agent.observation_space.space, obs)
rew = self.parent_session.rl_agent.reward_function.calculate()
rew = self.parent_session.rl_agent.reward_function.calculate(state)
term = False
trunc = False
info = {}
return obs, rew, term, trunc, info
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> Tuple[ndarray, Dict]:
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> Tuple[ObsType, Dict]:
self.parent_session.reset()
state = self.parent_session.simulation.describe_state()
obs = self.parent_session.rl_agent.observation_space.observe(state)
obs = flatten(self.parent_session.rl_agent.observation_space.space, obs)
return obs, {}
def close(self):
self.parent_session.close()