2623 Ray single agent action masking
This commit is contained in:
@@ -741,6 +741,7 @@ agents:
|
||||
|
||||
agent_settings:
|
||||
flatten_obs: true
|
||||
action_masking: true
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -69,6 +69,8 @@ class AgentSettings(BaseModel):
|
||||
"Configuration for when an agent begins performing it's actions"
|
||||
flatten_obs: bool = True
|
||||
"Whether to flatten the observation space before passing it to the agent. True by default."
|
||||
action_masking: bool = True
|
||||
"Whether to return action masks at each step."
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Optional[Dict]) -> "AgentSettings":
|
||||
@@ -205,6 +207,7 @@ class ProxyAgent(AbstractAgent):
|
||||
)
|
||||
self.most_recent_action: ActType
|
||||
self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
|
||||
self.action_masking: bool = agent_settings.action_masking if agent_settings else False
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
|
||||
184
src/primaite/notebooks/Action-masking.ipynb
Normal file
184
src/primaite/notebooks/Action-masking.ipynb
Normal file
@@ -0,0 +1,184 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Action Masking\n",
|
||||
"\n",
|
||||
"PrimAITE environments support action masking. The action mask shows which of the agent's actions are applicable with the current environment state. For example, a node can only be turned on if it is currently turned off."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"from prettytable import PrettyTable"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env = PrimaiteGymEnv(data_manipulation_config_path())\n",
|
||||
"env.action_masking = True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The action mask is a list of booleans that specifies whether each action in the agent's action map is currently possible. Demonstrated here:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"act_table = PrettyTable((\"number\", \"action\", \"parameters\", \"mask\"))\n",
|
||||
"mask = env.action_masks()\n",
|
||||
"actions = env.agent.action_manager.action_map\n",
|
||||
"max_str_len = 70\n",
|
||||
"for act,mask in zip(actions.items(), mask):\n",
|
||||
" act_num, act_data = act\n",
|
||||
" act_type, act_params = act_data\n",
|
||||
" act_params = s if len(s:=str(act_params))<max_str_len else f\"{s[:max_str_len-3]}...\"\n",
|
||||
" act_table.add_row((act_num, act_type, act_params, mask))\n",
|
||||
"print(act_table)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Action masking for Stable Baselines3 agents\n",
|
||||
"SB3 agents automatically use the action_masks method during the training loop"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sb3_contrib import MaskablePPO\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = MaskablePPO(\"MlpPolicy\", env, gamma=0.4, seed=32)\n",
|
||||
"model.learn(1024)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Action masking for Ray RLLib agents\n",
|
||||
"Ray uses a different API to obtain action masks, but this is handled by the PrimaiteRayEnv and PrimaiteRayMarlEnv classes"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv\n",
|
||||
"from ray.rllib.algorithms.ppo import PPOConfig\n",
|
||||
"import yaml\n",
|
||||
"from ray import air, tune\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"for agent in cfg['agents']:\n",
|
||||
" if agent[\"ref\"] == \"defender\":\n",
|
||||
" agent['agent_settings']['flatten_obs'] = True\n",
|
||||
"env_config = cfg\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"config = (\n",
|
||||
" PPOConfig()\n",
|
||||
" .environment(env=PrimaiteRayEnv, env_config=cfg)\n",
|
||||
" .env_runners(num_env_runners=0)\n",
|
||||
" .training(train_batch_size=128)\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tune.Tuner(\n",
|
||||
" \"PPO\",\n",
|
||||
" run_config=air.RunConfig(\n",
|
||||
" stop={\"timesteps_total\": 512}\n",
|
||||
" ),\n",
|
||||
" param_space=config\n",
|
||||
").fit()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -1,9 +1,10 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import json
|
||||
from os import PathLike
|
||||
from typing import Any, Dict, List, Optional, SupportsFloat, Tuple, Union
|
||||
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
|
||||
|
||||
import gymnasium
|
||||
import numpy as np
|
||||
from gymnasium.core import ActType, ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
@@ -40,10 +41,8 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""Current episode number."""
|
||||
self.total_reward_per_episode: Dict[int, float] = {}
|
||||
"""Average rewards of agents per episode."""
|
||||
self.action_masking: bool = False
|
||||
"""Whether to use action masking."""
|
||||
|
||||
def action_masks(self) -> List[bool]:
|
||||
def action_masks(self) -> np.ndarray:
|
||||
"""
|
||||
Return the action mask for the agent.
|
||||
|
||||
@@ -54,13 +53,13 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
:rtype: List[bool]
|
||||
"""
|
||||
mask = [True] * len(self.agent.action_manager.action_map)
|
||||
if not self.action_masking:
|
||||
if not self.agent.action_masking:
|
||||
return mask
|
||||
|
||||
for i, action in self.agent.action_manager.action_map.items():
|
||||
request = self.agent.action_manager.form_request(action_identifier=action[0], action_options=action[1])
|
||||
mask[i] = self.game.simulation._request_manager.check_valid(request, {})
|
||||
return mask
|
||||
return np.asarray(mask)
|
||||
|
||||
@property
|
||||
def agent(self) -> ProxyAgent:
|
||||
|
||||
@@ -3,6 +3,7 @@ import json
|
||||
from typing import Dict, SupportsFloat, Tuple
|
||||
|
||||
import gymnasium
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
||||
@@ -38,15 +39,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
|
||||
self.terminateds = set()
|
||||
self.truncateds = set()
|
||||
self.observation_space = gymnasium.spaces.Dict(
|
||||
{
|
||||
name: gymnasium.spaces.flatten_space(agent.observation_manager.space)
|
||||
for name, agent in self.agents.items()
|
||||
}
|
||||
)
|
||||
self.action_space = gymnasium.spaces.Dict(
|
||||
{name: agent.action_manager.space for name, agent in self.agents.items()}
|
||||
self.observation_space = spaces.Dict(
|
||||
{name: spaces.flatten_space(agent.observation_manager.space) for name, agent in self.agents.items()}
|
||||
)
|
||||
self.action_space = spaces.Dict({name: agent.action_manager.space for name, agent in self.agents.items()})
|
||||
self._obs_space_in_preferred_format = True
|
||||
self._action_space_in_preferred_format = True
|
||||
super().__init__()
|
||||
@@ -158,15 +154,30 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
self.env = PrimaiteGymEnv(env_config=env_config)
|
||||
# self.env.episode_counter -= 1
|
||||
self.action_space = self.env.action_space
|
||||
self.observation_space = self.env.observation_space
|
||||
if self.env.agent.action_masking:
|
||||
self.observation_space = spaces.Dict(
|
||||
{"action_mask": spaces.MultiBinary(self.env.action_space.n), "observations": self.env.observation_space}
|
||||
)
|
||||
else:
|
||||
self.observation_space = self.env.observation_space
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
if self.env.agent.action_masking:
|
||||
obs, *_ = self.env.reset(seed=seed)
|
||||
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}
|
||||
return new_obs, *_
|
||||
return self.env.reset(seed=seed)
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
|
||||
"""Perform a step in the environment."""
|
||||
return self.env.step(action)
|
||||
# if action masking is enabled, intercept the step method and add action mask to observation
|
||||
if self.env.agent.action_masking:
|
||||
obs, *_ = self.env.step(action)
|
||||
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}
|
||||
return new_obs, *_
|
||||
else:
|
||||
return self.env.step(action)
|
||||
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
|
||||
Reference in New Issue
Block a user