Fix GATE Client to work successfully with GATE Server
This commit is contained in:
@@ -487,7 +487,7 @@ class AclObservation(AbstractObservation):
|
||||
def space(self) -> spaces.Space:
|
||||
return spaces.Dict(
|
||||
{
|
||||
"RULE": spaces.Dict(
|
||||
"RULES": spaces.Dict(
|
||||
{
|
||||
i
|
||||
+ 1: spaces.Dict(
|
||||
@@ -532,11 +532,11 @@ class NullObservation(AbstractObservation):
|
||||
self.default_observation: Dict = {}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
return {}
|
||||
return 0
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
return spaces.Dict({})
|
||||
return spaces.Discrete(1)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: Optional["PrimaiteSession"] = None) -> "NullObservation":
|
||||
|
||||
@@ -94,26 +94,27 @@ class PrimaiteGATEClient(GATEClient):
|
||||
|
||||
@property
|
||||
def observation_space(self) -> spaces.Space:
|
||||
print("YEEY0")
|
||||
print(flatten_space(spaces.Dict({})))
|
||||
print("YEEY1")
|
||||
# print(self.parent_session.rl_agent.observation_space.space)
|
||||
return flatten_space(self.parent_session.rl_agent.observation_space.space)
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, Dict]:
|
||||
self.parent_session.rl_agent.most_recent_action = action
|
||||
self.parent_session.step()
|
||||
obs = self.parent_session.rl_agent.observation_space.observe()
|
||||
state = self.parent_session.simulation.describe_state()
|
||||
obs = self.parent_session.rl_agent.observation_space.observe(state)
|
||||
obs = flatten(self.parent_session.rl_agent.observation_space.space, obs)
|
||||
rew = self.parent_session.rl_agent.reward_function.calculate()
|
||||
rew = self.parent_session.rl_agent.reward_function.calculate(state)
|
||||
term = False
|
||||
trunc = False
|
||||
info = {}
|
||||
return obs, rew, term, trunc, info
|
||||
|
||||
|
||||
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> Tuple[ndarray, Dict]:
|
||||
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> Tuple[ObsType, Dict]:
|
||||
self.parent_session.reset()
|
||||
state = self.parent_session.simulation.describe_state()
|
||||
obs = self.parent_session.rl_agent.observation_space.observe(state)
|
||||
obs = flatten(self.parent_session.rl_agent.observation_space.space, obs)
|
||||
return obs, {}
|
||||
|
||||
def close(self):
|
||||
self.parent_session.close()
|
||||
|
||||
Reference in New Issue
Block a user