import copy import json from typing import Any, Dict, Optional, SupportsFloat, Tuple import gymnasium from gymnasium.core import ActType, ObsType from ray.rllib.env.multi_agent_env import MultiAgentEnv from primaite import getLogger from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame from primaite.session.io import PrimaiteIO from primaite.simulator import SIM_OUTPUT _LOGGER = getLogger(__name__) class PrimaiteGymEnv(gymnasium.Env): """ Thin wrapper env to provide agents with a gymnasium API. This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some assumptions about the agent list always having a list of length 1. """ def __init__(self, game_config: Dict): """Initialise the environment.""" super().__init__() self.game_config: Dict = game_config """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning.""" self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config)) """Current game.""" self._agent_name = next(iter(self.game.rl_agents)) """Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key.""" self.episode_counter: int = 0 """Current episode number.""" self.io = PrimaiteIO.from_config(game_config.get("io_settings", {})) """Handles IO for the environment. This produces sys logs, agent logs, etc.""" @property def agent(self) -> ProxyAgent: """Grab a fresh reference to the agent object because it will be reinstantiated each episode.""" return self.game.rl_agents[self._agent_name] def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """Perform a step in the environment.""" # make ProxyAgent store the action chosen my the RL policy self.agent.store_action(action) # apply_agent_actions accesses the action we just stored self.game.apply_agent_actions() self.game.advance_timestep() state = self.game.get_sim_state() self.game.update_agents(state) next_obs = self._get_obs() # this doesn't update observation, just gets the current observation reward = self.agent.reward_function.current_reward terminated = False truncated = self.game.calculate_truncated() info = { "agent_actions": {name: agent.action_history[-1] for name, agent in self.game.agents.items()} } # tell us what all the agents did for convenience. if self.game.save_step_metadata: self._write_step_metadata_json(action, state, reward) return next_obs, reward, terminated, truncated, info def _write_step_metadata_json(self, action: int, state: Dict, reward: int): output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata" output_dir.mkdir(parents=True, exist_ok=True) path = output_dir / f"step_{self.game.step_counter}.json" data = { "episode": self.episode_counter, "step": self.game.step_counter, "action": int(action), "reward": int(reward), "state": state, } with open(path, "w") as file: json.dump(data, file) def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: """Reset the environment.""" _LOGGER.info( f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {self.agent.reward_function.total_reward}" ) if self.io.settings.save_agent_actions: all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config)) self.game.setup_for_episode(episode=self.episode_counter) self.episode_counter += 1 state = self.game.get_sim_state() self.game.update_agents(state=state) next_obs = self._get_obs() info = {} return next_obs, info @property def action_space(self) -> gymnasium.Space: """Return the action space of the environment.""" return self.agent.action_manager.space @property def observation_space(self) -> gymnasium.Space: """Return the observation space of the environment.""" if self.agent.flatten_obs: return gymnasium.spaces.flatten_space(self.agent.observation_manager.space) else: return self.agent.observation_manager.space def _get_obs(self) -> ObsType: """Return the current observation.""" if self.agent.flatten_obs: unflat_space = self.agent.observation_manager.space unflat_obs = self.agent.observation_manager.current_observation return gymnasium.spaces.flatten(unflat_space, unflat_obs) else: return self.agent.observation_manager.current_observation class PrimaiteRayEnv(gymnasium.Env): """Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray.""" def __init__(self, env_config: Dict) -> None: """Initialise the environment. :param env_config: A dictionary containing the environment configuration. :type env_config: Dict """ self.env = PrimaiteGymEnv(game_config=env_config) self.env.episode_counter -= 1 self.action_space = self.env.action_space self.observation_space = self.env.observation_space def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" return self.env.reset(seed=seed) def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]: """Perform a step in the environment.""" return self.env.step(action) class PrimaiteRayMARLEnv(MultiAgentEnv): """Ray Environment that inherits from MultiAgentEnv to allow training MARL systems.""" def __init__(self, env_config: Dict) -> None: """Initialise the environment. :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game` which is the PrimaiteGame instance. :type env_config: Dict """ self.game_config: Dict = env_config """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning.""" self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config)) """Reference to the primaite game""" self._agent_ids = list(self.game.rl_agents.keys()) """Agent ids. This is a list of strings of agent names.""" self.episode_counter: int = 0 """Current episode number.""" self.terminateds = set() self.truncateds = set() self.observation_space = gymnasium.spaces.Dict( { name: gymnasium.spaces.flatten_space(agent.observation_manager.space) for name, agent in self.agents.items() } ) self.action_space = gymnasium.spaces.Dict( {name: agent.action_manager.space for name, agent in self.agents.items()} ) self.io = PrimaiteIO.from_config(env_config.get("io_settings")) """Handles IO for the environment. This produces sys logs, agent logs, etc.""" super().__init__() @property def agents(self) -> Dict[str, ProxyAgent]: """Grab a fresh reference to the agents from this episode's game object.""" return {name: self.game.rl_agents[name] for name in self._agent_ids} def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" if self.io.settings.save_agent_actions: all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config)) self.game.setup_for_episode(episode=self.episode_counter) self.episode_counter += 1 state = self.game.get_sim_state() self.game.update_agents(state) next_obs = self._get_obs() info = {} return next_obs, info def step( self, actions: Dict[str, ActType] ) -> Tuple[Dict[str, ObsType], Dict[str, SupportsFloat], Dict[str, bool], Dict[str, bool], Dict]: """Perform a step in the environment. Adherent to Ray MultiAgentEnv step API. :param actions: Dict of actions. The key is agent identifier and the value is a gymnasium action instance. :type actions: Dict[str, ActType] :return: Observations, rewards, terminateds, truncateds, and info. Each one is a dictionary keyed by agent identifier. :rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict] """ # 1. Perform actions for agent_name, action in actions.items(): self.agents[agent_name].store_action(action) self.game.apply_agent_actions() # 2. Advance timestep self.game.advance_timestep() # 3. Get next observations state = self.game.get_sim_state() self.game.update_agents(state) next_obs = self._get_obs() # 4. Get rewards rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()} terminateds = {name: False for name, _ in self.agents.items()} truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()} infos = {name: {} for name, _ in self.agents.items()} terminateds["__all__"] = len(self.terminateds) == len(self.agents) truncateds["__all__"] = self.game.calculate_truncated() if self.game.save_step_metadata: self._write_step_metadata_json(actions, state, rewards) return next_obs, rewards, terminateds, truncateds, infos def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict): output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata" output_dir.mkdir(parents=True, exist_ok=True) path = output_dir / f"step_{self.game.step_counter}.json" data = { "episode": self.episode_counter, "step": self.game.step_counter, "actions": {agent_name: int(action) for agent_name, action in actions.items()}, "reward": rewards, "state": state, } with open(path, "w") as file: json.dump(data, file) def _get_obs(self) -> Dict[str, ObsType]: """Return the current observation.""" obs = {} for agent_name in self._agent_ids: agent = self.game.rl_agents[agent_name] unflat_space = agent.observation_manager.space unflat_obs = agent.observation_manager.current_observation obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs) return obs