From 0c58c3969a901a8d508b2a429080905a311811ee Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 10 Jul 2024 13:46:30 +0100 Subject: [PATCH] 2623 - finish testing action mask --- .../simulator/system/services/service.py | 8 +- .../assets/configs/test_primaite_session.yaml | 27 ++- tests/conftest.py | 2 + .../test_agents_use_action_masks.py | 2 - .../test_rllib_multi_agent_environment.py | 3 - .../test_rllib_single_agent_environment.py | 4 - .../e2e_integration_tests/test_environment.py | 18 +- .../game_layer/test_action_mask.py | 161 ++++++++++++++++++ 8 files changed, 198 insertions(+), 27 deletions(-) create mode 100644 tests/integration_tests/game_layer/test_action_mask.py diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 8167a8a9..5adea6e7 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -92,6 +92,7 @@ class Service(IOSoftware): _is_service_running = Service._StateValidator(service=self, state=ServiceOperatingState.RUNNING) _is_service_stopped = Service._StateValidator(service=self, state=ServiceOperatingState.STOPPED) _is_service_paused = Service._StateValidator(service=self, state=ServiceOperatingState.PAUSED) + _is_service_disabled = Service._StateValidator(service=self, state=ServiceOperatingState.DISABLED) rm = super()._init_request_manager() rm.add_request( @@ -131,7 +132,12 @@ class Service(IOSoftware): ), ) rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable()))) - rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable()))) + rm.add_request( + "enable", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.enable()), validator=_is_service_disabled + ), + ) rm.add_request( "fix", RequestType( diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index 7c894ba0..c435fe44 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -243,25 +243,25 @@ agents: action: "NODE_FILE_SCAN" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 10: action: "NODE_FILE_CHECKHASH" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 11: action: "NODE_FILE_DELETE" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 12: action: "NODE_FILE_REPAIR" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 13: action: "NODE_SERVICE_FIX" @@ -272,22 +272,22 @@ agents: action: "NODE_FOLDER_SCAN" options: node_id: 2 - folder_id: 1 + folder_id: 0 15: action: "NODE_FOLDER_CHECKHASH" options: node_id: 2 - folder_id: 1 + folder_id: 0 16: action: "NODE_FOLDER_REPAIR" options: node_id: 2 - folder_id: 1 + folder_id: 0 17: action: "NODE_FOLDER_RESTORE" options: node_id: 2 - folder_id: 1 + folder_id: 0 18: action: "NODE_OS_SCAN" options: @@ -518,11 +518,22 @@ agents: nodes: - node_name: domain_controller - node_name: web_server + applications: + - application_name: DatabaseClient + services: + - service_name: WebServer - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db + services: + - service_name: DatabaseService - node_name: backup_server - node_name: security_suite - node_name: client_1 - node_name: client_2 + max_folders_per_node: 2 max_files_per_folder: 2 max_services_per_node: 2 diff --git a/tests/conftest.py b/tests/conftest.py index e36a2460..adfa7724 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Tuple import pytest +import ray import yaml from primaite import getLogger, PRIMAITE_PATHS @@ -29,6 +30,7 @@ from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.web_server.web_server import WebServer from tests import TEST_ASSETS_ROOT +ray.init(local_mode=True) ACTION_SPACE_NODE_VALUES = 1 ACTION_SPACE_NODE_ACTION_VALUES = 1 diff --git a/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py index 3efda71a..a299b913 100644 --- a/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py +++ b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py @@ -15,8 +15,6 @@ from primaite.session.environment import PrimaiteGymEnv from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv from tests import TEST_ASSETS_ROOT -init(local_mode=True) - CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml" MARL_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml" diff --git a/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py index 96ec799c..e015c33c 100644 --- a/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py +++ b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py @@ -16,8 +16,6 @@ def test_rllib_multi_agent_compatibility(): with open(MULTI_AGENT_PATH, "r") as f: cfg = yaml.safe_load(f) - ray.init() - config = ( PPOConfig() .environment(env=PrimaiteRayMARLEnv, env_config=cfg) @@ -39,4 +37,3 @@ def test_rllib_multi_agent_compatibility(): ), param_space=config, ).fit() - ray.shutdown() diff --git a/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py index d6cacfd2..a02a078c 100644 --- a/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py +++ b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py @@ -20,9 +20,6 @@ def test_rllib_single_agent_compatibility(): game = PrimaiteGame.from_config(cfg) - ray.shutdown() - ray.init() - env_config = {"game": game} config = { "env": PrimaiteRayEnv, @@ -41,4 +38,3 @@ def test_rllib_single_agent_compatibility(): assert save_file.exists() save_file.unlink() # clean up - ray.shutdown() diff --git a/tests/e2e_integration_tests/test_environment.py b/tests/e2e_integration_tests/test_environment.py index c8238aba..253bd396 100644 --- a/tests/e2e_integration_tests/test_environment.py +++ b/tests/e2e_integration_tests/test_environment.py @@ -65,25 +65,25 @@ class TestPrimaiteEnvironment: cfg = yaml.safe_load(f) env = PrimaiteRayMARLEnv(env_config=cfg) - assert set(env._agent_ids) == {"defender1", "defender2"} + assert set(env._agent_ids) == {"defender_1", "defender_2"} assert len(env.agents) == 2 - defender1 = env.agents["defender1"] - defender2 = env.agents["defender2"] - assert (num_actions_1 := len(defender1.action_manager.action_map)) == 54 - assert (num_actions_2 := len(defender2.action_manager.action_map)) == 38 + defender_1 = env.agents["defender_1"] + defender_2 = env.agents["defender_2"] + assert (num_actions_1 := len(defender_1.action_manager.action_map)) == 74 + assert (num_actions_2 := len(defender_2.action_manager.action_map)) == 74 # ensure we can run all valid actions without error for act_1 in range(num_actions_1): - env.step({"defender1": act_1, "defender2": 0}) + env.step({"defender_1": act_1, "defender_2": 0}) for act_2 in range(num_actions_2): - env.step({"defender1": 0, "defender2": act_2}) + env.step({"defender_1": 0, "defender_2": act_2}) # ensure we get error when taking an invalid action with pytest.raises(KeyError): - env.step({"defender1": num_actions_1, "defender2": 0}) + env.step({"defender_1": num_actions_1, "defender_2": 0}) with pytest.raises(KeyError): - env.step({"defender1": 0, "defender2": num_actions_2}) + env.step({"defender_1": 0, "defender_2": num_actions_2}) def test_error_thrown_on_bad_configuration(self): """Make sure we throw an error when the config is bad.""" diff --git a/tests/integration_tests/game_layer/test_action_mask.py b/tests/integration_tests/game_layer/test_action_mask.py new file mode 100644 index 00000000..64464724 --- /dev/null +++ b/tests/integration_tests/game_layer/test_action_mask.py @@ -0,0 +1,161 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from primaite.session.environment import PrimaiteGymEnv +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.host.host_node import HostNode +from primaite.simulator.system.services.service import ServiceOperatingState +from tests.conftest import TEST_ASSETS_ROOT + +CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml" + + +def test_mask_contents_correct(): + env = PrimaiteGymEnv(CFG_PATH) + game = env.game + sim = game.simulation + net = sim.network + mask = game.action_mask("defender") + agent = env.agent + node_list = agent.action_manager.node_names + action_map = agent.action_manager.action_map + + # CHECK NIC ENABLE/DISABLE ACTIONS + for action_num, action in action_map.items(): + mask = game.action_mask("defender") + act_type, act_params = action + + if act_type == "NODE_NIC_ENABLE": + node_name = node_list[act_params["node_id"]] + node_obj = net.get_node_by_hostname(node_name) + nic_obj = node_obj.network_interface[act_params["nic_id"] + 1] + assert nic_obj.enabled + assert not mask[action_num] + nic_obj.disable() + mask = game.action_mask("defender") + assert mask[action_num] + nic_obj.enable() + + if act_type == "NODE_NIC_DISABLE": + node_name = node_list[act_params["node_id"]] + node_obj = net.get_node_by_hostname(node_name) + nic_obj = node_obj.network_interface[act_params["nic_id"] + 1] + assert nic_obj.enabled + assert mask[action_num] + nic_obj.disable() + mask = game.action_mask("defender") + assert not mask[action_num] + nic_obj.enable() + + if act_type == "ROUTER_ACL_ADDRULE": + assert mask[action_num] + + if act_type == "ROUTER_ACL_REMOVERULE": + assert mask[action_num] + + if act_type == "NODE_RESET": + node_name = node_list[act_params["node_id"]] + node_obj = net.get_node_by_hostname(node_name) + assert node_obj.operating_state is NodeOperatingState.ON + assert mask[action_num] + node_obj.operating_state = NodeOperatingState.OFF + mask = game.action_mask("defender") + assert not mask[action_num] + node_obj.operating_state = NodeOperatingState.ON + + if act_type == "NODE_SHUTDOWN": + node_name = node_list[act_params["node_id"]] + node_obj = net.get_node_by_hostname(node_name) + assert node_obj.operating_state is NodeOperatingState.ON + assert mask[action_num] + node_obj.operating_state = NodeOperatingState.OFF + mask = game.action_mask("defender") + assert not mask[action_num] + node_obj.operating_state = NodeOperatingState.ON + + if act_type == "NODE_OS_SCAN": + node_name = node_list[act_params["node_id"]] + node_obj = net.get_node_by_hostname(node_name) + assert node_obj.operating_state is NodeOperatingState.ON + assert mask[action_num] + node_obj.operating_state = NodeOperatingState.OFF + mask = game.action_mask("defender") + assert not mask[action_num] + node_obj.operating_state = NodeOperatingState.ON + + if act_type == "NODE_STARTUP": + node_name = node_list[act_params["node_id"]] + node_obj = net.get_node_by_hostname(node_name) + assert node_obj.operating_state is NodeOperatingState.ON + assert not mask[action_num] + node_obj.operating_state = NodeOperatingState.OFF + mask = game.action_mask("defender") + assert mask[action_num] + node_obj.operating_state = NodeOperatingState.ON + + if act_type == "DONOTHING": + assert mask[action_num] + + if act_type == "NODE_SERVICE_DISABLE": + assert mask[action_num] + + if act_type in ["NODE_SERVICE_SCAN", "NODE_SERVICE_STOP", "NODE_SERVICE_PAUSE"]: + node_name = node_list[act_params["node_id"]] + service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]] + node_obj = net.get_node_by_hostname(node_name) + service_obj = node_obj.software_manager.software.get(service_name) + assert service_obj.operating_state is ServiceOperatingState.RUNNING + assert mask[action_num] + service_obj.operating_state = ServiceOperatingState.DISABLED + mask = game.action_mask("defender") + assert not mask[action_num] + service_obj.operating_state = ServiceOperatingState.RUNNING + + if act_type == "NODE_SERVICE_RESUME": + node_name = node_list[act_params["node_id"]] + service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]] + node_obj = net.get_node_by_hostname(node_name) + service_obj = node_obj.software_manager.software.get(service_name) + assert service_obj.operating_state is ServiceOperatingState.RUNNING + assert not mask[action_num] + service_obj.operating_state = ServiceOperatingState.PAUSED + mask = game.action_mask("defender") + assert mask[action_num] + service_obj.operating_state = ServiceOperatingState.RUNNING + + if act_type == "NODE_SERVICE_START": + node_name = node_list[act_params["node_id"]] + service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]] + node_obj = net.get_node_by_hostname(node_name) + service_obj = node_obj.software_manager.software.get(service_name) + assert service_obj.operating_state is ServiceOperatingState.RUNNING + assert not mask[action_num] + service_obj.operating_state = ServiceOperatingState.STOPPED + mask = game.action_mask("defender") + assert mask[action_num] + service_obj.operating_state = ServiceOperatingState.RUNNING + + if act_type == "NODE_SERVICE_ENABLE": + node_name = node_list[act_params["node_id"]] + service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]] + node_obj = net.get_node_by_hostname(node_name) + service_obj = node_obj.software_manager.software.get(service_name) + assert service_obj.operating_state is ServiceOperatingState.RUNNING + assert not mask[action_num] + service_obj.operating_state = ServiceOperatingState.DISABLED + mask = game.action_mask("defender") + assert mask[action_num] + service_obj.operating_state = ServiceOperatingState.RUNNING + + if act_type in ["NODE_FILE_SCAN", "NODE_FILE_CHECKHASH", "NODE_FILE_DELETE"]: + node_name = node_list[act_params["node_id"]] + folder_name = agent.action_manager.get_folder_name_by_idx(act_params["node_id"], act_params["folder_id"]) + file_name = agent.action_manager.get_file_name_by_idx( + act_params["node_id"], act_params["folder_id"], act_params["file_id"] + ) + node_obj = net.get_node_by_hostname(node_name) + file_obj = node_obj.file_system.get_file(folder_name, file_name, include_deleted=True) + assert not file_obj.deleted + assert mask[action_num] + service_obj.operating_state = ServiceOperatingState.DISABLED + mask = game.action_mask("defender") + assert mask[action_num] + service_obj.operating_state = ServiceOperatingState.RUNNING