2623 Add e2e tests for action masking
This commit is contained in:
@@ -208,7 +208,7 @@ class PrimaiteGame:
|
||||
for i, action in agent.action_manager.action_map.items():
|
||||
request = agent.action_manager.form_request(action_identifier=action[0], action_options=action[1])
|
||||
mask[i] = self.simulation._request_manager.check_valid(request, {})
|
||||
return np.asarray(mask)
|
||||
return np.asarray(mask, dtype=np.int8)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the game, this will close the simulation."""
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
"source": [
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"from prettytable import PrettyTable"
|
||||
"from prettytable import PrettyTable\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -99,7 +99,9 @@
|
||||
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
|
||||
"from ray.rllib.algorithms.ppo import PPOConfig\n",
|
||||
"import yaml\n",
|
||||
"from ray import air, tune\n"
|
||||
"from ray import air, tune\n",
|
||||
"from ray.rllib.examples.rl_modules.classes.action_masking_rlm import ActionMaskingTorchRLModule\n",
|
||||
"from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -124,25 +126,15 @@
|
||||
"source": [
|
||||
"config = (\n",
|
||||
" PPOConfig()\n",
|
||||
" .environment(env=PrimaiteRayEnv, env_config=cfg)\n",
|
||||
" .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)\n",
|
||||
" .environment(env=PrimaiteRayEnv, env_config=cfg, action_mask_key=\"action_mask\")\n",
|
||||
" .rl_module(rl_module_spec=SingleAgentRLModuleSpec(module_class = ActionMaskingTorchRLModule))\n",
|
||||
" .env_runners(num_env_runners=0)\n",
|
||||
" .training(train_batch_size=128)\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tune.Tuner(\n",
|
||||
" \"PPO\",\n",
|
||||
" run_config=air.RunConfig(\n",
|
||||
" stop={\"timesteps_total\": 512}\n",
|
||||
" ),\n",
|
||||
" param_space=config\n",
|
||||
").fit()\n"
|
||||
")\n",
|
||||
"algo = config.build()\n",
|
||||
"for i in range(2):\n",
|
||||
" results = algo.train()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -159,6 +151,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec\n",
|
||||
"from primaite.session.ray_envs import PrimaiteRayMARLEnv\n",
|
||||
"from primaite.config.load import data_manipulation_marl_config_path"
|
||||
]
|
||||
@@ -184,20 +177,20 @@
|
||||
" PPOConfig()\n",
|
||||
" .multi_agent(\n",
|
||||
" policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n",
|
||||
" policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n",
|
||||
" policy_mapping_fn=lambda agent_id, *args, **kwargs: agent_id,\n",
|
||||
" )\n",
|
||||
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg)\n",
|
||||
" .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)\n",
|
||||
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg, action_mask_key=\"action_mask\")\n",
|
||||
" .rl_module(rl_module_spec=MultiAgentRLModuleSpec(module_specs={\n",
|
||||
" \"defender_1\":SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),\n",
|
||||
" \"defender_2\":SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),\n",
|
||||
" }))\n",
|
||||
" .env_runners(num_env_runners=0)\n",
|
||||
" .training(train_batch_size=128)\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"tune.Tuner(\n",
|
||||
" \"PPO\",\n",
|
||||
" run_config=air.RunConfig(\n",
|
||||
" stop={\"timesteps_total\": 5 * 128},\n",
|
||||
" ),\n",
|
||||
" param_space=config\n",
|
||||
").fit()"
|
||||
")\n",
|
||||
"algo = config.build()\n",
|
||||
"for i in range(2):\n",
|
||||
" results = algo.train()"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
@@ -187,7 +187,7 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
# if action masking is enabled, intercept the step method and add action mask to observation
|
||||
if self.env.agent.action_masking:
|
||||
obs, *_ = self.env.step(action)
|
||||
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}
|
||||
new_obs = {"action_mask": self.game.action_mask(self.env._agent_name), "observations": obs}
|
||||
return new_obs, *_
|
||||
else:
|
||||
return self.env.step(action)
|
||||
|
||||
Reference in New Issue
Block a user