From 7201b7b8e096247621eecce69e078a9ebb14a6ed Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 10 Jul 2024 11:01:42 +0100 Subject: [PATCH] 2623 Add e2e tests for action masking --- src/primaite/game/game.py | 2 +- src/primaite/notebooks/Action-masking.ipynb | 53 +- src/primaite/session/ray_envs.py | 2 +- tests/assets/configs/multi_agent_session.yaml | 995 +++++++++++++----- .../assets/configs/test_primaite_session.yaml | 1 + .../action_masking/__init__.py | 1 + .../test_agents_use_action_masks.py | 160 +++ .../actions/test_configure_actions.py | 2 +- 8 files changed, 897 insertions(+), 319 deletions(-) create mode 100644 tests/e2e_integration_tests/action_masking/__init__.py create mode 100644 tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index e7d13061..252d1ce2 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -208,7 +208,7 @@ class PrimaiteGame: for i, action in agent.action_manager.action_map.items(): request = agent.action_manager.form_request(action_identifier=action[0], action_options=action[1]) mask[i] = self.simulation._request_manager.check_valid(request, {}) - return np.asarray(mask) + return np.asarray(mask, dtype=np.int8) def close(self) -> None: """Close the game, this will close the simulation.""" diff --git a/src/primaite/notebooks/Action-masking.ipynb b/src/primaite/notebooks/Action-masking.ipynb index 8090dacc..0e067b26 100644 --- a/src/primaite/notebooks/Action-masking.ipynb +++ b/src/primaite/notebooks/Action-masking.ipynb @@ -17,7 +17,7 @@ "source": [ "from primaite.session.environment import PrimaiteGymEnv\n", "from primaite.config.load import data_manipulation_config_path\n", - "from prettytable import PrettyTable" + "from prettytable import PrettyTable\n" ] }, { @@ -99,7 +99,9 @@ "from primaite.session.ray_envs import PrimaiteRayEnv\n", "from ray.rllib.algorithms.ppo import PPOConfig\n", "import yaml\n", - "from ray import air, tune\n" + "from ray import air, tune\n", + "from ray.rllib.examples.rl_modules.classes.action_masking_rlm import ActionMaskingTorchRLModule\n", + "from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec\n" ] }, { @@ -124,25 +126,15 @@ "source": [ "config = (\n", " PPOConfig()\n", - " .environment(env=PrimaiteRayEnv, env_config=cfg)\n", + " .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)\n", + " .environment(env=PrimaiteRayEnv, env_config=cfg, action_mask_key=\"action_mask\")\n", + " .rl_module(rl_module_spec=SingleAgentRLModuleSpec(module_class = ActionMaskingTorchRLModule))\n", " .env_runners(num_env_runners=0)\n", " .training(train_batch_size=128)\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tune.Tuner(\n", - " \"PPO\",\n", - " run_config=air.RunConfig(\n", - " stop={\"timesteps_total\": 512}\n", - " ),\n", - " param_space=config\n", - ").fit()\n" + ")\n", + "algo = config.build()\n", + "for i in range(2):\n", + " results = algo.train()" ] }, { @@ -159,6 +151,7 @@ "metadata": {}, "outputs": [], "source": [ + "from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec\n", "from primaite.session.ray_envs import PrimaiteRayMARLEnv\n", "from primaite.config.load import data_manipulation_marl_config_path" ] @@ -184,20 +177,20 @@ " PPOConfig()\n", " .multi_agent(\n", " policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n", - " policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n", + " policy_mapping_fn=lambda agent_id, *args, **kwargs: agent_id,\n", " )\n", - " .environment(env=PrimaiteRayMARLEnv, env_config=cfg)\n", + " .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)\n", + " .environment(env=PrimaiteRayMARLEnv, env_config=cfg, action_mask_key=\"action_mask\")\n", + " .rl_module(rl_module_spec=MultiAgentRLModuleSpec(module_specs={\n", + " \"defender_1\":SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),\n", + " \"defender_2\":SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),\n", + " }))\n", " .env_runners(num_env_runners=0)\n", " .training(train_batch_size=128)\n", - " )\n", - "\n", - "tune.Tuner(\n", - " \"PPO\",\n", - " run_config=air.RunConfig(\n", - " stop={\"timesteps_total\": 5 * 128},\n", - " ),\n", - " param_space=config\n", - ").fit()" + ")\n", + "algo = config.build()\n", + "for i in range(2):\n", + " results = algo.train()" ] } ], diff --git a/src/primaite/session/ray_envs.py b/src/primaite/session/ray_envs.py index 12167f89..1adc324c 100644 --- a/src/primaite/session/ray_envs.py +++ b/src/primaite/session/ray_envs.py @@ -187,7 +187,7 @@ class PrimaiteRayEnv(gymnasium.Env): # if action masking is enabled, intercept the step method and add action mask to observation if self.env.agent.action_masking: obs, *_ = self.env.step(action) - new_obs = {"action_mask": self.env.action_masks(), "observations": obs} + new_obs = {"action_mask": self.game.action_mask(self.env._agent_name), "observations": obs} return new_obs, *_ else: return self.env.step(action) diff --git a/tests/assets/configs/multi_agent_session.yaml b/tests/assets/configs/multi_agent_session.yaml index 971f36f8..a2d64605 100644 --- a/tests/assets/configs/multi_agent_session.yaml +++ b/tests/assets/configs/multi_agent_session.yaml @@ -1,3 +1,10 @@ +io_settings: + save_agent_actions: false + save_step_metadata: false + save_pcap_logs: false + save_sys_logs: false + + game: max_episode_length: 128 ports: @@ -13,31 +20,105 @@ game: agents: - ref: client_2_green_user team: GREEN - type: PeriodicAgent + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 observation_space: null action_space: action_list: - type: DONOTHING - type: NODE_APPLICATION_EXECUTE - options: nodes: - node_name: client_2 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 reward_function: reward_components: - - type: DUMMY + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_2 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_2 + + - ref: client_1_green_user + team: GREEN + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + observation_space: null + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + + reward_function: + reward_components: + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_1 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_1 + + + - agent_settings: # options specific to this particular agent type, basically args of __init__(self) - start_settings: - start_step: 25 - frequency: 20 - variance: 5 - ref: data_manipulation_attacker team: RED @@ -57,6 +138,9 @@ agents: - node_name: client_1 applications: - application_name: DataManipulationBot + - node_name: client_2 + applications: + - application_name: DataManipulationBot max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 @@ -71,7 +155,7 @@ agents: frequency: 20 variance: 5 - - ref: defender1 + - ref: defender_1 team: BLUE type: ProxyAgent @@ -194,318 +278,425 @@ agents: 3: action: "NODE_SERVICE_START" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 4: action: "NODE_SERVICE_PAUSE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 5: action: "NODE_SERVICE_RESUME" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 6: action: "NODE_SERVICE_RESTART" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 7: action: "NODE_SERVICE_DISABLE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 8: action: "NODE_SERVICE_ENABLE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 9: # check database.db file action: "NODE_FILE_SCAN" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 10: - action: "NODE_FILE_CHECKHASH" + action: "NODE_FILE_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 11: action: "NODE_FILE_DELETE" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 12: action: "NODE_FILE_REPAIR" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 13: action: "NODE_SERVICE_FIX" options: - node_id: 2 - service_id: 0 + node_id: 2 + service_id: 0 14: action: "NODE_FOLDER_SCAN" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 15: - action: "NODE_FOLDER_CHECKHASH" + action: "NODE_FOLDER_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 16: action: "NODE_FOLDER_REPAIR" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 17: action: "NODE_FOLDER_RESTORE" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 18: action: "NODE_OS_SCAN" options: - node_id: 2 - 19: # shutdown client 1 + node_id: 0 + 19: action: "NODE_SHUTDOWN" options: - node_id: 5 + node_id: 0 20: - action: "NODE_STARTUP" + action: NODE_STARTUP options: - node_id: 5 + node_id: 0 21: - action: "NODE_RESET" + action: NODE_RESET options: - node_id: 5 - 22: # "ACL: ADDRULE - Block outgoing traffic from client 1" (not supported in Primaite) - action: "ROUTER_ACL_ADDRULE" + node_id: 0 + 22: + action: "NODE_OS_SCAN" options: - target_router: router_1 - position: 1 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite) - action: "ROUTER_ACL_ADDRULE" + node_id: 1 + 23: + action: "NODE_SHUTDOWN" options: - target_router: router_1 - position: 2 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 24: # block tcp traffic from client 1 to web app - action: "ROUTER_ACL_ADDRULE" + node_id: 1 + 24: + action: NODE_STARTUP options: - target_router: router_1 - position: 3 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 25: # block tcp traffic from client 2 to web app - action: "ROUTER_ACL_ADDRULE" + node_id: 1 + 25: + action: NODE_RESET options: - target_router: router_1 - position: 4 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 26: - action: "ROUTER_ACL_ADDRULE" + node_id: 1 + 26: # old action num: 18 + action: "NODE_OS_SCAN" options: - target_router: router_1 - position: 5 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + node_id: 2 27: - action: "ROUTER_ACL_ADDRULE" + action: "NODE_SHUTDOWN" options: - target_router: router_1 - position: 6 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + node_id: 2 28: - action: "ROUTER_ACL_REMOVERULE" + action: NODE_STARTUP options: - target_router: router_1 - position: 0 + node_id: 2 29: - action: "ROUTER_ACL_REMOVERULE" + action: NODE_RESET options: - target_router: router_1 - position: 1 + node_id: 2 30: - action: "ROUTER_ACL_REMOVERULE" + action: "NODE_OS_SCAN" options: - target_router: router_1 - position: 2 + node_id: 3 31: - action: "ROUTER_ACL_REMOVERULE" + action: "NODE_SHUTDOWN" options: - target_router: router_1 - position: 3 + node_id: 3 32: - action: "ROUTER_ACL_REMOVERULE" + action: NODE_STARTUP options: - target_router: router_1 - position: 4 + node_id: 3 33: - action: "ROUTER_ACL_REMOVERULE" + action: NODE_RESET options: - target_router: router_1 - position: 5 + node_id: 3 34: - action: "ROUTER_ACL_REMOVERULE" + action: "NODE_OS_SCAN" options: - target_router: router_1 - position: 6 + node_id: 4 35: - action: "ROUTER_ACL_REMOVERULE" + action: "NODE_SHUTDOWN" options: - target_router: router_1 - position: 7 + node_id: 4 36: - action: "ROUTER_ACL_REMOVERULE" + action: NODE_STARTUP options: - target_router: router_1 - position: 8 + node_id: 4 37: - action: "ROUTER_ACL_REMOVERULE" + action: NODE_RESET options: - target_router: router_1 - position: 9 + node_id: 4 38: - action: "HOST_NIC_DISABLE" + action: "NODE_OS_SCAN" options: - node_id: 0 - nic_id: 0 - 39: - action: "HOST_NIC_ENABLE" + node_id: 5 + 39: # old action num: 19 # shutdown client 1 + action: "NODE_SHUTDOWN" options: - node_id: 0 - nic_id: 0 - 40: - action: "HOST_NIC_DISABLE" + node_id: 5 + 40: # old action num: 20 + action: NODE_STARTUP options: - node_id: 1 - nic_id: 0 - 41: - action: "HOST_NIC_ENABLE" + node_id: 5 + 41: # old action num: 21 + action: NODE_RESET options: - node_id: 1 - nic_id: 0 + node_id: 5 42: - action: "HOST_NIC_DISABLE" + action: "NODE_OS_SCAN" options: - node_id: 2 - nic_id: 0 + node_id: 6 43: + action: "NODE_SHUTDOWN" + options: + node_id: 6 + 44: + action: NODE_STARTUP + options: + node_id: 6 + 45: + action: NODE_RESET + options: + node_id: 6 + + 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 1 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 2 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 48: # old action num: 24 # block tcp traffic from client 1 to web app + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 3 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 49: # old action num: 25 # block tcp traffic from client 2 to web app + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 4 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 50: # old action num: 26 + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 5 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 51: # old action num: 27 + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 6 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 52: # old action num: 28 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 0 + 53: # old action num: 29 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 1 + 54: # old action num: 30 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 2 + 55: # old action num: 31 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 3 + 56: # old action num: 32 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 4 + 57: # old action num: 33 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 5 + 58: # old action num: 34 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 6 + 59: # old action num: 35 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 7 + 60: # old action num: 36 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 8 + 61: # old action num: 37 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 9 + 62: # old action num: 38 + action: "HOST_NIC_DISABLE" + options: + node_id: 0 + nic_id: 0 + 63: # old action num: 39 + action: "HOST_NIC_ENABLE" + options: + node_id: 0 + nic_id: 0 + 64: # old action num: 40 + action: "HOST_NIC_DISABLE" + options: + node_id: 1 + nic_id: 0 + 65: # old action num: 41 + action: "HOST_NIC_ENABLE" + options: + node_id: 1 + nic_id: 0 + 66: # old action num: 42 + action: "HOST_NIC_DISABLE" + options: + node_id: 2 + nic_id: 0 + 67: # old action num: 43 action: "HOST_NIC_ENABLE" options: node_id: 2 nic_id: 0 - 44: + 68: # old action num: 44 action: "HOST_NIC_DISABLE" options: node_id: 3 nic_id: 0 - 45: + 69: # old action num: 45 action: "HOST_NIC_ENABLE" options: node_id: 3 nic_id: 0 - 46: + 70: # old action num: 46 action: "HOST_NIC_DISABLE" options: node_id: 4 nic_id: 0 - 47: + 71: # old action num: 47 action: "HOST_NIC_ENABLE" options: node_id: 4 nic_id: 0 - 48: + 72: # old action num: 48 action: "HOST_NIC_DISABLE" options: node_id: 4 nic_id: 1 - 49: + 73: # old action num: 49 action: "HOST_NIC_ENABLE" options: node_id: 4 nic_id: 1 - 50: + 74: # old action num: 50 action: "HOST_NIC_DISABLE" options: node_id: 5 nic_id: 0 - 51: + 75: # old action num: 51 action: "HOST_NIC_ENABLE" options: node_id: 5 nic_id: 0 - 52: + 76: # old action num: 52 action: "HOST_NIC_DISABLE" options: node_id: 6 nic_id: 0 - 53: + 77: # old action num: 53 action: "HOST_NIC_ENABLE" options: node_id: 6 nic_id: 0 - options: nodes: - node_name: domain_controller - node_name: web_server + applications: + - application_name: DatabaseClient + services: + - service_name: WebServer - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db + services: + - service_name: DatabaseService - node_name: backup_server - node_name: security_suite - node_name: client_1 - node_name: client_2 + max_folders_per_node: 2 max_files_per_folder: 2 max_services_per_node: 2 @@ -521,27 +712,30 @@ agents: - 192.168.10.22 - 192.168.10.110 + reward_function: reward_components: - type: DATABASE_FILE_INTEGRITY - weight: 0.5 + weight: 0.40 options: node_hostname: database_server folder_name: database file_name: database.db - - - - type: WEB_SERVER_404_PENALTY - weight: 0.5 + - type: SHARED_REWARD + weight: 1.0 options: - node_hostname: web_server - service_name: web_server_web_service + agent_name: client_1_green_user + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_2_green_user agent_settings: - # ... + flatten_obs: true + action_masking: true - - ref: defender2 + - ref: defender_2 team: BLUE type: ProxyAgent @@ -640,7 +834,11 @@ agents: - type: NODE_STARTUP - type: NODE_RESET - type: ROUTER_ACL_ADDRULE + options: + target_router: router_1 - type: ROUTER_ACL_REMOVERULE + options: + target_router: router_1 - type: HOST_NIC_ENABLE - type: HOST_NIC_DISABLE @@ -664,99 +862,196 @@ agents: 3: action: "NODE_SERVICE_START" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 4: action: "NODE_SERVICE_PAUSE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 5: action: "NODE_SERVICE_RESUME" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 6: action: "NODE_SERVICE_RESTART" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 7: action: "NODE_SERVICE_DISABLE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 8: action: "NODE_SERVICE_ENABLE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 9: # check database.db file action: "NODE_FILE_SCAN" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 10: - action: "NODE_FILE_CHECKHASH" + action: "NODE_FILE_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 11: action: "NODE_FILE_DELETE" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 12: action: "NODE_FILE_REPAIR" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 13: action: "NODE_SERVICE_FIX" options: - node_id: 2 - service_id: 0 + node_id: 2 + service_id: 0 14: action: "NODE_FOLDER_SCAN" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 15: - action: "NODE_FOLDER_CHECKHASH" + action: "NODE_FOLDER_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 16: action: "NODE_FOLDER_REPAIR" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 17: action: "NODE_FOLDER_RESTORE" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 18: action: "NODE_OS_SCAN" options: - node_id: 2 - 19: # shutdown client 1 + node_id: 0 + 19: action: "NODE_SHUTDOWN" options: - node_id: 5 + node_id: 0 20: - action: "NODE_STARTUP" + action: NODE_STARTUP options: - node_id: 5 + node_id: 0 21: - action: "NODE_RESET" + action: NODE_RESET options: - node_id: 5 - 22: # "ACL: ADDRULE - Block outgoing traffic from client 1" (not supported in Primaite) + node_id: 0 + 22: + action: "NODE_OS_SCAN" + options: + node_id: 1 + 23: + action: "NODE_SHUTDOWN" + options: + node_id: 1 + 24: + action: NODE_STARTUP + options: + node_id: 1 + 25: + action: NODE_RESET + options: + node_id: 1 + 26: # old action num: 18 + action: "NODE_OS_SCAN" + options: + node_id: 2 + 27: + action: "NODE_SHUTDOWN" + options: + node_id: 2 + 28: + action: NODE_STARTUP + options: + node_id: 2 + 29: + action: NODE_RESET + options: + node_id: 2 + 30: + action: "NODE_OS_SCAN" + options: + node_id: 3 + 31: + action: "NODE_SHUTDOWN" + options: + node_id: 3 + 32: + action: NODE_STARTUP + options: + node_id: 3 + 33: + action: NODE_RESET + options: + node_id: 3 + 34: + action: "NODE_OS_SCAN" + options: + node_id: 4 + 35: + action: "NODE_SHUTDOWN" + options: + node_id: 4 + 36: + action: NODE_STARTUP + options: + node_id: 4 + 37: + action: NODE_RESET + options: + node_id: 4 + 38: + action: "NODE_OS_SCAN" + options: + node_id: 5 + 39: # old action num: 19 # shutdown client 1 + action: "NODE_SHUTDOWN" + options: + node_id: 5 + 40: # old action num: 20 + action: NODE_STARTUP + options: + node_id: 5 + 41: # old action num: 21 + action: NODE_RESET + options: + node_id: 5 + 42: + action: "NODE_OS_SCAN" + options: + node_id: 6 + 43: + action: "NODE_SHUTDOWN" + options: + node_id: 6 + 44: + action: NODE_STARTUP + options: + node_id: 6 + 45: + action: NODE_RESET + options: + node_id: 6 + + 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" action: "ROUTER_ACL_ADDRULE" options: target_router: router_1 @@ -769,7 +1064,7 @@ agents: protocol_id: 1 source_wildcard_id: 0 dest_wildcard_id: 0 - 23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite) + 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" action: "ROUTER_ACL_ADDRULE" options: target_router: router_1 @@ -782,7 +1077,7 @@ agents: protocol_id: 1 source_wildcard_id: 0 dest_wildcard_id: 0 - 24: # block tcp traffic from client 1 to web app + 48: # old action num: 24 # block tcp traffic from client 1 to web app action: "ROUTER_ACL_ADDRULE" options: target_router: router_1 @@ -795,7 +1090,7 @@ agents: protocol_id: 3 source_wildcard_id: 0 dest_wildcard_id: 0 - 25: # block tcp traffic from client 2 to web app + 49: # old action num: 25 # block tcp traffic from client 2 to web app action: "ROUTER_ACL_ADDRULE" options: target_router: router_1 @@ -808,7 +1103,7 @@ agents: protocol_id: 3 source_wildcard_id: 0 dest_wildcard_id: 0 - 26: + 50: # old action num: 26 action: "ROUTER_ACL_ADDRULE" options: target_router: router_1 @@ -821,7 +1116,7 @@ agents: protocol_id: 3 source_wildcard_id: 0 dest_wildcard_id: 0 - 27: + 51: # old action num: 27 action: "ROUTER_ACL_ADDRULE" options: target_router: router_1 @@ -834,67 +1129,159 @@ agents: protocol_id: 3 source_wildcard_id: 0 dest_wildcard_id: 0 - 28: + 52: # old action num: 28 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 0 - 29: + 53: # old action num: 29 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 1 - 30: + 54: # old action num: 30 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 2 - 31: + 55: # old action num: 31 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 3 - 32: + 56: # old action num: 32 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 4 - 33: + 57: # old action num: 33 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 5 - 34: + 58: # old action num: 34 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 6 - 35: + 59: # old action num: 35 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 7 - 36: + 60: # old action num: 36 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 8 - 37: + 61: # old action num: 37 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 9 + 62: # old action num: 38 + action: "HOST_NIC_DISABLE" + options: + node_id: 0 + nic_id: 0 + 63: # old action num: 39 + action: "HOST_NIC_ENABLE" + options: + node_id: 0 + nic_id: 0 + 64: # old action num: 40 + action: "HOST_NIC_DISABLE" + options: + node_id: 1 + nic_id: 0 + 65: # old action num: 41 + action: "HOST_NIC_ENABLE" + options: + node_id: 1 + nic_id: 0 + 66: # old action num: 42 + action: "HOST_NIC_DISABLE" + options: + node_id: 2 + nic_id: 0 + 67: # old action num: 43 + action: "HOST_NIC_ENABLE" + options: + node_id: 2 + nic_id: 0 + 68: # old action num: 44 + action: "HOST_NIC_DISABLE" + options: + node_id: 3 + nic_id: 0 + 69: # old action num: 45 + action: "HOST_NIC_ENABLE" + options: + node_id: 3 + nic_id: 0 + 70: # old action num: 46 + action: "HOST_NIC_DISABLE" + options: + node_id: 4 + nic_id: 0 + 71: # old action num: 47 + action: "HOST_NIC_ENABLE" + options: + node_id: 4 + nic_id: 0 + 72: # old action num: 48 + action: "HOST_NIC_DISABLE" + options: + node_id: 4 + nic_id: 1 + 73: # old action num: 49 + action: "HOST_NIC_ENABLE" + options: + node_id: 4 + nic_id: 1 + 74: # old action num: 50 + action: "HOST_NIC_DISABLE" + options: + node_id: 5 + nic_id: 0 + 75: # old action num: 51 + action: "HOST_NIC_ENABLE" + options: + node_id: 5 + nic_id: 0 + 76: # old action num: 52 + action: "HOST_NIC_DISABLE" + options: + node_id: 6 + nic_id: 0 + 77: # old action num: 53 + action: "HOST_NIC_ENABLE" + options: + node_id: 6 + nic_id: 0 + options: nodes: - node_name: domain_controller - node_name: web_server + applications: + - application_name: DatabaseClient + services: + - service_name: WebServer - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db + services: + - service_name: DatabaseService - node_name: backup_server - node_name: security_suite - node_name: client_1 - node_name: client_2 + max_folders_per_node: 2 max_files_per_folder: 2 max_services_per_node: 2 @@ -913,50 +1300,63 @@ agents: reward_function: reward_components: - type: DATABASE_FILE_INTEGRITY - weight: 0.5 + weight: 0.40 options: node_hostname: database_server folder_name: database file_name: database.db - - - - type: WEB_SERVER_404_PENALTY - weight: 0.5 + - type: SHARED_REWARD + weight: 1.0 options: - node_hostname: web_server - service_name: web_server_web_service + agent_name: client_1_green_user + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_2_green_user agent_settings: - # ... - + flatten_obs: true + action_masking: true simulation: network: + nmne_config: + capture_nmne: true + nmne_capture_keywords: + - DELETE nodes: - - type: router - hostname: router_1 + - hostname: router_1 + type: router num_ports: 5 ports: 1: ip_address: 192.168.1.1 subnet_mask: 255.255.255.0 2: - ip_address: 192.168.1.1 + ip_address: 192.168.10.1 subnet_mask: 255.255.255.0 acl: - 0: + 18: action: PERMIT src_port: POSTGRES_SERVER dst_port: POSTGRES_SERVER - 1: + 19: action: PERMIT src_port: DNS dst_port: DNS + 20: + action: PERMIT + src_port: FTP + dst_port: FTP + 21: + action: PERMIT + src_port: HTTP + dst_port: HTTP 22: action: PERMIT src_port: ARP @@ -965,16 +1365,16 @@ simulation: action: PERMIT protocol: ICMP - - type: switch - hostname: switch_1 + - hostname: switch_1 + type: switch num_ports: 8 - - type: switch - hostname: switch_2 + - hostname: switch_2 + type: switch num_ports: 8 - - type: server - hostname: domain_controller + - hostname: domain_controller + type: server ip_address: 192.168.1.10 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -984,8 +1384,8 @@ simulation: domain_mapping: arcd.com: 192.168.1.12 # web server - - type: server - hostname: web_server + - hostname: web_server + type: server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -997,17 +1397,21 @@ simulation: options: db_server_ip: 192.168.1.14 - - type: server - hostname: database_server + + - hostname: database_server + type: server ip_address: 192.168.1.14 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - type: DatabaseService + options: + backup_server_ip: 192.168.1.16 + - type: FTPClient - - type: server - hostname: backup_server + - hostname: backup_server + type: server ip_address: 192.168.1.16 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -1015,8 +1419,8 @@ simulation: services: - type: FTPServer - - type: server - hostname: security_suite + - hostname: security_suite + type: server ip_address: 192.168.1.110 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -1026,8 +1430,8 @@ simulation: ip_address: 192.168.10.110 subnet_mask: 255.255.255.0 - - type: computer - hostname: client_1 + - hostname: client_1 + type: computer ip_address: 192.168.10.21 subnet_mask: 255.255.255.0 default_gateway: 192.168.10.1 @@ -1035,24 +1439,43 @@ simulation: applications: - type: DataManipulationBot options: - port_scan_p_of_success: 0.1 - data_manipulation_p_of_success: 0.1 + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 + - type: WebBrowser + options: + target_url: http://arcd.com/users/ + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 services: - type: DNSClient - - type: computer - hostname: client_2 + - hostname: client_2 + type: computer ip_address: 192.168.10.22 subnet_mask: 255.255.255.0 default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - type: WebBrowser + options: + target_url: http://arcd.com/users/ + - type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 services: - type: DNSClient + + links: - endpoint_a_hostname: router_1 endpoint_a_port: 1 diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index 54143af0..7c894ba0 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -557,6 +557,7 @@ agents: agent_settings: flatten_obs: true + action_masking: true diff --git a/tests/e2e_integration_tests/action_masking/__init__.py b/tests/e2e_integration_tests/action_masking/__init__.py new file mode 100644 index 00000000..be6c00e7 --- /dev/null +++ b/tests/e2e_integration_tests/action_masking/__init__.py @@ -0,0 +1 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK diff --git a/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py new file mode 100644 index 00000000..3efda71a --- /dev/null +++ b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py @@ -0,0 +1,160 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import importlib +from typing import Dict + +import yaml +from ray import air, init, tune +from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.examples.rl_modules.classes.action_masking_rlm import ActionMaskingTorchRLModule +from sb3_contrib import MaskablePPO + +from primaite.game.game import PrimaiteGame +from primaite.session.environment import PrimaiteGymEnv +from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv +from tests import TEST_ASSETS_ROOT + +init(local_mode=True) + +CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml" +MARL_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml" + + +def test_sb3_action_masking(monkeypatch): + # There's no simple way of capturing what the action mask was at every step, therefore we are mocking the action + # mask function here to save the output of the action mask method and pass through the result back to the agent. + old_action_mask_method = PrimaiteGame.action_mask + mask_history = [] + + def cache_action_mask(obj, agent_name): + mask = old_action_mask_method(obj, agent_name) + mask_history.append(mask) + return mask + + # Even though it's easy to know which CAOS action the agent took by looking at agent history, we don't know which + # action map action integer that was, therefore we cache it by using monkeypatch + action_num_history = [] + + def cache_step(env, action: int): + action_num_history.append(action) + return PrimaiteGymEnv.step(env, action) + + monkeypatch.setattr(PrimaiteGame, "action_mask", cache_action_mask) + env = PrimaiteGymEnv(CFG_PATH) + monkeypatch.setattr(env, "step", lambda action: cache_step(env, action)) + + model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, batch_size=32) + model.learn(512) + + assert len(action_num_history) == len(mask_history) > 0 + # Make sure the masks had at least some False entries, if it was all True then the mask was disabled + assert any([not all(x) for x in mask_history]) + # When the agent takes action N from its action map, we need to have a look at the action mask and make sure that + # the N-th entry was True, meaning that it was a valid action at that step. + # This plucks out the mask history at step i, and at action entry a and checks that it's set to True, and this + # happens for all steps i in the episode + assert all(mask_history[i][a] for i, a in enumerate(action_num_history)) + monkeypatch.undo() + + +def test_ray_single_agent_action_masking(monkeypatch): + """Check that a Ray agent uses the action mask and never chooses invalid actions.""" + with open(CFG_PATH, "r") as f: + cfg = yaml.safe_load(f) + for agent in cfg["agents"]: + if agent["ref"] == "defender": + agent["agent_settings"]["flatten_obs"] = True + + # There's no simple way of capturing what the action mask was at every step, therefore we are mocking the step + # function to save the action mask and the agent's chosen action to a local variable. + old_step_method = PrimaiteRayEnv.step + action_num_history = [] + mask_history = [] + + def cache_step(self, action: int): + action_num_history.append(action) + obs, *_ = old_step_method(self, action) + action_mask = obs["action_mask"] + mask_history.append(action_mask) + return obs, *_ + + monkeypatch.setattr(PrimaiteRayEnv, "step", lambda *args, **kwargs: cache_step(*args, **kwargs)) + + # Configure Ray PPO to use action masking by using the ActionMaskingTorchRLModule + config = ( + PPOConfig() + .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True) + .environment(env=PrimaiteRayEnv, env_config=cfg, action_mask_key="action_mask") + .rl_module(rl_module_spec=SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule)) + .env_runners(num_env_runners=0) + .training(train_batch_size=128) + ) + algo = config.build() + algo.train() + + assert len(action_num_history) == len(mask_history) > 0 + # Make sure the masks had at least some False entries, if it was all True then the mask was disabled + assert any([not all(x) for x in mask_history]) + # When the agent takes action N from its action map, we need to have a look at the action mask and make sure that + # the N-th action was valid. + # The first step uses the action mask provided by the reset method, so we are only checking from the second step + # onward, that's why we need to use mask_history[:-1] and action_num_history[1:] + assert all(mask_history[:-1][i][a] for i, a in enumerate(action_num_history[1:])) + monkeypatch.undo() + + +def test_ray_multi_agent_action_masking(monkeypatch): + """Check that Ray agents never take invalid actions when using MARL.""" + with open(MARL_PATH, "r") as f: + cfg = yaml.safe_load(f) + + old_step_method = PrimaiteRayMARLEnv.step + action_num_history = {"defender_1": [], "defender_2": []} + mask_history = {"defender_1": [], "defender_2": []} + + def cache_step(self, actions: Dict[str, int]): + for agent_name, action in actions.items(): + action_num_history[agent_name].append(action) + obs, *_ = old_step_method(self, actions) + for ( + agent_name, + o, + ) in obs.items(): + mask_history[agent_name].append(o["action_mask"]) + return obs, *_ + + monkeypatch.setattr(PrimaiteRayMARLEnv, "step", lambda *args, **kwargs: cache_step(*args, **kwargs)) + + config = ( + PPOConfig() + .multi_agent( + policies={ + "defender_1", + "defender_2", + }, # These names are the same as the agents defined in the example config. + policy_mapping_fn=lambda agent_id, *args, **kwargs: agent_id, + ) + .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True) + .environment(env=PrimaiteRayMARLEnv, env_config=cfg, action_mask_key="action_mask") + .rl_module( + rl_module_spec=MultiAgentRLModuleSpec( + module_specs={ + "defender_1": SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule), + "defender_2": SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule), + } + ) + ) + .env_runners(num_env_runners=0) + .training(train_batch_size=128) + ) + algo = config.build() + algo.train() + + for agent_name in ["defender_1", "defender_2"]: + act_hist = action_num_history[agent_name] + mask_hist = mask_history[agent_name] + assert len(act_hist) == len(mask_hist) > 0 + assert any([not all(x) for x in mask_hist]) + assert all(mask_hist[:-1][i][a] for i, a in enumerate(act_hist[1:])) + monkeypatch.undo() diff --git a/tests/integration_tests/game_layer/actions/test_configure_actions.py b/tests/integration_tests/game_layer/actions/test_configure_actions.py index b7acc8a8..0c9ec6f0 100644 --- a/tests/integration_tests/game_layer/actions/test_configure_actions.py +++ b/tests/integration_tests/game_layer/actions/test_configure_actions.py @@ -99,7 +99,7 @@ class TestConfigureDatabaseAction: game.step() assert db_client.server_ip_address == old_ip - assert db_client.server_password is "admin123" + assert db_client.server_password == "admin123" class TestConfigureRansomwareScriptAction: