2623 Add e2e tests for action masking

This commit is contained in:
Marek Wolan
2024-07-10 11:01:42 +01:00
parent faf268a9b9
commit 7201b7b8e0
8 changed files with 897 additions and 319 deletions

View File

@@ -208,7 +208,7 @@ class PrimaiteGame:
for i, action in agent.action_manager.action_map.items():
request = agent.action_manager.form_request(action_identifier=action[0], action_options=action[1])
mask[i] = self.simulation._request_manager.check_valid(request, {})
return np.asarray(mask)
return np.asarray(mask, dtype=np.int8)
def close(self) -> None:
"""Close the game, this will close the simulation."""

View File

@@ -17,7 +17,7 @@
"source": [
"from primaite.session.environment import PrimaiteGymEnv\n",
"from primaite.config.load import data_manipulation_config_path\n",
"from prettytable import PrettyTable"
"from prettytable import PrettyTable\n"
]
},
{
@@ -99,7 +99,9 @@
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
"from ray.rllib.algorithms.ppo import PPOConfig\n",
"import yaml\n",
"from ray import air, tune\n"
"from ray import air, tune\n",
"from ray.rllib.examples.rl_modules.classes.action_masking_rlm import ActionMaskingTorchRLModule\n",
"from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec\n"
]
},
{
@@ -124,25 +126,15 @@
"source": [
"config = (\n",
" PPOConfig()\n",
" .environment(env=PrimaiteRayEnv, env_config=cfg)\n",
" .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)\n",
" .environment(env=PrimaiteRayEnv, env_config=cfg, action_mask_key=\"action_mask\")\n",
" .rl_module(rl_module_spec=SingleAgentRLModuleSpec(module_class = ActionMaskingTorchRLModule))\n",
" .env_runners(num_env_runners=0)\n",
" .training(train_batch_size=128)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tune.Tuner(\n",
" \"PPO\",\n",
" run_config=air.RunConfig(\n",
" stop={\"timesteps_total\": 512}\n",
" ),\n",
" param_space=config\n",
").fit()\n"
")\n",
"algo = config.build()\n",
"for i in range(2):\n",
" results = algo.train()"
]
},
{
@@ -159,6 +151,7 @@
"metadata": {},
"outputs": [],
"source": [
"from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec\n",
"from primaite.session.ray_envs import PrimaiteRayMARLEnv\n",
"from primaite.config.load import data_manipulation_marl_config_path"
]
@@ -184,20 +177,20 @@
" PPOConfig()\n",
" .multi_agent(\n",
" policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n",
" policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n",
" policy_mapping_fn=lambda agent_id, *args, **kwargs: agent_id,\n",
" )\n",
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg)\n",
" .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)\n",
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg, action_mask_key=\"action_mask\")\n",
" .rl_module(rl_module_spec=MultiAgentRLModuleSpec(module_specs={\n",
" \"defender_1\":SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),\n",
" \"defender_2\":SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),\n",
" }))\n",
" .env_runners(num_env_runners=0)\n",
" .training(train_batch_size=128)\n",
" )\n",
"\n",
"tune.Tuner(\n",
" \"PPO\",\n",
" run_config=air.RunConfig(\n",
" stop={\"timesteps_total\": 5 * 128},\n",
" ),\n",
" param_space=config\n",
").fit()"
")\n",
"algo = config.build()\n",
"for i in range(2):\n",
" results = algo.train()"
]
}
],

View File

@@ -187,7 +187,7 @@ class PrimaiteRayEnv(gymnasium.Env):
# if action masking is enabled, intercept the step method and add action mask to observation
if self.env.agent.action_masking:
obs, *_ = self.env.step(action)
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}
new_obs = {"action_mask": self.game.action_mask(self.env._agent_name), "observations": obs}
return new_obs, *_
else:
return self.env.step(action)