2623 move action mask generation to game and fix MARL masking

This commit is contained in:
Marek Wolan
2024-07-09 15:59:50 +01:00
parent 5367f9ad53
commit faf268a9b9
6 changed files with 101 additions and 16 deletions

View File

@@ -733,6 +733,7 @@ agents:
agent_settings:
flatten_obs: true
action_masking: true
- ref: defender_2
team: BLUE
@@ -1316,6 +1317,7 @@ agents:
agent_settings:
flatten_obs: true
action_masking: true

View File

@@ -44,3 +44,18 @@ def data_manipulation_config_path() -> Path:
_LOGGER.error(msg)
raise FileNotFoundError(msg)
return path
def data_manipulation_marl_config_path() -> Path:
"""
Get the path to the MARL example config.
:return: Path to yaml config file for the MARL scenario.
:rtype: Path
"""
path = _EXAMPLE_CFG / "data_manipulation_marl.yaml"
if not path.exists():
msg = f"Example config does not exist: {path}. Have you run `primaite setup`?"
_LOGGER.error(msg)
raise FileNotFoundError(msg)
return path

View File

@@ -3,6 +3,7 @@
from ipaddress import IPv4Address
from typing import Dict, List, Optional
import numpy as np
from pydantic import BaseModel, ConfigDict
from primaite import DEFAULT_BANDWIDTH, getLogger
@@ -192,6 +193,23 @@ class PrimaiteGame:
return True
return False
def action_mask(self, agent_name: str) -> np.ndarray:
"""
Return the action mask for the agent.
This is a boolean list corresponding to the agent's action space. A False entry means this action cannot be
performed during this step.
:return: Action mask
:rtype: List[bool]
"""
agent = self.agents[agent_name]
mask = [True] * len(agent.action_manager.action_map)
for i, action in agent.action_manager.action_map.items():
request = agent.action_manager.form_request(action_identifier=action[0], action_options=action[1])
mask[i] = self.simulation._request_manager.check_valid(request, {})
return np.asarray(mask)
def close(self) -> None:
"""Close the game, this will close the simulation."""
return NotImplemented

View File

@@ -96,7 +96,7 @@
"metadata": {},
"outputs": [],
"source": [
"from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv\n",
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
"from ray.rllib.algorithms.ppo import PPOConfig\n",
"import yaml\n",
"from ray import air, tune\n"
@@ -146,18 +146,59 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": []
"source": [
"## Action masking with MARL in Ray RLLib\n",
"Each agent has their own action mask, this is useful if the agents have different action spaces."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"from primaite.session.ray_envs import PrimaiteRayMARLEnv\n",
"from primaite.config.load import data_manipulation_marl_config_path"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(data_manipulation_marl_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"env_config = cfg\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config = (\n",
" PPOConfig()\n",
" .multi_agent(\n",
" policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n",
" policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n",
" )\n",
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg)\n",
" .env_runners(num_env_runners=0)\n",
" .training(train_batch_size=128)\n",
" )\n",
"\n",
"tune.Tuner(\n",
" \"PPO\",\n",
" run_config=air.RunConfig(\n",
" stop={\"timesteps_total\": 5 * 128},\n",
" ),\n",
" param_space=config\n",
").fit()"
]
}
],
"metadata": {

View File

@@ -52,14 +52,10 @@ class PrimaiteGymEnv(gymnasium.Env):
:return: Action mask
:rtype: List[bool]
"""
mask = [True] * len(self.agent.action_manager.action_map)
if not self.agent.action_masking:
return mask
for i, action in self.agent.action_manager.action_map.items():
request = self.agent.action_manager.form_request(action_identifier=action[0], action_options=action[1])
mask[i] = self.game.simulation._request_manager.check_valid(request, {})
return np.asarray(mask)
return np.asarray([True] * len(self.agent.action_manager.action_map))
else:
return self.game.action_mask(self._agent_name)
@property
def agent(self) -> ProxyAgent:

View File

@@ -42,6 +42,15 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
self.observation_space = spaces.Dict(
{name: spaces.flatten_space(agent.observation_manager.space) for name, agent in self.agents.items()}
)
for agent_name in self._agent_ids:
agent = self.game.rl_agents[agent_name]
if agent.action_masking:
self.observation_space[agent_name] = spaces.Dict(
{
"action_mask": spaces.MultiBinary(agent.action_manager.space.n),
"observations": self.observation_space[agent_name],
}
)
self.action_space = spaces.Dict({name: agent.action_manager.space for name, agent in self.agents.items()})
self._obs_space_in_preferred_format = True
self._action_space_in_preferred_format = True
@@ -127,13 +136,17 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
def _get_obs(self) -> Dict[str, ObsType]:
"""Return the current observation."""
obs = {}
all_obs = {}
for agent_name in self._agent_ids:
agent = self.game.rl_agents[agent_name]
unflat_space = agent.observation_manager.space
unflat_obs = agent.observation_manager.current_observation
obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
return obs
obs = gymnasium.spaces.flatten(unflat_space, unflat_obs)
if agent.action_masking:
all_obs[agent_name] = {"action_mask": self.game.action_mask(agent_name), "observations": obs}
else:
all_obs[agent_name] = obs
return all_obs
def close(self):
"""Close the simulation."""