2623 Ray single agent action masking

This commit is contained in:
Marek Wolan
2024-07-09 15:27:03 +01:00
parent 470fa28ee1
commit 5367f9ad53
5 changed files with 214 additions and 16 deletions

View File

@@ -741,6 +741,7 @@ agents:
agent_settings:
flatten_obs: true
action_masking: true

View File

@@ -69,6 +69,8 @@ class AgentSettings(BaseModel):
"Configuration for when an agent begins performing it's actions"
flatten_obs: bool = True
"Whether to flatten the observation space before passing it to the agent. True by default."
action_masking: bool = True
"Whether to return action masks at each step."
@classmethod
def from_config(cls, config: Optional[Dict]) -> "AgentSettings":
@@ -205,6 +207,7 @@ class ProxyAgent(AbstractAgent):
)
self.most_recent_action: ActType
self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
self.action_masking: bool = agent_settings.action_masking if agent_settings else False
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""

View File

@@ -0,0 +1,184 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Action Masking\n",
"\n",
"PrimAITE environments support action masking. The action mask shows which of the agent's actions are applicable with the current environment state. For example, a node can only be turned on if it is currently turned off."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.session.environment import PrimaiteGymEnv\n",
"from primaite.config.load import data_manipulation_config_path\n",
"from prettytable import PrettyTable"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env = PrimaiteGymEnv(data_manipulation_config_path())\n",
"env.action_masking = True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The action mask is a list of booleans that specifies whether each action in the agent's action map is currently possible. Demonstrated here:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"act_table = PrettyTable((\"number\", \"action\", \"parameters\", \"mask\"))\n",
"mask = env.action_masks()\n",
"actions = env.agent.action_manager.action_map\n",
"max_str_len = 70\n",
"for act,mask in zip(actions.items(), mask):\n",
" act_num, act_data = act\n",
" act_type, act_params = act_data\n",
" act_params = s if len(s:=str(act_params))<max_str_len else f\"{s[:max_str_len-3]}...\"\n",
" act_table.add_row((act_num, act_type, act_params, mask))\n",
"print(act_table)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Action masking for Stable Baselines3 agents\n",
"SB3 agents automatically use the action_masks method during the training loop"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sb3_contrib import MaskablePPO\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = MaskablePPO(\"MlpPolicy\", env, gamma=0.4, seed=32)\n",
"model.learn(1024)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Action masking for Ray RLLib agents\n",
"Ray uses a different API to obtain action masks, but this is handled by the PrimaiteRayEnv and PrimaiteRayMarlEnv classes"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv\n",
"from ray.rllib.algorithms.ppo import PPOConfig\n",
"import yaml\n",
"from ray import air, tune\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(data_manipulation_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"for agent in cfg['agents']:\n",
" if agent[\"ref\"] == \"defender\":\n",
" agent['agent_settings']['flatten_obs'] = True\n",
"env_config = cfg\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config = (\n",
" PPOConfig()\n",
" .environment(env=PrimaiteRayEnv, env_config=cfg)\n",
" .env_runners(num_env_runners=0)\n",
" .training(train_batch_size=128)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tune.Tuner(\n",
" \"PPO\",\n",
" run_config=air.RunConfig(\n",
" stop={\"timesteps_total\": 512}\n",
" ),\n",
" param_space=config\n",
").fit()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -1,9 +1,10 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import json
from os import PathLike
from typing import Any, Dict, List, Optional, SupportsFloat, Tuple, Union
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
import gymnasium
import numpy as np
from gymnasium.core import ActType, ObsType
from primaite import getLogger
@@ -40,10 +41,8 @@ class PrimaiteGymEnv(gymnasium.Env):
"""Current episode number."""
self.total_reward_per_episode: Dict[int, float] = {}
"""Average rewards of agents per episode."""
self.action_masking: bool = False
"""Whether to use action masking."""
def action_masks(self) -> List[bool]:
def action_masks(self) -> np.ndarray:
"""
Return the action mask for the agent.
@@ -54,13 +53,13 @@ class PrimaiteGymEnv(gymnasium.Env):
:rtype: List[bool]
"""
mask = [True] * len(self.agent.action_manager.action_map)
if not self.action_masking:
if not self.agent.action_masking:
return mask
for i, action in self.agent.action_manager.action_map.items():
request = self.agent.action_manager.form_request(action_identifier=action[0], action_options=action[1])
mask[i] = self.game.simulation._request_manager.check_valid(request, {})
return mask
return np.asarray(mask)
@property
def agent(self) -> ProxyAgent:

View File

@@ -3,6 +3,7 @@ import json
from typing import Dict, SupportsFloat, Tuple
import gymnasium
from gymnasium import spaces
from gymnasium.core import ActType, ObsType
from ray.rllib.env.multi_agent_env import MultiAgentEnv
@@ -38,15 +39,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
self.terminateds = set()
self.truncateds = set()
self.observation_space = gymnasium.spaces.Dict(
{
name: gymnasium.spaces.flatten_space(agent.observation_manager.space)
for name, agent in self.agents.items()
}
)
self.action_space = gymnasium.spaces.Dict(
{name: agent.action_manager.space for name, agent in self.agents.items()}
self.observation_space = spaces.Dict(
{name: spaces.flatten_space(agent.observation_manager.space) for name, agent in self.agents.items()}
)
self.action_space = spaces.Dict({name: agent.action_manager.space for name, agent in self.agents.items()})
self._obs_space_in_preferred_format = True
self._action_space_in_preferred_format = True
super().__init__()
@@ -158,15 +154,30 @@ class PrimaiteRayEnv(gymnasium.Env):
self.env = PrimaiteGymEnv(env_config=env_config)
# self.env.episode_counter -= 1
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
if self.env.agent.action_masking:
self.observation_space = spaces.Dict(
{"action_mask": spaces.MultiBinary(self.env.action_space.n), "observations": self.env.observation_space}
)
else:
self.observation_space = self.env.observation_space
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
if self.env.agent.action_masking:
obs, *_ = self.env.reset(seed=seed)
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}
return new_obs, *_
return self.env.reset(seed=seed)
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
"""Perform a step in the environment."""
return self.env.step(action)
# if action masking is enabled, intercept the step method and add action mask to observation
if self.env.agent.action_masking:
obs, *_ = self.env.step(action)
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}
return new_obs, *_
else:
return self.env.step(action)
def close(self):
"""Close the simulation."""