2623 - finish testing action mask
This commit is contained in:
@@ -92,6 +92,7 @@ class Service(IOSoftware):
|
||||
_is_service_running = Service._StateValidator(service=self, state=ServiceOperatingState.RUNNING)
|
||||
_is_service_stopped = Service._StateValidator(service=self, state=ServiceOperatingState.STOPPED)
|
||||
_is_service_paused = Service._StateValidator(service=self, state=ServiceOperatingState.PAUSED)
|
||||
_is_service_disabled = Service._StateValidator(service=self, state=ServiceOperatingState.DISABLED)
|
||||
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request(
|
||||
@@ -131,7 +132,12 @@ class Service(IOSoftware):
|
||||
),
|
||||
)
|
||||
rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable())))
|
||||
rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable())))
|
||||
rm.add_request(
|
||||
"enable",
|
||||
RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self.enable()), validator=_is_service_disabled
|
||||
),
|
||||
)
|
||||
rm.add_request(
|
||||
"fix",
|
||||
RequestType(
|
||||
|
||||
@@ -243,25 +243,25 @@ agents:
|
||||
action: "NODE_FILE_SCAN"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
10:
|
||||
action: "NODE_FILE_CHECKHASH"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
11:
|
||||
action: "NODE_FILE_DELETE"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
12:
|
||||
action: "NODE_FILE_REPAIR"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
13:
|
||||
action: "NODE_SERVICE_FIX"
|
||||
@@ -272,22 +272,22 @@ agents:
|
||||
action: "NODE_FOLDER_SCAN"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
15:
|
||||
action: "NODE_FOLDER_CHECKHASH"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
16:
|
||||
action: "NODE_FOLDER_REPAIR"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
17:
|
||||
action: "NODE_FOLDER_RESTORE"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
18:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
@@ -518,11 +518,22 @@ agents:
|
||||
nodes:
|
||||
- node_name: domain_controller
|
||||
- node_name: web_server
|
||||
applications:
|
||||
- application_name: DatabaseClient
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- node_name: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
services:
|
||||
- service_name: DatabaseService
|
||||
- node_name: backup_server
|
||||
- node_name: security_suite
|
||||
- node_name: client_1
|
||||
- node_name: client_2
|
||||
|
||||
max_folders_per_node: 2
|
||||
max_files_per_folder: 2
|
||||
max_services_per_node: 2
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
@@ -29,6 +30,7 @@ from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
ray.init(local_mode=True)
|
||||
ACTION_SPACE_NODE_VALUES = 1
|
||||
ACTION_SPACE_NODE_ACTION_VALUES = 1
|
||||
|
||||
|
||||
@@ -15,8 +15,6 @@ from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
init(local_mode=True)
|
||||
|
||||
CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"
|
||||
MARL_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
|
||||
|
||||
|
||||
@@ -16,8 +16,6 @@ def test_rllib_multi_agent_compatibility():
|
||||
with open(MULTI_AGENT_PATH, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
ray.init()
|
||||
|
||||
config = (
|
||||
PPOConfig()
|
||||
.environment(env=PrimaiteRayMARLEnv, env_config=cfg)
|
||||
@@ -39,4 +37,3 @@ def test_rllib_multi_agent_compatibility():
|
||||
),
|
||||
param_space=config,
|
||||
).fit()
|
||||
ray.shutdown()
|
||||
|
||||
@@ -20,9 +20,6 @@ def test_rllib_single_agent_compatibility():
|
||||
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
|
||||
ray.shutdown()
|
||||
ray.init()
|
||||
|
||||
env_config = {"game": game}
|
||||
config = {
|
||||
"env": PrimaiteRayEnv,
|
||||
@@ -41,4 +38,3 @@ def test_rllib_single_agent_compatibility():
|
||||
assert save_file.exists()
|
||||
|
||||
save_file.unlink() # clean up
|
||||
ray.shutdown()
|
||||
|
||||
@@ -65,25 +65,25 @@ class TestPrimaiteEnvironment:
|
||||
cfg = yaml.safe_load(f)
|
||||
env = PrimaiteRayMARLEnv(env_config=cfg)
|
||||
|
||||
assert set(env._agent_ids) == {"defender1", "defender2"}
|
||||
assert set(env._agent_ids) == {"defender_1", "defender_2"}
|
||||
|
||||
assert len(env.agents) == 2
|
||||
defender1 = env.agents["defender1"]
|
||||
defender2 = env.agents["defender2"]
|
||||
assert (num_actions_1 := len(defender1.action_manager.action_map)) == 54
|
||||
assert (num_actions_2 := len(defender2.action_manager.action_map)) == 38
|
||||
defender_1 = env.agents["defender_1"]
|
||||
defender_2 = env.agents["defender_2"]
|
||||
assert (num_actions_1 := len(defender_1.action_manager.action_map)) == 74
|
||||
assert (num_actions_2 := len(defender_2.action_manager.action_map)) == 74
|
||||
|
||||
# ensure we can run all valid actions without error
|
||||
for act_1 in range(num_actions_1):
|
||||
env.step({"defender1": act_1, "defender2": 0})
|
||||
env.step({"defender_1": act_1, "defender_2": 0})
|
||||
for act_2 in range(num_actions_2):
|
||||
env.step({"defender1": 0, "defender2": act_2})
|
||||
env.step({"defender_1": 0, "defender_2": act_2})
|
||||
|
||||
# ensure we get error when taking an invalid action
|
||||
with pytest.raises(KeyError):
|
||||
env.step({"defender1": num_actions_1, "defender2": 0})
|
||||
env.step({"defender_1": num_actions_1, "defender_2": 0})
|
||||
with pytest.raises(KeyError):
|
||||
env.step({"defender1": 0, "defender2": num_actions_2})
|
||||
env.step({"defender_1": 0, "defender_2": num_actions_2})
|
||||
|
||||
def test_error_thrown_on_bad_configuration(self):
|
||||
"""Make sure we throw an error when the config is bad."""
|
||||
|
||||
161
tests/integration_tests/game_layer/test_action_mask.py
Normal file
161
tests/integration_tests/game_layer/test_action_mask.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
from tests.conftest import TEST_ASSETS_ROOT
|
||||
|
||||
CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"
|
||||
|
||||
|
||||
def test_mask_contents_correct():
|
||||
env = PrimaiteGymEnv(CFG_PATH)
|
||||
game = env.game
|
||||
sim = game.simulation
|
||||
net = sim.network
|
||||
mask = game.action_mask("defender")
|
||||
agent = env.agent
|
||||
node_list = agent.action_manager.node_names
|
||||
action_map = agent.action_manager.action_map
|
||||
|
||||
# CHECK NIC ENABLE/DISABLE ACTIONS
|
||||
for action_num, action in action_map.items():
|
||||
mask = game.action_mask("defender")
|
||||
act_type, act_params = action
|
||||
|
||||
if act_type == "NODE_NIC_ENABLE":
|
||||
node_name = node_list[act_params["node_id"]]
|
||||
node_obj = net.get_node_by_hostname(node_name)
|
||||
nic_obj = node_obj.network_interface[act_params["nic_id"] + 1]
|
||||
assert nic_obj.enabled
|
||||
assert not mask[action_num]
|
||||
nic_obj.disable()
|
||||
mask = game.action_mask("defender")
|
||||
assert mask[action_num]
|
||||
nic_obj.enable()
|
||||
|
||||
if act_type == "NODE_NIC_DISABLE":
|
||||
node_name = node_list[act_params["node_id"]]
|
||||
node_obj = net.get_node_by_hostname(node_name)
|
||||
nic_obj = node_obj.network_interface[act_params["nic_id"] + 1]
|
||||
assert nic_obj.enabled
|
||||
assert mask[action_num]
|
||||
nic_obj.disable()
|
||||
mask = game.action_mask("defender")
|
||||
assert not mask[action_num]
|
||||
nic_obj.enable()
|
||||
|
||||
if act_type == "ROUTER_ACL_ADDRULE":
|
||||
assert mask[action_num]
|
||||
|
||||
if act_type == "ROUTER_ACL_REMOVERULE":
|
||||
assert mask[action_num]
|
||||
|
||||
if act_type == "NODE_RESET":
|
||||
node_name = node_list[act_params["node_id"]]
|
||||
node_obj = net.get_node_by_hostname(node_name)
|
||||
assert node_obj.operating_state is NodeOperatingState.ON
|
||||
assert mask[action_num]
|
||||
node_obj.operating_state = NodeOperatingState.OFF
|
||||
mask = game.action_mask("defender")
|
||||
assert not mask[action_num]
|
||||
node_obj.operating_state = NodeOperatingState.ON
|
||||
|
||||
if act_type == "NODE_SHUTDOWN":
|
||||
node_name = node_list[act_params["node_id"]]
|
||||
node_obj = net.get_node_by_hostname(node_name)
|
||||
assert node_obj.operating_state is NodeOperatingState.ON
|
||||
assert mask[action_num]
|
||||
node_obj.operating_state = NodeOperatingState.OFF
|
||||
mask = game.action_mask("defender")
|
||||
assert not mask[action_num]
|
||||
node_obj.operating_state = NodeOperatingState.ON
|
||||
|
||||
if act_type == "NODE_OS_SCAN":
|
||||
node_name = node_list[act_params["node_id"]]
|
||||
node_obj = net.get_node_by_hostname(node_name)
|
||||
assert node_obj.operating_state is NodeOperatingState.ON
|
||||
assert mask[action_num]
|
||||
node_obj.operating_state = NodeOperatingState.OFF
|
||||
mask = game.action_mask("defender")
|
||||
assert not mask[action_num]
|
||||
node_obj.operating_state = NodeOperatingState.ON
|
||||
|
||||
if act_type == "NODE_STARTUP":
|
||||
node_name = node_list[act_params["node_id"]]
|
||||
node_obj = net.get_node_by_hostname(node_name)
|
||||
assert node_obj.operating_state is NodeOperatingState.ON
|
||||
assert not mask[action_num]
|
||||
node_obj.operating_state = NodeOperatingState.OFF
|
||||
mask = game.action_mask("defender")
|
||||
assert mask[action_num]
|
||||
node_obj.operating_state = NodeOperatingState.ON
|
||||
|
||||
if act_type == "DONOTHING":
|
||||
assert mask[action_num]
|
||||
|
||||
if act_type == "NODE_SERVICE_DISABLE":
|
||||
assert mask[action_num]
|
||||
|
||||
if act_type in ["NODE_SERVICE_SCAN", "NODE_SERVICE_STOP", "NODE_SERVICE_PAUSE"]:
|
||||
node_name = node_list[act_params["node_id"]]
|
||||
service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]]
|
||||
node_obj = net.get_node_by_hostname(node_name)
|
||||
service_obj = node_obj.software_manager.software.get(service_name)
|
||||
assert service_obj.operating_state is ServiceOperatingState.RUNNING
|
||||
assert mask[action_num]
|
||||
service_obj.operating_state = ServiceOperatingState.DISABLED
|
||||
mask = game.action_mask("defender")
|
||||
assert not mask[action_num]
|
||||
service_obj.operating_state = ServiceOperatingState.RUNNING
|
||||
|
||||
if act_type == "NODE_SERVICE_RESUME":
|
||||
node_name = node_list[act_params["node_id"]]
|
||||
service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]]
|
||||
node_obj = net.get_node_by_hostname(node_name)
|
||||
service_obj = node_obj.software_manager.software.get(service_name)
|
||||
assert service_obj.operating_state is ServiceOperatingState.RUNNING
|
||||
assert not mask[action_num]
|
||||
service_obj.operating_state = ServiceOperatingState.PAUSED
|
||||
mask = game.action_mask("defender")
|
||||
assert mask[action_num]
|
||||
service_obj.operating_state = ServiceOperatingState.RUNNING
|
||||
|
||||
if act_type == "NODE_SERVICE_START":
|
||||
node_name = node_list[act_params["node_id"]]
|
||||
service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]]
|
||||
node_obj = net.get_node_by_hostname(node_name)
|
||||
service_obj = node_obj.software_manager.software.get(service_name)
|
||||
assert service_obj.operating_state is ServiceOperatingState.RUNNING
|
||||
assert not mask[action_num]
|
||||
service_obj.operating_state = ServiceOperatingState.STOPPED
|
||||
mask = game.action_mask("defender")
|
||||
assert mask[action_num]
|
||||
service_obj.operating_state = ServiceOperatingState.RUNNING
|
||||
|
||||
if act_type == "NODE_SERVICE_ENABLE":
|
||||
node_name = node_list[act_params["node_id"]]
|
||||
service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]]
|
||||
node_obj = net.get_node_by_hostname(node_name)
|
||||
service_obj = node_obj.software_manager.software.get(service_name)
|
||||
assert service_obj.operating_state is ServiceOperatingState.RUNNING
|
||||
assert not mask[action_num]
|
||||
service_obj.operating_state = ServiceOperatingState.DISABLED
|
||||
mask = game.action_mask("defender")
|
||||
assert mask[action_num]
|
||||
service_obj.operating_state = ServiceOperatingState.RUNNING
|
||||
|
||||
if act_type in ["NODE_FILE_SCAN", "NODE_FILE_CHECKHASH", "NODE_FILE_DELETE"]:
|
||||
node_name = node_list[act_params["node_id"]]
|
||||
folder_name = agent.action_manager.get_folder_name_by_idx(act_params["node_id"], act_params["folder_id"])
|
||||
file_name = agent.action_manager.get_file_name_by_idx(
|
||||
act_params["node_id"], act_params["folder_id"], act_params["file_id"]
|
||||
)
|
||||
node_obj = net.get_node_by_hostname(node_name)
|
||||
file_obj = node_obj.file_system.get_file(folder_name, file_name, include_deleted=True)
|
||||
assert not file_obj.deleted
|
||||
assert mask[action_num]
|
||||
service_obj.operating_state = ServiceOperatingState.DISABLED
|
||||
mask = game.action_mask("defender")
|
||||
assert mask[action_num]
|
||||
service_obj.operating_state = ServiceOperatingState.RUNNING
|
||||
Reference in New Issue
Block a user