Merged PR 454: Action masking support

## Summary
* Add action masking to environments
* add notebook demonstrating action masking

## Test process
* E2E test for SB3, Ray SARL, and Ray MARL
* integration test to check if the contents of the action mask change accordingly when statuses of components change

## Checklist
- [x] PR is linked to a **work item**
- [x] **acceptance criteria** of linked ticket are met
- [x] performed **self-review** of the code
- [x] written **tests** for any new functionality added with this PR
- [x] updated the **documentation** if this PR changes or adds functionality
- [x] written/updated **design docs** if this PR implements new functionality
  - https://dev.azure.com/ma-dev-uk/PrimAITE/_wiki/wikis/PrimAITE.wiki/703/Action-Masking
- [ ] updated the **change log**
- [x] ran **pre-commit** checks for code style
- [x] attended to any **TO-DOs** left in the code

Related work items: #2623
This commit is contained in:
Marek Wolan
2024-07-12 08:14:47 +00:00
25 changed files with 1525 additions and 350 deletions

View File

@@ -123,6 +123,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
source/environment
source/customising_scenarios
source/varying_config_files
source/action_masking
.. toctree::
:caption: Notebooks:

View File

@@ -0,0 +1,80 @@
.. only:: comment
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
Action Masking
**************
The PrimAITE simulation is able to provide action masks in the environment output. These action masks let the agents know
about which actions are invalid based on the current environment state. For instance, it's not possible to install
software on a node that is turned off. Therefore, if an agent has a NODE_SOFTWARE_INSTALL in it's action map for that node,
the action mask will show `0` in the corresponding entry.
Configuration
=============
Action masking is supported for agents that use the `ProxyAgent` class (the class used for connecting to RL algorithms).
In order to use action masking, set the agent_settings.action_masking parameter to True in the config file.
Masking Logic
=============
The following logic is applied:
* **DONOTHING** : Always possible
* **NODE_HOST_SERVICE_SCAN** : Node is on. Service is running.
* **NODE_HOST_SERVICE_STOP** : Node is on. Service is running.
* **NODE_HOST_SERVICE_START** : Node is on. Service is stopped.
* **NODE_HOST_SERVICE_PAUSE** : Node is on. Service is running.
* **NODE_HOST_SERVICE_RESUME** : Node is on. Service is paused.
* **NODE_HOST_SERVICE_RESTART** : Node is on. Service is running.
* **NODE_HOST_SERVICE_DISABLE** : Node is on.
* **NODE_HOST_SERVICE_ENABLE** : Node is on. Service is disabled.
* **NODE_HOST_SERVICE_FIX** : Node is on. Service is running.
* **NODE_HOST_APPLICATION_EXECUTE** : Node is on.
* **NODE_HOST_APPLICATION_SCAN** : Node is on. Application is running.
* **NODE_HOST_APPLICATION_CLOSE** : Node is on. Application is running.
* **NODE_HOST_APPLICATION_FIX** : Node is on. Application is running.
* **NODE_HOST_APPLICATION_INSTALL** : Node is on.
* **NODE_HOST_APPLICATION_REMOVE** : Node is on.
* **NODE_HOST_FILE_SCAN** : Node is on. File exists. File not deleted.
* **NODE_HOST_FILE_CREATE** : Node is on.
* **NODE_HOST_FILE_CHECKHASH** : Node is on. File exists. File not deleted.
* **NODE_HOST_FILE_DELETE** : Node is on. File exists.
* **NODE_HOST_FILE_REPAIR** : Node is on. File exists. File not deleted.
* **NODE_HOST_FILE_RESTORE** : Node is on. File exists. File is deleted.
* **NODE_HOST_FILE_CORRUPT** : Node is on. File exists. File not deleted.
* **NODE_HOST_FILE_ACCESS** : Node is on. File exists. File not deleted.
* **NODE_HOST_FOLDER_CREATE** : Node is on.
* **NODE_HOST_FOLDER_SCAN** : Node is on. Folder exists. Folder not deleted.
* **NODE_HOST_FOLDER_CHECKHASH** : Node is on. Folder exists. Folder not deleted.
* **NODE_HOST_FOLDER_REPAIR** : Node is on. Folder exists. Folder not deleted.
* **NODE_HOST_FOLDER_RESTORE** : Node is on. Folder exists. Folder is deleted.
* **NODE_HOST_OS_SCAN** : Node is on.
* **NODE_HOST_NIC_ENABLE** : NIC is disabled. Node is on.
* **NODE_HOST_NIC_DISABLE** : NIC is enabled. Node is on.
* **NODE_HOST_SHUTDOWN** : Node is on.
* **NODE_HOST_STARTUP** : Node is off.
* **NODE_HOST_RESET** : Node is on.
* **NODE_HOST_NMAP_PING_SCAN** : Node is on.
* **NODE_HOST_NMAP_PORT_SCAN** : Node is on.
* **NODE_HOST_NMAP_NETWORK_SERVICE_RECON** : Node is on.
* **NODE_ROUTER_PORT_ENABLE** : Router is on.
* **NODE_ROUTER_PORT_DISABLE** : Router is on.
* **NODE_ROUTER_ACL_ADDRULE** : Router is on.
* **NODE_ROUTER_ACL_REMOVERULE** : Router is on.
* **NODE_FIREWALL_PORT_ENABLE** : Firewall is on.
* **NODE_FIREWALL_PORT_DISABLE** : Firewall is on.
* **NODE_FIREWALL_ACL_ADDRULE** : Firewall is on.
* **NODE_FIREWALL_ACL_REMOVERULE** : Firewall is on.
Mechanism
=========
The environment iterates over the RL agent's ``action_map`` and generates the corresponding simulator request string.
It uses the ``RequestManager.check_valid()`` method to invoke the relevant ``RequestPermissionValidator`` without
actually running the request on the simulation.
Current Limitations
===================
Currently, action masking only considers whether the action as a whole is possible, it doesn't verify that the exact
parameter combination passed to the action make sense in the current context. For instance, if ACL rule 3 on router_1 is
already populated, the action for adding another rule at position 3 will be available regardless, as long as that router
is turned on. This will never block valid actions. It will just occasionally allow invalid actions.

View File

@@ -55,6 +55,7 @@ rl = [
"ray[rllib] >= 2.20.0, < 3",
"tensorflow==2.12.0",
"stable-baselines3[extra]==2.1.0",
"sb3-contrib==2.1.0",
]
dev = [
"build==0.10.0",

View File

@@ -741,6 +741,7 @@ agents:
agent_settings:
flatten_obs: true
action_masking: true

View File

@@ -733,6 +733,7 @@ agents:
agent_settings:
flatten_obs: true
action_masking: true
- ref: defender_2
team: BLUE
@@ -1316,6 +1317,7 @@ agents:
agent_settings:
flatten_obs: true
action_masking: true

View File

@@ -44,3 +44,18 @@ def data_manipulation_config_path() -> Path:
_LOGGER.error(msg)
raise FileNotFoundError(msg)
return path
def data_manipulation_marl_config_path() -> Path:
"""
Get the path to the MARL example config.
:return: Path to yaml config file for the MARL scenario.
:rtype: Path
"""
path = _EXAMPLE_CFG / "data_manipulation_marl.yaml"
if not path.exists():
msg = f"Example config does not exist: {path}. Have you run `primaite setup`?"
_LOGGER.error(msg)
raise FileNotFoundError(msg)
return path

View File

@@ -49,7 +49,7 @@ class AbstractAction(ABC):
objects."""
@abstractmethod
def form_request(self) -> List[str]:
def form_request(self) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return []
@@ -67,7 +67,7 @@ class DoNothingAction(AbstractAction):
# i.e. a choice between one option. To make enumerating this action easier, we are adding a 'dummy' paramter
# with one option. This just aids the Action Manager to enumerate all possibilities.
def form_request(self, **kwargs) -> List[str]:
def form_request(self, **kwargs) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return ["do_nothing"]
@@ -86,7 +86,7 @@ class NodeServiceAbstractAction(AbstractAction):
self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services}
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, service_id: int) -> List[str]:
def form_request(self, node_id: int, service_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
service_name = self.manager.get_service_name_by_idx(node_id, service_id)
@@ -181,7 +181,7 @@ class NodeApplicationAbstractAction(AbstractAction):
self.shape: Dict[str, int] = {"node_id": num_nodes, "application_id": num_applications}
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, application_id: int) -> List[str]:
def form_request(self, node_id: int, application_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
application_name = self.manager.get_application_name_by_idx(node_id, application_id)
@@ -229,7 +229,7 @@ class NodeApplicationInstallAction(AbstractAction):
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes}
def form_request(self, node_id: int, application_name: str) -> List[str]:
def form_request(self, node_id: int, application_name: str) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None:
@@ -324,7 +324,7 @@ class NodeApplicationRemoveAction(AbstractAction):
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes}
def form_request(self, node_id: int, application_name: str) -> List[str]:
def form_request(self, node_id: int, application_name: str) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None:
@@ -346,7 +346,7 @@ class NodeFolderAbstractAction(AbstractAction):
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders}
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, folder_id: int) -> List[str]:
def form_request(self, node_id: int, folder_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id)
@@ -394,7 +394,9 @@ class NodeFileCreateAction(AbstractAction):
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "create"
def form_request(self, node_id: int, folder_name: str, file_name: str, force: Optional[bool] = False) -> List[str]:
def form_request(
self, node_id: int, folder_name: str, file_name: str, force: Optional[bool] = False
) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None or folder_name is None or file_name is None:
@@ -409,7 +411,7 @@ class NodeFolderCreateAction(AbstractAction):
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "create"
def form_request(self, node_id: int, folder_name: str) -> List[str]:
def form_request(self, node_id: int, folder_name: str) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None or folder_name is None:
@@ -430,7 +432,7 @@ class NodeFileAbstractAction(AbstractAction):
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders, "file_id": num_files}
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
def form_request(self, node_id: int, folder_id: int, file_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id)
@@ -463,7 +465,7 @@ class NodeFileDeleteAction(NodeFileAbstractAction):
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb: str = "delete"
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
def form_request(self, node_id: int, folder_id: int, file_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id)
@@ -504,7 +506,7 @@ class NodeFileAccessAction(AbstractAction):
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "access"
def form_request(self, node_id: int, folder_name: str, file_name: str) -> List[str]:
def form_request(self, node_id: int, folder_name: str, file_name: str) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None or folder_name is None or file_name is None:
@@ -525,7 +527,7 @@ class NodeAbstractAction(AbstractAction):
self.shape: Dict[str, int] = {"node_id": num_nodes}
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int) -> List[str]:
def form_request(self, node_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
return ["network", "node", node_name, self.verb]
@@ -740,7 +742,7 @@ class RouterACLRemoveRuleAction(AbstractAction):
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"position": max_acl_rules}
def form_request(self, target_router: str, position: int) -> List[str]:
def form_request(self, target_router: str, position: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return ["network", "node", target_router, "acl", "remove_rule", position]
@@ -923,7 +925,7 @@ class HostNICAbstractAction(AbstractAction):
self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, nic_id: int) -> List[str]:
def form_request(self, node_id: int, nic_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_idx=node_id)
nic_num = self.manager.get_nic_num_by_idx(node_idx=node_id, nic_idx=nic_id)
@@ -960,7 +962,7 @@ class NetworkPortEnableAction(AbstractAction):
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"port_id": max_nics_per_node}
def form_request(self, target_nodename: str, port_id: int) -> List[str]:
def form_request(self, target_nodename: str, port_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if target_nodename is None or port_id is None:
return ["do_nothing"]
@@ -979,7 +981,7 @@ class NetworkPortDisableAction(AbstractAction):
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"port_id": max_nics_per_node}
def form_request(self, target_nodename: str, port_id: int) -> List[str]:
def form_request(self, target_nodename: str, port_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if target_nodename is None or port_id is None:
return ["do_nothing"]
@@ -1315,7 +1317,7 @@ class ActionManager:
act_identifier, act_options = self.action_map[action]
return act_identifier, act_options
def form_request(self, action_identifier: str, action_options: Dict) -> List[str]:
def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat:
"""Take action in CAOS format and use the execution definition to change it into PrimAITE request format."""
act_obj = self.actions[action_identifier]
return act_obj.form_request(**action_options)

View File

@@ -70,6 +70,8 @@ class AgentSettings(BaseModel):
"Configuration for when an agent begins performing it's actions"
flatten_obs: bool = True
"Whether to flatten the observation space before passing it to the agent. True by default."
action_masking: bool = False
"Whether to return action masks at each step."
@classmethod
def from_config(cls, config: Optional[Dict]) -> "AgentSettings":
@@ -207,6 +209,7 @@ class ProxyAgent(AbstractAgent):
)
self.most_recent_action: ActType
self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
self.action_masking: bool = agent_settings.action_masking if agent_settings else False
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""

View File

@@ -3,6 +3,7 @@
from ipaddress import IPv4Address
from typing import Dict, List, Optional
import numpy as np
from pydantic import BaseModel, ConfigDict
from primaite import DEFAULT_BANDWIDTH, getLogger
@@ -202,6 +203,23 @@ class PrimaiteGame:
return True
return False
def action_mask(self, agent_name: str) -> np.ndarray:
"""
Return the action mask for the agent.
This is a boolean list corresponding to the agent's action space. A False entry means this action cannot be
performed during this step.
:return: Action mask
:rtype: List[bool]
"""
agent = self.agents[agent_name]
mask = [True] * len(agent.action_manager.action_map)
for i, action in agent.action_manager.action_map.items():
request = agent.action_manager.form_request(action_identifier=action[0], action_options=action[1])
mask[i] = self.simulation._request_manager.check_valid(request, {})
return np.asarray(mask, dtype=np.int8)
def close(self) -> None:
"""Close the game, this will close the simulation."""
return NotImplemented

View File

@@ -0,0 +1,218 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Action Masking\n",
"\n",
"PrimAITE environments support action masking. The action mask shows which of the agent's actions are applicable with the current environment state. For example, a node can only be turned on if it is currently turned off."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.session.environment import PrimaiteGymEnv\n",
"from primaite.config.load import data_manipulation_config_path\n",
"from prettytable import PrettyTable\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env = PrimaiteGymEnv(data_manipulation_config_path())\n",
"env.action_masking = True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The action mask is a list of booleans that specifies whether each action in the agent's action map is currently possible. Demonstrated here:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"act_table = PrettyTable((\"number\", \"action\", \"parameters\", \"mask\"))\n",
"mask = env.action_masks()\n",
"actions = env.agent.action_manager.action_map\n",
"max_str_len = 70\n",
"for act,mask in zip(actions.items(), mask):\n",
" act_num, act_data = act\n",
" act_type, act_params = act_data\n",
" act_params = s if len(s:=str(act_params))<max_str_len else f\"{s[:max_str_len-3]}...\"\n",
" act_table.add_row((act_num, act_type, act_params, mask))\n",
"print(act_table)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Action masking for Stable Baselines3 agents\n",
"SB3 agents automatically use the action_masks method during the training loop"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sb3_contrib import MaskablePPO\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = MaskablePPO(\"MlpPolicy\", env, gamma=0.4, seed=32)\n",
"model.learn(1024)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Action masking for Ray RLLib agents\n",
"Ray uses a different API to obtain action masks, but this is handled by the PrimaiteRayEnv and PrimaiteRayMarlEnv classes"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
"from ray.rllib.algorithms.ppo import PPOConfig\n",
"import yaml\n",
"from ray import air, tune\n",
"from ray.rllib.examples.rl_modules.classes.action_masking_rlm import ActionMaskingTorchRLModule\n",
"from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(data_manipulation_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"for agent in cfg['agents']:\n",
" if agent[\"ref\"] == \"defender\":\n",
" agent['agent_settings']['flatten_obs'] = True\n",
"env_config = cfg\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config = (\n",
" PPOConfig()\n",
" .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)\n",
" .environment(env=PrimaiteRayEnv, env_config=cfg, action_mask_key=\"action_mask\")\n",
" .rl_module(rl_module_spec=SingleAgentRLModuleSpec(module_class = ActionMaskingTorchRLModule))\n",
" .env_runners(num_env_runners=0)\n",
" .training(train_batch_size=128)\n",
")\n",
"algo = config.build()\n",
"for i in range(2):\n",
" results = algo.train()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Action masking with MARL in Ray RLLib\n",
"Each agent has their own action mask, this is useful if the agents have different action spaces."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec\n",
"from primaite.session.ray_envs import PrimaiteRayMARLEnv\n",
"from primaite.config.load import data_manipulation_marl_config_path"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(data_manipulation_marl_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"env_config = cfg\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config = (\n",
" PPOConfig()\n",
" .multi_agent(\n",
" policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n",
" policy_mapping_fn=lambda agent_id, *args, **kwargs: agent_id,\n",
" )\n",
" .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)\n",
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg, action_mask_key=\"action_mask\")\n",
" .rl_module(rl_module_spec=MultiAgentRLModuleSpec(module_specs={\n",
" \"defender_1\":SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),\n",
" \"defender_2\":SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),\n",
" }))\n",
" .env_runners(num_env_runners=0)\n",
" .training(train_batch_size=128)\n",
")\n",
"algo = config.build()\n",
"for i in range(2):\n",
" results = algo.train()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -4,6 +4,7 @@ from os import PathLike
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
import gymnasium
import numpy as np
from gymnasium.core import ActType, ObsType
from primaite import getLogger
@@ -41,6 +42,21 @@ class PrimaiteGymEnv(gymnasium.Env):
self.total_reward_per_episode: Dict[int, float] = {}
"""Average rewards of agents per episode."""
def action_masks(self) -> np.ndarray:
"""
Return the action mask for the agent.
This is a boolean list corresponding to the agent's action space. A False entry means this action cannot be
performed during this step.
:return: Action mask
:rtype: List[bool]
"""
if not self.agent.action_masking:
return np.asarray([True] * len(self.agent.action_manager.action_map))
else:
return self.game.action_mask(self._agent_name)
@property
def agent(self) -> ProxyAgent:
"""Grab a fresh reference to the agent object because it will be reinstantiated each episode."""

View File

@@ -3,6 +3,7 @@ import json
from typing import Dict, SupportsFloat, Tuple
import gymnasium
from gymnasium import spaces
from gymnasium.core import ActType, ObsType
from ray.rllib.env.multi_agent_env import MultiAgentEnv
@@ -38,15 +39,19 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
self.terminateds = set()
self.truncateds = set()
self.observation_space = gymnasium.spaces.Dict(
{
name: gymnasium.spaces.flatten_space(agent.observation_manager.space)
for name, agent in self.agents.items()
}
)
self.action_space = gymnasium.spaces.Dict(
{name: agent.action_manager.space for name, agent in self.agents.items()}
self.observation_space = spaces.Dict(
{name: spaces.flatten_space(agent.observation_manager.space) for name, agent in self.agents.items()}
)
for agent_name in self._agent_ids:
agent = self.game.rl_agents[agent_name]
if agent.action_masking:
self.observation_space[agent_name] = spaces.Dict(
{
"action_mask": spaces.MultiBinary(agent.action_manager.space.n),
"observations": self.observation_space[agent_name],
}
)
self.action_space = spaces.Dict({name: agent.action_manager.space for name, agent in self.agents.items()})
self._obs_space_in_preferred_format = True
self._action_space_in_preferred_format = True
super().__init__()
@@ -131,13 +136,17 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
def _get_obs(self) -> Dict[str, ObsType]:
"""Return the current observation."""
obs = {}
all_obs = {}
for agent_name in self._agent_ids:
agent = self.game.rl_agents[agent_name]
unflat_space = agent.observation_manager.space
unflat_obs = agent.observation_manager.current_observation
obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
return obs
obs = gymnasium.spaces.flatten(unflat_space, unflat_obs)
if agent.action_masking:
all_obs[agent_name] = {"action_mask": self.game.action_mask(agent_name), "observations": obs}
else:
all_obs[agent_name] = obs
return all_obs
def close(self):
"""Close the simulation."""
@@ -158,15 +167,30 @@ class PrimaiteRayEnv(gymnasium.Env):
self.env = PrimaiteGymEnv(env_config=env_config)
# self.env.episode_counter -= 1
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
if self.env.agent.action_masking:
self.observation_space = spaces.Dict(
{"action_mask": spaces.MultiBinary(self.env.action_space.n), "observations": self.env.observation_space}
)
else:
self.observation_space = self.env.observation_space
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
if self.env.agent.action_masking:
obs, *_ = self.env.reset(seed=seed)
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}
return new_obs, *_
return self.env.reset(seed=seed)
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
"""Perform a step in the environment."""
return self.env.step(action)
# if action masking is enabled, intercept the step method and add action mask to observation
if self.env.agent.action_masking:
obs, *_ = self.env.step(action)
new_obs = {"action_mask": self.game.action_mask(self.env._agent_name), "observations": obs}
return new_obs, *_
else:
return self.env.step(action)
def close(self):
"""Close the simulation."""

View File

@@ -3,9 +3,10 @@
"""Core of the PrimAITE Simulator."""
import warnings
from abc import abstractmethod
from typing import Callable, Dict, List, Literal, Optional, Union
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
from uuid import uuid4
from prettytable import PrettyTable
from pydantic import BaseModel, ConfigDict, Field, validate_call
from primaite import getLogger
@@ -164,18 +165,51 @@ class RequestManager(BaseModel):
self.request_types.pop(name)
def get_request_types_recursively(self) -> List[List[str]]:
"""Recursively generate request tree for this component."""
def get_request_types_recursively(self) -> List[RequestFormat]:
"""
Recursively generate request tree for this component.
:param parent_valid: Whether this sub-request's parent request was valid. This value should not be specified by
users, it is used by the recursive call.
:type parent_valid: bool
:returns: A list of tuples where the first tuple element is the request string and the second is whether that
request is currently possible to execute.
:rtype: List[Tuple[RequestFormat, bool]]
"""
requests = []
for req_name, req in self.request_types.items():
if isinstance(req.func, RequestManager):
sub_requests = req.func.get_request_types_recursively()
sub_requests = [[req_name] + a for a in sub_requests]
sub_requests = req.func.get_request_types_recursively() # recurse
sub_requests = [([req_name] + a) for a in sub_requests] # prepend parent request to leaf
requests.extend(sub_requests)
else:
requests.append([req_name])
else: # leaf node found
requests.append(req_name)
return requests
def show(self) -> None:
"""Display all currently available requests and whether they are valid."""
table = PrettyTable(["request"])
table.align = "l"
table.add_rows(self.get_request_types_recursively())
print(table)
def check_valid(self, request: RequestFormat, context: Dict) -> bool:
"""Check if this request would be valid in the current state of the simulation without invoking it."""
request_key = request[0]
request_options = request[1:]
if request_key not in self.request_types:
return False
request_type = self.request_types[request_key]
# recurse if we are not at a leaf node
if isinstance(request_type.func, RequestManager):
return request_type.func.check_valid(request_options, context)
return request_type.validator(request_options, context)
class SimComponent(BaseModel):
"""Extension of pydantic BaseModel with additional methods that must be defined by all classes in the simulator."""

View File

@@ -52,6 +52,8 @@ class GroupMembershipValidator(RequestPermissionValidator):
def __call__(self, request: List[str], context: Dict) -> bool:
"""Permit the action if the request comes from an account which belongs to the right group."""
# if context request source is part of any groups mentioned in self.allow_groups, return true, otherwise false
if not context:
return False
requestor_groups: List[str] = context["request_source"]["groups"]
for allowed_group in self.allowed_groups:
if allowed_group.name in requestor_groups:

View File

@@ -92,6 +92,7 @@ class Service(IOSoftware):
_is_service_running = Service._StateValidator(service=self, state=ServiceOperatingState.RUNNING)
_is_service_stopped = Service._StateValidator(service=self, state=ServiceOperatingState.STOPPED)
_is_service_paused = Service._StateValidator(service=self, state=ServiceOperatingState.PAUSED)
_is_service_disabled = Service._StateValidator(service=self, state=ServiceOperatingState.DISABLED)
rm = super()._init_request_manager()
rm.add_request(
@@ -131,7 +132,12 @@ class Service(IOSoftware):
),
)
rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable())))
rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable())))
rm.add_request(
"enable",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.enable()), validator=_is_service_disabled
),
)
rm.add_request(
"fix",
RequestType(

File diff suppressed because it is too large Load Diff

View File

@@ -243,25 +243,25 @@ agents:
action: "NODE_FILE_SCAN"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
10:
action: "NODE_FILE_CHECKHASH"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
11:
action: "NODE_FILE_DELETE"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
13:
action: "NODE_SERVICE_FIX"
@@ -272,22 +272,22 @@ agents:
action: "NODE_FOLDER_SCAN"
options:
node_id: 2
folder_id: 1
folder_id: 0
15:
action: "NODE_FOLDER_CHECKHASH"
options:
node_id: 2
folder_id: 1
folder_id: 0
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 2
folder_id: 1
folder_id: 0
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 2
folder_id: 1
folder_id: 0
18:
action: "NODE_OS_SCAN"
options:
@@ -518,11 +518,22 @@ agents:
nodes:
- node_name: domain_controller
- node_name: web_server
applications:
- application_name: DatabaseClient
services:
- service_name: WebServer
- node_name: database_server
folders:
- folder_name: database
files:
- file_name: database.db
services:
- service_name: DatabaseService
- node_name: backup_server
- node_name: security_suite
- node_name: client_1
- node_name: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
@@ -557,6 +568,7 @@ agents:
agent_settings:
flatten_obs: true
action_masking: true
@@ -634,6 +646,8 @@ simulation:
dns_server: 192.168.1.10
services:
- type: DatabaseService
options:
backup_server_ip: 192.168.1.16
- type: server
hostname: backup_server

View File

@@ -2,6 +2,7 @@
from typing import Any, Dict, Tuple
import pytest
import ray
import yaml
from primaite import getLogger, PRIMAITE_PATHS
@@ -29,6 +30,7 @@ from primaite.simulator.system.services.service import Service
from primaite.simulator.system.services.web_server.web_server import WebServer
from tests import TEST_ASSETS_ROOT
ray.init(local_mode=True)
ACTION_SPACE_NODE_VALUES = 1
ACTION_SPACE_NODE_ACTION_VALUES = 1

View File

@@ -0,0 +1 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK

View File

@@ -0,0 +1,158 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import importlib
from typing import Dict
import yaml
from ray import air, init, tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.examples.rl_modules.classes.action_masking_rlm import ActionMaskingTorchRLModule
from sb3_contrib import MaskablePPO
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv
from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv
from tests import TEST_ASSETS_ROOT
CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"
MARL_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
def test_sb3_action_masking(monkeypatch):
# There's no simple way of capturing what the action mask was at every step, therefore we are mocking the action
# mask function here to save the output of the action mask method and pass through the result back to the agent.
old_action_mask_method = PrimaiteGame.action_mask
mask_history = []
def cache_action_mask(obj, agent_name):
mask = old_action_mask_method(obj, agent_name)
mask_history.append(mask)
return mask
# Even though it's easy to know which CAOS action the agent took by looking at agent history, we don't know which
# action map action integer that was, therefore we cache it by using monkeypatch
action_num_history = []
def cache_step(env, action: int):
action_num_history.append(action)
return PrimaiteGymEnv.step(env, action)
monkeypatch.setattr(PrimaiteGame, "action_mask", cache_action_mask)
env = PrimaiteGymEnv(CFG_PATH)
monkeypatch.setattr(env, "step", lambda action: cache_step(env, action))
model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, batch_size=32)
model.learn(512)
assert len(action_num_history) == len(mask_history) > 0
# Make sure the masks had at least some False entries, if it was all True then the mask was disabled
assert any([not all(x) for x in mask_history])
# When the agent takes action N from its action map, we need to have a look at the action mask and make sure that
# the N-th entry was True, meaning that it was a valid action at that step.
# This plucks out the mask history at step i, and at action entry a and checks that it's set to True, and this
# happens for all steps i in the episode
assert all(mask_history[i][a] for i, a in enumerate(action_num_history))
monkeypatch.undo()
def test_ray_single_agent_action_masking(monkeypatch):
"""Check that a Ray agent uses the action mask and never chooses invalid actions."""
with open(CFG_PATH, "r") as f:
cfg = yaml.safe_load(f)
for agent in cfg["agents"]:
if agent["ref"] == "defender":
agent["agent_settings"]["flatten_obs"] = True
# There's no simple way of capturing what the action mask was at every step, therefore we are mocking the step
# function to save the action mask and the agent's chosen action to a local variable.
old_step_method = PrimaiteRayEnv.step
action_num_history = []
mask_history = []
def cache_step(self, action: int):
action_num_history.append(action)
obs, *_ = old_step_method(self, action)
action_mask = obs["action_mask"]
mask_history.append(action_mask)
return obs, *_
monkeypatch.setattr(PrimaiteRayEnv, "step", lambda *args, **kwargs: cache_step(*args, **kwargs))
# Configure Ray PPO to use action masking by using the ActionMaskingTorchRLModule
config = (
PPOConfig()
.api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)
.environment(env=PrimaiteRayEnv, env_config=cfg, action_mask_key="action_mask")
.rl_module(rl_module_spec=SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule))
.env_runners(num_env_runners=0)
.training(train_batch_size=128)
)
algo = config.build()
algo.train()
assert len(action_num_history) == len(mask_history) > 0
# Make sure the masks had at least some False entries, if it was all True then the mask was disabled
assert any([not all(x) for x in mask_history])
# When the agent takes action N from its action map, we need to have a look at the action mask and make sure that
# the N-th action was valid.
# The first step uses the action mask provided by the reset method, so we are only checking from the second step
# onward, that's why we need to use mask_history[:-1] and action_num_history[1:]
assert all(mask_history[:-1][i][a] for i, a in enumerate(action_num_history[1:]))
monkeypatch.undo()
def test_ray_multi_agent_action_masking(monkeypatch):
"""Check that Ray agents never take invalid actions when using MARL."""
with open(MARL_PATH, "r") as f:
cfg = yaml.safe_load(f)
old_step_method = PrimaiteRayMARLEnv.step
action_num_history = {"defender_1": [], "defender_2": []}
mask_history = {"defender_1": [], "defender_2": []}
def cache_step(self, actions: Dict[str, int]):
for agent_name, action in actions.items():
action_num_history[agent_name].append(action)
obs, *_ = old_step_method(self, actions)
for (
agent_name,
o,
) in obs.items():
mask_history[agent_name].append(o["action_mask"])
return obs, *_
monkeypatch.setattr(PrimaiteRayMARLEnv, "step", lambda *args, **kwargs: cache_step(*args, **kwargs))
config = (
PPOConfig()
.multi_agent(
policies={
"defender_1",
"defender_2",
}, # These names are the same as the agents defined in the example config.
policy_mapping_fn=lambda agent_id, *args, **kwargs: agent_id,
)
.api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)
.environment(env=PrimaiteRayMARLEnv, env_config=cfg, action_mask_key="action_mask")
.rl_module(
rl_module_spec=MultiAgentRLModuleSpec(
module_specs={
"defender_1": SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),
"defender_2": SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),
}
)
)
.env_runners(num_env_runners=0)
.training(train_batch_size=128)
)
algo = config.build()
algo.train()
for agent_name in ["defender_1", "defender_2"]:
act_hist = action_num_history[agent_name]
mask_hist = mask_history[agent_name]
assert len(act_hist) == len(mask_hist) > 0
assert any([not all(x) for x in mask_hist])
assert all(mask_hist[:-1][i][a] for i, a in enumerate(act_hist[1:]))
monkeypatch.undo()

View File

@@ -16,8 +16,6 @@ def test_rllib_multi_agent_compatibility():
with open(MULTI_AGENT_PATH, "r") as f:
cfg = yaml.safe_load(f)
ray.init()
config = (
PPOConfig()
.environment(env=PrimaiteRayMARLEnv, env_config=cfg)
@@ -39,4 +37,3 @@ def test_rllib_multi_agent_compatibility():
),
param_space=config,
).fit()
ray.shutdown()

View File

@@ -20,9 +20,6 @@ def test_rllib_single_agent_compatibility():
game = PrimaiteGame.from_config(cfg)
ray.shutdown()
ray.init()
env_config = {"game": game}
config = {
"env": PrimaiteRayEnv,
@@ -41,4 +38,3 @@ def test_rllib_single_agent_compatibility():
assert save_file.exists()
save_file.unlink() # clean up
ray.shutdown()

View File

@@ -65,25 +65,25 @@ class TestPrimaiteEnvironment:
cfg = yaml.safe_load(f)
env = PrimaiteRayMARLEnv(env_config=cfg)
assert set(env._agent_ids) == {"defender1", "defender2"}
assert set(env._agent_ids) == {"defender_1", "defender_2"}
assert len(env.agents) == 2
defender1 = env.agents["defender1"]
defender2 = env.agents["defender2"]
assert (num_actions_1 := len(defender1.action_manager.action_map)) == 54
assert (num_actions_2 := len(defender2.action_manager.action_map)) == 38
defender_1 = env.agents["defender_1"]
defender_2 = env.agents["defender_2"]
assert (num_actions_1 := len(defender_1.action_manager.action_map)) == 78
assert (num_actions_2 := len(defender_2.action_manager.action_map)) == 78
# ensure we can run all valid actions without error
for act_1 in range(num_actions_1):
env.step({"defender1": act_1, "defender2": 0})
env.step({"defender_1": act_1, "defender_2": 0})
for act_2 in range(num_actions_2):
env.step({"defender1": 0, "defender2": act_2})
env.step({"defender_1": 0, "defender_2": act_2})
# ensure we get error when taking an invalid action
with pytest.raises(KeyError):
env.step({"defender1": num_actions_1, "defender2": 0})
env.step({"defender_1": num_actions_1, "defender_2": 0})
with pytest.raises(KeyError):
env.step({"defender1": 0, "defender2": num_actions_2})
env.step({"defender_1": 0, "defender_2": num_actions_2})
def test_error_thrown_on_bad_configuration(self):
"""Make sure we throw an error when the config is bad."""

View File

@@ -99,7 +99,7 @@ class TestConfigureDatabaseAction:
game.step()
assert db_client.server_ip_address == old_ip
assert db_client.server_password is "admin123"
assert db_client.server_password == "admin123"
class TestConfigureRansomwareScriptAction:

View File

@@ -0,0 +1,161 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from primaite.session.environment import PrimaiteGymEnv
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
from primaite.simulator.system.services.service import ServiceOperatingState
from tests.conftest import TEST_ASSETS_ROOT
CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"
def test_mask_contents_correct():
env = PrimaiteGymEnv(CFG_PATH)
game = env.game
sim = game.simulation
net = sim.network
mask = game.action_mask("defender")
agent = env.agent
node_list = agent.action_manager.node_names
action_map = agent.action_manager.action_map
# CHECK NIC ENABLE/DISABLE ACTIONS
for action_num, action in action_map.items():
mask = game.action_mask("defender")
act_type, act_params = action
if act_type == "NODE_NIC_ENABLE":
node_name = node_list[act_params["node_id"]]
node_obj = net.get_node_by_hostname(node_name)
nic_obj = node_obj.network_interface[act_params["nic_id"] + 1]
assert nic_obj.enabled
assert not mask[action_num]
nic_obj.disable()
mask = game.action_mask("defender")
assert mask[action_num]
nic_obj.enable()
if act_type == "NODE_NIC_DISABLE":
node_name = node_list[act_params["node_id"]]
node_obj = net.get_node_by_hostname(node_name)
nic_obj = node_obj.network_interface[act_params["nic_id"] + 1]
assert nic_obj.enabled
assert mask[action_num]
nic_obj.disable()
mask = game.action_mask("defender")
assert not mask[action_num]
nic_obj.enable()
if act_type == "ROUTER_ACL_ADDRULE":
assert mask[action_num]
if act_type == "ROUTER_ACL_REMOVERULE":
assert mask[action_num]
if act_type == "NODE_RESET":
node_name = node_list[act_params["node_id"]]
node_obj = net.get_node_by_hostname(node_name)
assert node_obj.operating_state is NodeOperatingState.ON
assert mask[action_num]
node_obj.operating_state = NodeOperatingState.OFF
mask = game.action_mask("defender")
assert not mask[action_num]
node_obj.operating_state = NodeOperatingState.ON
if act_type == "NODE_SHUTDOWN":
node_name = node_list[act_params["node_id"]]
node_obj = net.get_node_by_hostname(node_name)
assert node_obj.operating_state is NodeOperatingState.ON
assert mask[action_num]
node_obj.operating_state = NodeOperatingState.OFF
mask = game.action_mask("defender")
assert not mask[action_num]
node_obj.operating_state = NodeOperatingState.ON
if act_type == "NODE_OS_SCAN":
node_name = node_list[act_params["node_id"]]
node_obj = net.get_node_by_hostname(node_name)
assert node_obj.operating_state is NodeOperatingState.ON
assert mask[action_num]
node_obj.operating_state = NodeOperatingState.OFF
mask = game.action_mask("defender")
assert not mask[action_num]
node_obj.operating_state = NodeOperatingState.ON
if act_type == "NODE_STARTUP":
node_name = node_list[act_params["node_id"]]
node_obj = net.get_node_by_hostname(node_name)
assert node_obj.operating_state is NodeOperatingState.ON
assert not mask[action_num]
node_obj.operating_state = NodeOperatingState.OFF
mask = game.action_mask("defender")
assert mask[action_num]
node_obj.operating_state = NodeOperatingState.ON
if act_type == "DONOTHING":
assert mask[action_num]
if act_type == "NODE_SERVICE_DISABLE":
assert mask[action_num]
if act_type in ["NODE_SERVICE_SCAN", "NODE_SERVICE_STOP", "NODE_SERVICE_PAUSE"]:
node_name = node_list[act_params["node_id"]]
service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]]
node_obj = net.get_node_by_hostname(node_name)
service_obj = node_obj.software_manager.software.get(service_name)
assert service_obj.operating_state is ServiceOperatingState.RUNNING
assert mask[action_num]
service_obj.operating_state = ServiceOperatingState.DISABLED
mask = game.action_mask("defender")
assert not mask[action_num]
service_obj.operating_state = ServiceOperatingState.RUNNING
if act_type == "NODE_SERVICE_RESUME":
node_name = node_list[act_params["node_id"]]
service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]]
node_obj = net.get_node_by_hostname(node_name)
service_obj = node_obj.software_manager.software.get(service_name)
assert service_obj.operating_state is ServiceOperatingState.RUNNING
assert not mask[action_num]
service_obj.operating_state = ServiceOperatingState.PAUSED
mask = game.action_mask("defender")
assert mask[action_num]
service_obj.operating_state = ServiceOperatingState.RUNNING
if act_type == "NODE_SERVICE_START":
node_name = node_list[act_params["node_id"]]
service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]]
node_obj = net.get_node_by_hostname(node_name)
service_obj = node_obj.software_manager.software.get(service_name)
assert service_obj.operating_state is ServiceOperatingState.RUNNING
assert not mask[action_num]
service_obj.operating_state = ServiceOperatingState.STOPPED
mask = game.action_mask("defender")
assert mask[action_num]
service_obj.operating_state = ServiceOperatingState.RUNNING
if act_type == "NODE_SERVICE_ENABLE":
node_name = node_list[act_params["node_id"]]
service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]]
node_obj = net.get_node_by_hostname(node_name)
service_obj = node_obj.software_manager.software.get(service_name)
assert service_obj.operating_state is ServiceOperatingState.RUNNING
assert not mask[action_num]
service_obj.operating_state = ServiceOperatingState.DISABLED
mask = game.action_mask("defender")
assert mask[action_num]
service_obj.operating_state = ServiceOperatingState.RUNNING
if act_type in ["NODE_FILE_SCAN", "NODE_FILE_CHECKHASH", "NODE_FILE_DELETE"]:
node_name = node_list[act_params["node_id"]]
folder_name = agent.action_manager.get_folder_name_by_idx(act_params["node_id"], act_params["folder_id"])
file_name = agent.action_manager.get_file_name_by_idx(
act_params["node_id"], act_params["folder_id"], act_params["file_id"]
)
node_obj = net.get_node_by_hostname(node_name)
file_obj = node_obj.file_system.get_file(folder_name, file_name, include_deleted=True)
assert not file_obj.deleted
assert mask[action_num]
service_obj.operating_state = ServiceOperatingState.DISABLED
mask = game.action_mask("defender")
assert mask[action_num]
service_obj.operating_state = ServiceOperatingState.RUNNING