230 lines
6.8 KiB
Plaintext
230 lines
6.8 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Action Masking\n",
|
|
"\n",
|
|
"© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK\n",
|
|
"\n",
|
|
"PrimAITE environments support action masking. The action mask shows which of the agent's actions are applicable with the current environment state. For example, a node can only be turned on if it is currently turned off. Please refer to the action masking configuration user guide page for more information."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!primaite setup"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from primaite.session.environment import PrimaiteGymEnv\n",
|
|
"from primaite.config.load import data_manipulation_config_path\n",
|
|
"from prettytable import PrettyTable"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"env = PrimaiteGymEnv(data_manipulation_config_path())\n",
|
|
"env.action_masking = True"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The action mask is a list of booleans that specifies whether each action in the agent's action map is currently possible. Demonstrated here:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"act_table = PrettyTable((\"number\", \"action\", \"parameters\", \"mask\"))\n",
|
|
"mask = env.action_masks()\n",
|
|
"actions = env.agent.action_manager.action_map\n",
|
|
"max_str_len = 70\n",
|
|
"for act,mask in zip(actions.items(), mask):\n",
|
|
" act_num, act_data = act\n",
|
|
" act_type, act_params = act_data\n",
|
|
" act_params = s if len(s:=str(act_params))<max_str_len else f\"{s[:max_str_len-3]}...\"\n",
|
|
" act_table.add_row((act_num, act_type, act_params, mask))\n",
|
|
"print(act_table)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Action masking for Stable Baselines3 agents\n",
|
|
"SB3 agents automatically use the action_masks method during the training loop"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from sb3_contrib import MaskablePPO\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = MaskablePPO(\"MlpPolicy\", env, gamma=0.4, seed=32)\n",
|
|
"model.learn(1024)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Action masking for Ray RLLib agents\n",
|
|
"Ray uses a different API to obtain action masks, but this is handled by the PrimaiteRayEnv and PrimaiteRayMarlEnv classes"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
|
|
"from ray.rllib.algorithms.ppo import PPOConfig\n",
|
|
"import yaml\n",
|
|
"from ray.rllib.examples.rl_modules.classes.action_masking_rlm import ActionMaskingTorchRLModule\n",
|
|
"from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
|
" cfg = yaml.safe_load(f)\n",
|
|
"for agent in cfg['agents']:\n",
|
|
" if agent[\"ref\"] == \"defender\":\n",
|
|
" agent['agent_settings']['flatten_obs'] = True\n",
|
|
"env_config = cfg\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"config = (\n",
|
|
" PPOConfig()\n",
|
|
" .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)\n",
|
|
" .environment(env=PrimaiteRayEnv, env_config=cfg, action_mask_key=\"action_mask\")\n",
|
|
" .rl_module(rl_module_spec=SingleAgentRLModuleSpec(module_class = ActionMaskingTorchRLModule))\n",
|
|
" .env_runners(num_env_runners=0)\n",
|
|
" .training(train_batch_size=128)\n",
|
|
")\n",
|
|
"algo = config.build()\n",
|
|
"results = algo.train()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Action masking with MARL in Ray RLLib\n",
|
|
"\n",
|
|
"Each agent has their own action mask which is useful for multi-agent environments where each agent are configured with different action spaces.\n",
|
|
"\n",
|
|
"The code snippets below demonstrate how users can use multiple agents with action masks using the [UC2 MARL example config](./Training-an-RLLIB-MARL-System.ipynb)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec\n",
|
|
"from primaite.session.ray_envs import PrimaiteRayMARLEnv\n",
|
|
"from primaite.config.load import data_manipulation_marl_config_path"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"with open(data_manipulation_marl_config_path(), 'r') as f:\n",
|
|
" cfg = yaml.safe_load(f)\n",
|
|
"env_config = cfg\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"config = (\n",
|
|
" PPOConfig()\n",
|
|
" .multi_agent(\n",
|
|
" policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n",
|
|
" policy_mapping_fn=lambda agent_id, *args, **kwargs: agent_id,\n",
|
|
" )\n",
|
|
" .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)\n",
|
|
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg, action_mask_key=\"action_mask\")\n",
|
|
" .rl_module(rl_module_spec=MultiAgentRLModuleSpec(module_specs={\n",
|
|
" \"defender_1\":SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),\n",
|
|
" \"defender_2\":SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),\n",
|
|
" }))\n",
|
|
" .env_runners(num_env_runners=0)\n",
|
|
" .training(train_batch_size=128)\n",
|
|
")\n",
|
|
"algo = config.build()\n",
|
|
"results = algo.train()"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.12"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|