From 5367f9ad5376241ee40005381652cc13f628c53c Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 9 Jul 2024 15:27:03 +0100 Subject: [PATCH] 2623 Ray single agent action masking --- .../_package_data/data_manipulation.yaml | 1 + src/primaite/game/agent/interface.py | 3 + src/primaite/notebooks/Action-masking.ipynb | 184 ++++++++++++++++++ src/primaite/session/environment.py | 11 +- src/primaite/session/ray_envs.py | 31 ++- 5 files changed, 214 insertions(+), 16 deletions(-) create mode 100644 src/primaite/notebooks/Action-masking.ipynb diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml index 6d4ec9b4..97442903 100644 --- a/src/primaite/config/_package_data/data_manipulation.yaml +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -741,6 +741,7 @@ agents: agent_settings: flatten_obs: true + action_masking: true diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 95468331..01b7fb0a 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -69,6 +69,8 @@ class AgentSettings(BaseModel): "Configuration for when an agent begins performing it's actions" flatten_obs: bool = True "Whether to flatten the observation space before passing it to the agent. True by default." + action_masking: bool = True + "Whether to return action masks at each step." @classmethod def from_config(cls, config: Optional[Dict]) -> "AgentSettings": @@ -205,6 +207,7 @@ class ProxyAgent(AbstractAgent): ) self.most_recent_action: ActType self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False + self.action_masking: bool = agent_settings.action_masking if agent_settings else False def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ diff --git a/src/primaite/notebooks/Action-masking.ipynb b/src/primaite/notebooks/Action-masking.ipynb new file mode 100644 index 00000000..822b8451 --- /dev/null +++ b/src/primaite/notebooks/Action-masking.ipynb @@ -0,0 +1,184 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Action Masking\n", + "\n", + "PrimAITE environments support action masking. The action mask shows which of the agent's actions are applicable with the current environment state. For example, a node can only be turned on if it is currently turned off." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from primaite.session.environment import PrimaiteGymEnv\n", + "from primaite.config.load import data_manipulation_config_path\n", + "from prettytable import PrettyTable" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env = PrimaiteGymEnv(data_manipulation_config_path())\n", + "env.action_masking = True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The action mask is a list of booleans that specifies whether each action in the agent's action map is currently possible. Demonstrated here:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "act_table = PrettyTable((\"number\", \"action\", \"parameters\", \"mask\"))\n", + "mask = env.action_masks()\n", + "actions = env.agent.action_manager.action_map\n", + "max_str_len = 70\n", + "for act,mask in zip(actions.items(), mask):\n", + " act_num, act_data = act\n", + " act_type, act_params = act_data\n", + " act_params = s if len(s:=str(act_params)) List[bool]: + def action_masks(self) -> np.ndarray: """ Return the action mask for the agent. @@ -54,13 +53,13 @@ class PrimaiteGymEnv(gymnasium.Env): :rtype: List[bool] """ mask = [True] * len(self.agent.action_manager.action_map) - if not self.action_masking: + if not self.agent.action_masking: return mask for i, action in self.agent.action_manager.action_map.items(): request = self.agent.action_manager.form_request(action_identifier=action[0], action_options=action[1]) mask[i] = self.game.simulation._request_manager.check_valid(request, {}) - return mask + return np.asarray(mask) @property def agent(self) -> ProxyAgent: diff --git a/src/primaite/session/ray_envs.py b/src/primaite/session/ray_envs.py index fc5d73d8..1fc7624f 100644 --- a/src/primaite/session/ray_envs.py +++ b/src/primaite/session/ray_envs.py @@ -3,6 +3,7 @@ import json from typing import Dict, SupportsFloat, Tuple import gymnasium +from gymnasium import spaces from gymnasium.core import ActType, ObsType from ray.rllib.env.multi_agent_env import MultiAgentEnv @@ -38,15 +39,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): self.terminateds = set() self.truncateds = set() - self.observation_space = gymnasium.spaces.Dict( - { - name: gymnasium.spaces.flatten_space(agent.observation_manager.space) - for name, agent in self.agents.items() - } - ) - self.action_space = gymnasium.spaces.Dict( - {name: agent.action_manager.space for name, agent in self.agents.items()} + self.observation_space = spaces.Dict( + {name: spaces.flatten_space(agent.observation_manager.space) for name, agent in self.agents.items()} ) + self.action_space = spaces.Dict({name: agent.action_manager.space for name, agent in self.agents.items()}) self._obs_space_in_preferred_format = True self._action_space_in_preferred_format = True super().__init__() @@ -158,15 +154,30 @@ class PrimaiteRayEnv(gymnasium.Env): self.env = PrimaiteGymEnv(env_config=env_config) # self.env.episode_counter -= 1 self.action_space = self.env.action_space - self.observation_space = self.env.observation_space + if self.env.agent.action_masking: + self.observation_space = spaces.Dict( + {"action_mask": spaces.MultiBinary(self.env.action_space.n), "observations": self.env.observation_space} + ) + else: + self.observation_space = self.env.observation_space def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" + if self.env.agent.action_masking: + obs, *_ = self.env.reset(seed=seed) + new_obs = {"action_mask": self.env.action_masks(), "observations": obs} + return new_obs, *_ return self.env.reset(seed=seed) def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]: """Perform a step in the environment.""" - return self.env.step(action) + # if action masking is enabled, intercept the step method and add action mask to observation + if self.env.agent.action_masking: + obs, *_ = self.env.step(action) + new_obs = {"action_mask": self.env.action_masks(), "observations": obs} + return new_obs, *_ + else: + return self.env.step(action) def close(self): """Close the simulation."""