2623 move action mask generation to game and fix MARL masking
This commit is contained in:
@@ -733,6 +733,7 @@ agents:
|
||||
|
||||
agent_settings:
|
||||
flatten_obs: true
|
||||
action_masking: true
|
||||
|
||||
- ref: defender_2
|
||||
team: BLUE
|
||||
@@ -1316,6 +1317,7 @@ agents:
|
||||
|
||||
agent_settings:
|
||||
flatten_obs: true
|
||||
action_masking: true
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -44,3 +44,18 @@ def data_manipulation_config_path() -> Path:
|
||||
_LOGGER.error(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
return path
|
||||
|
||||
|
||||
def data_manipulation_marl_config_path() -> Path:
|
||||
"""
|
||||
Get the path to the MARL example config.
|
||||
|
||||
:return: Path to yaml config file for the MARL scenario.
|
||||
:rtype: Path
|
||||
"""
|
||||
path = _EXAMPLE_CFG / "data_manipulation_marl.yaml"
|
||||
if not path.exists():
|
||||
msg = f"Example config does not exist: {path}. Have you run `primaite setup`?"
|
||||
_LOGGER.error(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
return path
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import DEFAULT_BANDWIDTH, getLogger
|
||||
@@ -192,6 +193,23 @@ class PrimaiteGame:
|
||||
return True
|
||||
return False
|
||||
|
||||
def action_mask(self, agent_name: str) -> np.ndarray:
|
||||
"""
|
||||
Return the action mask for the agent.
|
||||
|
||||
This is a boolean list corresponding to the agent's action space. A False entry means this action cannot be
|
||||
performed during this step.
|
||||
|
||||
:return: Action mask
|
||||
:rtype: List[bool]
|
||||
"""
|
||||
agent = self.agents[agent_name]
|
||||
mask = [True] * len(agent.action_manager.action_map)
|
||||
for i, action in agent.action_manager.action_map.items():
|
||||
request = agent.action_manager.form_request(action_identifier=action[0], action_options=action[1])
|
||||
mask[i] = self.simulation._request_manager.check_valid(request, {})
|
||||
return np.asarray(mask)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the game, this will close the simulation."""
|
||||
return NotImplemented
|
||||
|
||||
@@ -96,7 +96,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv\n",
|
||||
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
|
||||
"from ray.rllib.algorithms.ppo import PPOConfig\n",
|
||||
"import yaml\n",
|
||||
"from ray import air, tune\n"
|
||||
@@ -146,18 +146,59 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"## Action masking with MARL in Ray RLLib\n",
|
||||
"Each agent has their own action mask, this is useful if the agents have different action spaces."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"from primaite.session.ray_envs import PrimaiteRayMARLEnv\n",
|
||||
"from primaite.config.load import data_manipulation_marl_config_path"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(data_manipulation_marl_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"env_config = cfg\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"config = (\n",
|
||||
" PPOConfig()\n",
|
||||
" .multi_agent(\n",
|
||||
" policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n",
|
||||
" policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n",
|
||||
" )\n",
|
||||
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg)\n",
|
||||
" .env_runners(num_env_runners=0)\n",
|
||||
" .training(train_batch_size=128)\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"tune.Tuner(\n",
|
||||
" \"PPO\",\n",
|
||||
" run_config=air.RunConfig(\n",
|
||||
" stop={\"timesteps_total\": 5 * 128},\n",
|
||||
" ),\n",
|
||||
" param_space=config\n",
|
||||
").fit()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -52,14 +52,10 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
:return: Action mask
|
||||
:rtype: List[bool]
|
||||
"""
|
||||
mask = [True] * len(self.agent.action_manager.action_map)
|
||||
if not self.agent.action_masking:
|
||||
return mask
|
||||
|
||||
for i, action in self.agent.action_manager.action_map.items():
|
||||
request = self.agent.action_manager.form_request(action_identifier=action[0], action_options=action[1])
|
||||
mask[i] = self.game.simulation._request_manager.check_valid(request, {})
|
||||
return np.asarray(mask)
|
||||
return np.asarray([True] * len(self.agent.action_manager.action_map))
|
||||
else:
|
||||
return self.game.action_mask(self._agent_name)
|
||||
|
||||
@property
|
||||
def agent(self) -> ProxyAgent:
|
||||
|
||||
@@ -42,6 +42,15 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
self.observation_space = spaces.Dict(
|
||||
{name: spaces.flatten_space(agent.observation_manager.space) for name, agent in self.agents.items()}
|
||||
)
|
||||
for agent_name in self._agent_ids:
|
||||
agent = self.game.rl_agents[agent_name]
|
||||
if agent.action_masking:
|
||||
self.observation_space[agent_name] = spaces.Dict(
|
||||
{
|
||||
"action_mask": spaces.MultiBinary(agent.action_manager.space.n),
|
||||
"observations": self.observation_space[agent_name],
|
||||
}
|
||||
)
|
||||
self.action_space = spaces.Dict({name: agent.action_manager.space for name, agent in self.agents.items()})
|
||||
self._obs_space_in_preferred_format = True
|
||||
self._action_space_in_preferred_format = True
|
||||
@@ -127,13 +136,17 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
|
||||
def _get_obs(self) -> Dict[str, ObsType]:
|
||||
"""Return the current observation."""
|
||||
obs = {}
|
||||
all_obs = {}
|
||||
for agent_name in self._agent_ids:
|
||||
agent = self.game.rl_agents[agent_name]
|
||||
unflat_space = agent.observation_manager.space
|
||||
unflat_obs = agent.observation_manager.current_observation
|
||||
obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
return obs
|
||||
obs = gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
if agent.action_masking:
|
||||
all_obs[agent_name] = {"action_mask": self.game.action_mask(agent_name), "observations": obs}
|
||||
else:
|
||||
all_obs[agent_name] = obs
|
||||
return all_obs
|
||||
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
|
||||
Reference in New Issue
Block a user