Merge remote-tracking branch 'devops/bugfix/episode-length-and-rewards' into feature/2085-dump_describe_state
This commit is contained in:
@@ -83,14 +83,15 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
class PrimaiteRayEnv(gymnasium.Env):
|
||||
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
|
||||
|
||||
def __init__(self, env_config: Dict[str, PrimaiteGame]) -> None:
|
||||
def __init__(self, env_config: Dict) -> None:
|
||||
"""Initialise the environment.
|
||||
|
||||
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
|
||||
which is the PrimaiteGame instance.
|
||||
:type env_config: Dict[str, PrimaiteGame]
|
||||
"""
|
||||
self.env = PrimaiteGymEnv(game=env_config["game"])
|
||||
self.env = PrimaiteGymEnv(game=PrimaiteGame.from_config(env_config["cfg"]))
|
||||
self.env.game.episode_counter -= 1
|
||||
self.action_space = self.env.action_space
|
||||
self.observation_space = self.env.observation_space
|
||||
|
||||
@@ -106,14 +107,14 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
|
||||
|
||||
def __init__(self, env_config: Optional[Dict] = None) -> None:
|
||||
def __init__(self, env_config: Dict) -> None:
|
||||
"""Initialise the environment.
|
||||
|
||||
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
|
||||
which is the PrimaiteGame instance.
|
||||
:type env_config: Dict[str, PrimaiteGame]
|
||||
"""
|
||||
self.game: PrimaiteGame = env_config["game"]
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(env_config["cfg"])
|
||||
"""Reference to the primaite game"""
|
||||
self.agents: Final[Dict[str, ProxyAgent]] = {agent.agent_name: agent for agent in self.game.rl_agents}
|
||||
"""List of all possible agents in the environment. This list should not change!"""
|
||||
@@ -122,7 +123,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
self.terminateds = set()
|
||||
self.truncateds = set()
|
||||
self.observation_space = gymnasium.spaces.Dict(
|
||||
{name: agent.observation_manager.space for name, agent in self.agents.items()}
|
||||
{
|
||||
name: gymnasium.spaces.flatten_space(agent.observation_manager.space)
|
||||
for name, agent in self.agents.items()
|
||||
}
|
||||
)
|
||||
self.action_space = gymnasium.spaces.Dict(
|
||||
{name: agent.action_manager.space for name, agent in self.agents.items()}
|
||||
@@ -173,4 +177,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
|
||||
def _get_obs(self) -> Dict[str, ObsType]:
|
||||
"""Return the current observation."""
|
||||
return {name: agent.observation_manager.current_observation for name, agent in self.agents.items()}
|
||||
obs = {}
|
||||
for name, agent in self.agents.items():
|
||||
unflat_space = agent.observation_manager.space
|
||||
unflat_obs = agent.observation_manager.current_observation
|
||||
obs[name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
return obs
|
||||
|
||||
Reference in New Issue
Block a user