2623 Add e2e tests for action masking

This commit is contained in:
Marek Wolan
2024-07-10 11:01:42 +01:00
parent faf268a9b9
commit 7201b7b8e0
8 changed files with 897 additions and 319 deletions

View File

@@ -208,7 +208,7 @@ class PrimaiteGame:
for i, action in agent.action_manager.action_map.items():
request = agent.action_manager.form_request(action_identifier=action[0], action_options=action[1])
mask[i] = self.simulation._request_manager.check_valid(request, {})
return np.asarray(mask)
return np.asarray(mask, dtype=np.int8)
def close(self) -> None:
"""Close the game, this will close the simulation."""

View File

@@ -17,7 +17,7 @@
"source": [
"from primaite.session.environment import PrimaiteGymEnv\n",
"from primaite.config.load import data_manipulation_config_path\n",
"from prettytable import PrettyTable"
"from prettytable import PrettyTable\n"
]
},
{
@@ -99,7 +99,9 @@
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
"from ray.rllib.algorithms.ppo import PPOConfig\n",
"import yaml\n",
"from ray import air, tune\n"
"from ray import air, tune\n",
"from ray.rllib.examples.rl_modules.classes.action_masking_rlm import ActionMaskingTorchRLModule\n",
"from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec\n"
]
},
{
@@ -124,25 +126,15 @@
"source": [
"config = (\n",
" PPOConfig()\n",
" .environment(env=PrimaiteRayEnv, env_config=cfg)\n",
" .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)\n",
" .environment(env=PrimaiteRayEnv, env_config=cfg, action_mask_key=\"action_mask\")\n",
" .rl_module(rl_module_spec=SingleAgentRLModuleSpec(module_class = ActionMaskingTorchRLModule))\n",
" .env_runners(num_env_runners=0)\n",
" .training(train_batch_size=128)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tune.Tuner(\n",
" \"PPO\",\n",
" run_config=air.RunConfig(\n",
" stop={\"timesteps_total\": 512}\n",
" ),\n",
" param_space=config\n",
").fit()\n"
")\n",
"algo = config.build()\n",
"for i in range(2):\n",
" results = algo.train()"
]
},
{
@@ -159,6 +151,7 @@
"metadata": {},
"outputs": [],
"source": [
"from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec\n",
"from primaite.session.ray_envs import PrimaiteRayMARLEnv\n",
"from primaite.config.load import data_manipulation_marl_config_path"
]
@@ -184,20 +177,20 @@
" PPOConfig()\n",
" .multi_agent(\n",
" policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n",
" policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n",
" policy_mapping_fn=lambda agent_id, *args, **kwargs: agent_id,\n",
" )\n",
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg)\n",
" .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)\n",
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg, action_mask_key=\"action_mask\")\n",
" .rl_module(rl_module_spec=MultiAgentRLModuleSpec(module_specs={\n",
" \"defender_1\":SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),\n",
" \"defender_2\":SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),\n",
" }))\n",
" .env_runners(num_env_runners=0)\n",
" .training(train_batch_size=128)\n",
" )\n",
"\n",
"tune.Tuner(\n",
" \"PPO\",\n",
" run_config=air.RunConfig(\n",
" stop={\"timesteps_total\": 5 * 128},\n",
" ),\n",
" param_space=config\n",
").fit()"
")\n",
"algo = config.build()\n",
"for i in range(2):\n",
" results = algo.train()"
]
}
],

View File

@@ -187,7 +187,7 @@ class PrimaiteRayEnv(gymnasium.Env):
# if action masking is enabled, intercept the step method and add action mask to observation
if self.env.agent.action_masking:
obs, *_ = self.env.step(action)
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}
new_obs = {"action_mask": self.game.action_mask(self.env._agent_name), "observations": obs}
return new_obs, *_
else:
return self.env.step(action)

File diff suppressed because it is too large Load Diff

View File

@@ -557,6 +557,7 @@ agents:
agent_settings:
flatten_obs: true
action_masking: true

View File

@@ -0,0 +1 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK

View File

@@ -0,0 +1,160 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import importlib
from typing import Dict
import yaml
from ray import air, init, tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.examples.rl_modules.classes.action_masking_rlm import ActionMaskingTorchRLModule
from sb3_contrib import MaskablePPO
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv
from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv
from tests import TEST_ASSETS_ROOT
init(local_mode=True)
CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"
MARL_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
def test_sb3_action_masking(monkeypatch):
# There's no simple way of capturing what the action mask was at every step, therefore we are mocking the action
# mask function here to save the output of the action mask method and pass through the result back to the agent.
old_action_mask_method = PrimaiteGame.action_mask
mask_history = []
def cache_action_mask(obj, agent_name):
mask = old_action_mask_method(obj, agent_name)
mask_history.append(mask)
return mask
# Even though it's easy to know which CAOS action the agent took by looking at agent history, we don't know which
# action map action integer that was, therefore we cache it by using monkeypatch
action_num_history = []
def cache_step(env, action: int):
action_num_history.append(action)
return PrimaiteGymEnv.step(env, action)
monkeypatch.setattr(PrimaiteGame, "action_mask", cache_action_mask)
env = PrimaiteGymEnv(CFG_PATH)
monkeypatch.setattr(env, "step", lambda action: cache_step(env, action))
model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, batch_size=32)
model.learn(512)
assert len(action_num_history) == len(mask_history) > 0
# Make sure the masks had at least some False entries, if it was all True then the mask was disabled
assert any([not all(x) for x in mask_history])
# When the agent takes action N from its action map, we need to have a look at the action mask and make sure that
# the N-th entry was True, meaning that it was a valid action at that step.
# This plucks out the mask history at step i, and at action entry a and checks that it's set to True, and this
# happens for all steps i in the episode
assert all(mask_history[i][a] for i, a in enumerate(action_num_history))
monkeypatch.undo()
def test_ray_single_agent_action_masking(monkeypatch):
"""Check that a Ray agent uses the action mask and never chooses invalid actions."""
with open(CFG_PATH, "r") as f:
cfg = yaml.safe_load(f)
for agent in cfg["agents"]:
if agent["ref"] == "defender":
agent["agent_settings"]["flatten_obs"] = True
# There's no simple way of capturing what the action mask was at every step, therefore we are mocking the step
# function to save the action mask and the agent's chosen action to a local variable.
old_step_method = PrimaiteRayEnv.step
action_num_history = []
mask_history = []
def cache_step(self, action: int):
action_num_history.append(action)
obs, *_ = old_step_method(self, action)
action_mask = obs["action_mask"]
mask_history.append(action_mask)
return obs, *_
monkeypatch.setattr(PrimaiteRayEnv, "step", lambda *args, **kwargs: cache_step(*args, **kwargs))
# Configure Ray PPO to use action masking by using the ActionMaskingTorchRLModule
config = (
PPOConfig()
.api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)
.environment(env=PrimaiteRayEnv, env_config=cfg, action_mask_key="action_mask")
.rl_module(rl_module_spec=SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule))
.env_runners(num_env_runners=0)
.training(train_batch_size=128)
)
algo = config.build()
algo.train()
assert len(action_num_history) == len(mask_history) > 0
# Make sure the masks had at least some False entries, if it was all True then the mask was disabled
assert any([not all(x) for x in mask_history])
# When the agent takes action N from its action map, we need to have a look at the action mask and make sure that
# the N-th action was valid.
# The first step uses the action mask provided by the reset method, so we are only checking from the second step
# onward, that's why we need to use mask_history[:-1] and action_num_history[1:]
assert all(mask_history[:-1][i][a] for i, a in enumerate(action_num_history[1:]))
monkeypatch.undo()
def test_ray_multi_agent_action_masking(monkeypatch):
"""Check that Ray agents never take invalid actions when using MARL."""
with open(MARL_PATH, "r") as f:
cfg = yaml.safe_load(f)
old_step_method = PrimaiteRayMARLEnv.step
action_num_history = {"defender_1": [], "defender_2": []}
mask_history = {"defender_1": [], "defender_2": []}
def cache_step(self, actions: Dict[str, int]):
for agent_name, action in actions.items():
action_num_history[agent_name].append(action)
obs, *_ = old_step_method(self, actions)
for (
agent_name,
o,
) in obs.items():
mask_history[agent_name].append(o["action_mask"])
return obs, *_
monkeypatch.setattr(PrimaiteRayMARLEnv, "step", lambda *args, **kwargs: cache_step(*args, **kwargs))
config = (
PPOConfig()
.multi_agent(
policies={
"defender_1",
"defender_2",
}, # These names are the same as the agents defined in the example config.
policy_mapping_fn=lambda agent_id, *args, **kwargs: agent_id,
)
.api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)
.environment(env=PrimaiteRayMARLEnv, env_config=cfg, action_mask_key="action_mask")
.rl_module(
rl_module_spec=MultiAgentRLModuleSpec(
module_specs={
"defender_1": SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),
"defender_2": SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),
}
)
)
.env_runners(num_env_runners=0)
.training(train_batch_size=128)
)
algo = config.build()
algo.train()
for agent_name in ["defender_1", "defender_2"]:
act_hist = action_num_history[agent_name]
mask_hist = mask_history[agent_name]
assert len(act_hist) == len(mask_hist) > 0
assert any([not all(x) for x in mask_hist])
assert all(mask_hist[:-1][i][a] for i, a in enumerate(act_hist[1:]))
monkeypatch.undo()

View File

@@ -99,7 +99,7 @@ class TestConfigureDatabaseAction:
game.step()
assert db_client.server_ip_address == old_ip
assert db_client.server_password is "admin123"
assert db_client.server_password == "admin123"
class TestConfigureRansomwareScriptAction: