From faf268a9b9b4fd847bab7f31518b8488dadc169b Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 9 Jul 2024 15:59:50 +0100 Subject: [PATCH] 2623 move action mask generation to game and fix MARL masking --- .../_package_data/data_manipulation_marl.yaml | 2 + src/primaite/config/load.py | 15 ++++++ src/primaite/game/game.py | 18 +++++++ src/primaite/notebooks/Action-masking.ipynb | 53 ++++++++++++++++--- src/primaite/session/environment.py | 10 ++-- src/primaite/session/ray_envs.py | 19 +++++-- 6 files changed, 101 insertions(+), 16 deletions(-) diff --git a/src/primaite/config/_package_data/data_manipulation_marl.yaml b/src/primaite/config/_package_data/data_manipulation_marl.yaml index 2e8221a0..ba666781 100644 --- a/src/primaite/config/_package_data/data_manipulation_marl.yaml +++ b/src/primaite/config/_package_data/data_manipulation_marl.yaml @@ -733,6 +733,7 @@ agents: agent_settings: flatten_obs: true + action_masking: true - ref: defender_2 team: BLUE @@ -1316,6 +1317,7 @@ agents: agent_settings: flatten_obs: true + action_masking: true diff --git a/src/primaite/config/load.py b/src/primaite/config/load.py index 3483fc87..144e0733 100644 --- a/src/primaite/config/load.py +++ b/src/primaite/config/load.py @@ -44,3 +44,18 @@ def data_manipulation_config_path() -> Path: _LOGGER.error(msg) raise FileNotFoundError(msg) return path + + +def data_manipulation_marl_config_path() -> Path: + """ + Get the path to the MARL example config. + + :return: Path to yaml config file for the MARL scenario. + :rtype: Path + """ + path = _EXAMPLE_CFG / "data_manipulation_marl.yaml" + if not path.exists(): + msg = f"Example config does not exist: {path}. Have you run `primaite setup`?" + _LOGGER.error(msg) + raise FileNotFoundError(msg) + return path diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 3dc9571f..e7d13061 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -3,6 +3,7 @@ from ipaddress import IPv4Address from typing import Dict, List, Optional +import numpy as np from pydantic import BaseModel, ConfigDict from primaite import DEFAULT_BANDWIDTH, getLogger @@ -192,6 +193,23 @@ class PrimaiteGame: return True return False + def action_mask(self, agent_name: str) -> np.ndarray: + """ + Return the action mask for the agent. + + This is a boolean list corresponding to the agent's action space. A False entry means this action cannot be + performed during this step. + + :return: Action mask + :rtype: List[bool] + """ + agent = self.agents[agent_name] + mask = [True] * len(agent.action_manager.action_map) + for i, action in agent.action_manager.action_map.items(): + request = agent.action_manager.form_request(action_identifier=action[0], action_options=action[1]) + mask[i] = self.simulation._request_manager.check_valid(request, {}) + return np.asarray(mask) + def close(self) -> None: """Close the game, this will close the simulation.""" return NotImplemented diff --git a/src/primaite/notebooks/Action-masking.ipynb b/src/primaite/notebooks/Action-masking.ipynb index 822b8451..8090dacc 100644 --- a/src/primaite/notebooks/Action-masking.ipynb +++ b/src/primaite/notebooks/Action-masking.ipynb @@ -96,7 +96,7 @@ "metadata": {}, "outputs": [], "source": [ - "from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv\n", + "from primaite.session.ray_envs import PrimaiteRayEnv\n", "from ray.rllib.algorithms.ppo import PPOConfig\n", "import yaml\n", "from ray import air, tune\n" @@ -146,18 +146,59 @@ ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], - "source": [] + "source": [ + "## Action masking with MARL in Ray RLLib\n", + "Each agent has their own action mask, this is useful if the agents have different action spaces." + ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from primaite.session.ray_envs import PrimaiteRayMARLEnv\n", + "from primaite.config.load import data_manipulation_marl_config_path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(data_manipulation_marl_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + "env_config = cfg\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config = (\n", + " PPOConfig()\n", + " .multi_agent(\n", + " policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n", + " policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n", + " )\n", + " .environment(env=PrimaiteRayMARLEnv, env_config=cfg)\n", + " .env_runners(num_env_runners=0)\n", + " .training(train_batch_size=128)\n", + " )\n", + "\n", + "tune.Tuner(\n", + " \"PPO\",\n", + " run_config=air.RunConfig(\n", + " stop={\"timesteps_total\": 5 * 128},\n", + " ),\n", + " param_space=config\n", + ").fit()" + ] } ], "metadata": { diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 0520cce9..a87f0cde 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -52,14 +52,10 @@ class PrimaiteGymEnv(gymnasium.Env): :return: Action mask :rtype: List[bool] """ - mask = [True] * len(self.agent.action_manager.action_map) if not self.agent.action_masking: - return mask - - for i, action in self.agent.action_manager.action_map.items(): - request = self.agent.action_manager.form_request(action_identifier=action[0], action_options=action[1]) - mask[i] = self.game.simulation._request_manager.check_valid(request, {}) - return np.asarray(mask) + return np.asarray([True] * len(self.agent.action_manager.action_map)) + else: + return self.game.action_mask(self._agent_name) @property def agent(self) -> ProxyAgent: diff --git a/src/primaite/session/ray_envs.py b/src/primaite/session/ray_envs.py index 1fc7624f..12167f89 100644 --- a/src/primaite/session/ray_envs.py +++ b/src/primaite/session/ray_envs.py @@ -42,6 +42,15 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): self.observation_space = spaces.Dict( {name: spaces.flatten_space(agent.observation_manager.space) for name, agent in self.agents.items()} ) + for agent_name in self._agent_ids: + agent = self.game.rl_agents[agent_name] + if agent.action_masking: + self.observation_space[agent_name] = spaces.Dict( + { + "action_mask": spaces.MultiBinary(agent.action_manager.space.n), + "observations": self.observation_space[agent_name], + } + ) self.action_space = spaces.Dict({name: agent.action_manager.space for name, agent in self.agents.items()}) self._obs_space_in_preferred_format = True self._action_space_in_preferred_format = True @@ -127,13 +136,17 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): def _get_obs(self) -> Dict[str, ObsType]: """Return the current observation.""" - obs = {} + all_obs = {} for agent_name in self._agent_ids: agent = self.game.rl_agents[agent_name] unflat_space = agent.observation_manager.space unflat_obs = agent.observation_manager.current_observation - obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs) - return obs + obs = gymnasium.spaces.flatten(unflat_space, unflat_obs) + if agent.action_masking: + all_obs[agent_name] = {"action_mask": self.game.action_mask(agent_name), "observations": obs} + else: + all_obs[agent_name] = obs + return all_obs def close(self): """Close the simulation."""