Merged PR 454: Action masking support
## Summary * Add action masking to environments * add notebook demonstrating action masking ## Test process * E2E test for SB3, Ray SARL, and Ray MARL * integration test to check if the contents of the action mask change accordingly when statuses of components change ## Checklist - [x] PR is linked to a **work item** - [x] **acceptance criteria** of linked ticket are met - [x] performed **self-review** of the code - [x] written **tests** for any new functionality added with this PR - [x] updated the **documentation** if this PR changes or adds functionality - [x] written/updated **design docs** if this PR implements new functionality - https://dev.azure.com/ma-dev-uk/PrimAITE/_wiki/wikis/PrimAITE.wiki/703/Action-Masking - [ ] updated the **change log** - [x] ran **pre-commit** checks for code style - [x] attended to any **TO-DOs** left in the code Related work items: #2623
This commit is contained in:
@@ -123,6 +123,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
|
||||
source/environment
|
||||
source/customising_scenarios
|
||||
source/varying_config_files
|
||||
source/action_masking
|
||||
|
||||
.. toctree::
|
||||
:caption: Notebooks:
|
||||
|
||||
80
docs/source/action_masking.rst
Normal file
80
docs/source/action_masking.rst
Normal file
@@ -0,0 +1,80 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
Action Masking
|
||||
**************
|
||||
The PrimAITE simulation is able to provide action masks in the environment output. These action masks let the agents know
|
||||
about which actions are invalid based on the current environment state. For instance, it's not possible to install
|
||||
software on a node that is turned off. Therefore, if an agent has a NODE_SOFTWARE_INSTALL in it's action map for that node,
|
||||
the action mask will show `0` in the corresponding entry.
|
||||
|
||||
Configuration
|
||||
=============
|
||||
Action masking is supported for agents that use the `ProxyAgent` class (the class used for connecting to RL algorithms).
|
||||
In order to use action masking, set the agent_settings.action_masking parameter to True in the config file.
|
||||
|
||||
Masking Logic
|
||||
=============
|
||||
The following logic is applied:
|
||||
|
||||
* **DONOTHING** : Always possible
|
||||
* **NODE_HOST_SERVICE_SCAN** : Node is on. Service is running.
|
||||
* **NODE_HOST_SERVICE_STOP** : Node is on. Service is running.
|
||||
* **NODE_HOST_SERVICE_START** : Node is on. Service is stopped.
|
||||
* **NODE_HOST_SERVICE_PAUSE** : Node is on. Service is running.
|
||||
* **NODE_HOST_SERVICE_RESUME** : Node is on. Service is paused.
|
||||
* **NODE_HOST_SERVICE_RESTART** : Node is on. Service is running.
|
||||
* **NODE_HOST_SERVICE_DISABLE** : Node is on.
|
||||
* **NODE_HOST_SERVICE_ENABLE** : Node is on. Service is disabled.
|
||||
* **NODE_HOST_SERVICE_FIX** : Node is on. Service is running.
|
||||
* **NODE_HOST_APPLICATION_EXECUTE** : Node is on.
|
||||
* **NODE_HOST_APPLICATION_SCAN** : Node is on. Application is running.
|
||||
* **NODE_HOST_APPLICATION_CLOSE** : Node is on. Application is running.
|
||||
* **NODE_HOST_APPLICATION_FIX** : Node is on. Application is running.
|
||||
* **NODE_HOST_APPLICATION_INSTALL** : Node is on.
|
||||
* **NODE_HOST_APPLICATION_REMOVE** : Node is on.
|
||||
* **NODE_HOST_FILE_SCAN** : Node is on. File exists. File not deleted.
|
||||
* **NODE_HOST_FILE_CREATE** : Node is on.
|
||||
* **NODE_HOST_FILE_CHECKHASH** : Node is on. File exists. File not deleted.
|
||||
* **NODE_HOST_FILE_DELETE** : Node is on. File exists.
|
||||
* **NODE_HOST_FILE_REPAIR** : Node is on. File exists. File not deleted.
|
||||
* **NODE_HOST_FILE_RESTORE** : Node is on. File exists. File is deleted.
|
||||
* **NODE_HOST_FILE_CORRUPT** : Node is on. File exists. File not deleted.
|
||||
* **NODE_HOST_FILE_ACCESS** : Node is on. File exists. File not deleted.
|
||||
* **NODE_HOST_FOLDER_CREATE** : Node is on.
|
||||
* **NODE_HOST_FOLDER_SCAN** : Node is on. Folder exists. Folder not deleted.
|
||||
* **NODE_HOST_FOLDER_CHECKHASH** : Node is on. Folder exists. Folder not deleted.
|
||||
* **NODE_HOST_FOLDER_REPAIR** : Node is on. Folder exists. Folder not deleted.
|
||||
* **NODE_HOST_FOLDER_RESTORE** : Node is on. Folder exists. Folder is deleted.
|
||||
* **NODE_HOST_OS_SCAN** : Node is on.
|
||||
* **NODE_HOST_NIC_ENABLE** : NIC is disabled. Node is on.
|
||||
* **NODE_HOST_NIC_DISABLE** : NIC is enabled. Node is on.
|
||||
* **NODE_HOST_SHUTDOWN** : Node is on.
|
||||
* **NODE_HOST_STARTUP** : Node is off.
|
||||
* **NODE_HOST_RESET** : Node is on.
|
||||
* **NODE_HOST_NMAP_PING_SCAN** : Node is on.
|
||||
* **NODE_HOST_NMAP_PORT_SCAN** : Node is on.
|
||||
* **NODE_HOST_NMAP_NETWORK_SERVICE_RECON** : Node is on.
|
||||
* **NODE_ROUTER_PORT_ENABLE** : Router is on.
|
||||
* **NODE_ROUTER_PORT_DISABLE** : Router is on.
|
||||
* **NODE_ROUTER_ACL_ADDRULE** : Router is on.
|
||||
* **NODE_ROUTER_ACL_REMOVERULE** : Router is on.
|
||||
* **NODE_FIREWALL_PORT_ENABLE** : Firewall is on.
|
||||
* **NODE_FIREWALL_PORT_DISABLE** : Firewall is on.
|
||||
* **NODE_FIREWALL_ACL_ADDRULE** : Firewall is on.
|
||||
* **NODE_FIREWALL_ACL_REMOVERULE** : Firewall is on.
|
||||
|
||||
|
||||
Mechanism
|
||||
=========
|
||||
The environment iterates over the RL agent's ``action_map`` and generates the corresponding simulator request string.
|
||||
It uses the ``RequestManager.check_valid()`` method to invoke the relevant ``RequestPermissionValidator`` without
|
||||
actually running the request on the simulation.
|
||||
|
||||
Current Limitations
|
||||
===================
|
||||
Currently, action masking only considers whether the action as a whole is possible, it doesn't verify that the exact
|
||||
parameter combination passed to the action make sense in the current context. For instance, if ACL rule 3 on router_1 is
|
||||
already populated, the action for adding another rule at position 3 will be available regardless, as long as that router
|
||||
is turned on. This will never block valid actions. It will just occasionally allow invalid actions.
|
||||
@@ -55,6 +55,7 @@ rl = [
|
||||
"ray[rllib] >= 2.20.0, < 3",
|
||||
"tensorflow==2.12.0",
|
||||
"stable-baselines3[extra]==2.1.0",
|
||||
"sb3-contrib==2.1.0",
|
||||
]
|
||||
dev = [
|
||||
"build==0.10.0",
|
||||
|
||||
@@ -741,6 +741,7 @@ agents:
|
||||
|
||||
agent_settings:
|
||||
flatten_obs: true
|
||||
action_masking: true
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -733,6 +733,7 @@ agents:
|
||||
|
||||
agent_settings:
|
||||
flatten_obs: true
|
||||
action_masking: true
|
||||
|
||||
- ref: defender_2
|
||||
team: BLUE
|
||||
@@ -1316,6 +1317,7 @@ agents:
|
||||
|
||||
agent_settings:
|
||||
flatten_obs: true
|
||||
action_masking: true
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -44,3 +44,18 @@ def data_manipulation_config_path() -> Path:
|
||||
_LOGGER.error(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
return path
|
||||
|
||||
|
||||
def data_manipulation_marl_config_path() -> Path:
|
||||
"""
|
||||
Get the path to the MARL example config.
|
||||
|
||||
:return: Path to yaml config file for the MARL scenario.
|
||||
:rtype: Path
|
||||
"""
|
||||
path = _EXAMPLE_CFG / "data_manipulation_marl.yaml"
|
||||
if not path.exists():
|
||||
msg = f"Example config does not exist: {path}. Have you run `primaite setup`?"
|
||||
_LOGGER.error(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
return path
|
||||
|
||||
@@ -49,7 +49,7 @@ class AbstractAction(ABC):
|
||||
objects."""
|
||||
|
||||
@abstractmethod
|
||||
def form_request(self) -> List[str]:
|
||||
def form_request(self) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return []
|
||||
|
||||
@@ -67,7 +67,7 @@ class DoNothingAction(AbstractAction):
|
||||
# i.e. a choice between one option. To make enumerating this action easier, we are adding a 'dummy' paramter
|
||||
# with one option. This just aids the Action Manager to enumerate all possibilities.
|
||||
|
||||
def form_request(self, **kwargs) -> List[str]:
|
||||
def form_request(self, **kwargs) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return ["do_nothing"]
|
||||
|
||||
@@ -86,7 +86,7 @@ class NodeServiceAbstractAction(AbstractAction):
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services}
|
||||
self.verb: str # define but don't initialise: defends against children classes not defining this
|
||||
|
||||
def form_request(self, node_id: int, service_id: int) -> List[str]:
|
||||
def form_request(self, node_id: int, service_id: int) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
service_name = self.manager.get_service_name_by_idx(node_id, service_id)
|
||||
@@ -181,7 +181,7 @@ class NodeApplicationAbstractAction(AbstractAction):
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "application_id": num_applications}
|
||||
self.verb: str # define but don't initialise: defends against children classes not defining this
|
||||
|
||||
def form_request(self, node_id: int, application_id: int) -> List[str]:
|
||||
def form_request(self, node_id: int, application_id: int) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
application_name = self.manager.get_application_name_by_idx(node_id, application_id)
|
||||
@@ -229,7 +229,7 @@ class NodeApplicationInstallAction(AbstractAction):
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes}
|
||||
|
||||
def form_request(self, node_id: int, application_name: str) -> List[str]:
|
||||
def form_request(self, node_id: int, application_name: str) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
if node_name is None:
|
||||
@@ -324,7 +324,7 @@ class NodeApplicationRemoveAction(AbstractAction):
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes}
|
||||
|
||||
def form_request(self, node_id: int, application_name: str) -> List[str]:
|
||||
def form_request(self, node_id: int, application_name: str) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
if node_name is None:
|
||||
@@ -346,7 +346,7 @@ class NodeFolderAbstractAction(AbstractAction):
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders}
|
||||
self.verb: str # define but don't initialise: defends against children classes not defining this
|
||||
|
||||
def form_request(self, node_id: int, folder_id: int) -> List[str]:
|
||||
def form_request(self, node_id: int, folder_id: int) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id)
|
||||
@@ -394,7 +394,9 @@ class NodeFileCreateAction(AbstractAction):
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "create"
|
||||
|
||||
def form_request(self, node_id: int, folder_name: str, file_name: str, force: Optional[bool] = False) -> List[str]:
|
||||
def form_request(
|
||||
self, node_id: int, folder_name: str, file_name: str, force: Optional[bool] = False
|
||||
) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
if node_name is None or folder_name is None or file_name is None:
|
||||
@@ -409,7 +411,7 @@ class NodeFolderCreateAction(AbstractAction):
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "create"
|
||||
|
||||
def form_request(self, node_id: int, folder_name: str) -> List[str]:
|
||||
def form_request(self, node_id: int, folder_name: str) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
if node_name is None or folder_name is None:
|
||||
@@ -430,7 +432,7 @@ class NodeFileAbstractAction(AbstractAction):
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders, "file_id": num_files}
|
||||
self.verb: str # define but don't initialise: defends against children classes not defining this
|
||||
|
||||
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
|
||||
def form_request(self, node_id: int, folder_id: int, file_id: int) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id)
|
||||
@@ -463,7 +465,7 @@ class NodeFileDeleteAction(NodeFileAbstractAction):
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb: str = "delete"
|
||||
|
||||
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
|
||||
def form_request(self, node_id: int, folder_id: int, file_id: int) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id)
|
||||
@@ -504,7 +506,7 @@ class NodeFileAccessAction(AbstractAction):
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "access"
|
||||
|
||||
def form_request(self, node_id: int, folder_name: str, file_name: str) -> List[str]:
|
||||
def form_request(self, node_id: int, folder_name: str, file_name: str) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
if node_name is None or folder_name is None or file_name is None:
|
||||
@@ -525,7 +527,7 @@ class NodeAbstractAction(AbstractAction):
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes}
|
||||
self.verb: str # define but don't initialise: defends against children classes not defining this
|
||||
|
||||
def form_request(self, node_id: int) -> List[str]:
|
||||
def form_request(self, node_id: int) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
return ["network", "node", node_name, self.verb]
|
||||
@@ -740,7 +742,7 @@ class RouterACLRemoveRuleAction(AbstractAction):
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"position": max_acl_rules}
|
||||
|
||||
def form_request(self, target_router: str, position: int) -> List[str]:
|
||||
def form_request(self, target_router: str, position: int) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return ["network", "node", target_router, "acl", "remove_rule", position]
|
||||
|
||||
@@ -923,7 +925,7 @@ class HostNICAbstractAction(AbstractAction):
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
|
||||
self.verb: str # define but don't initialise: defends against children classes not defining this
|
||||
|
||||
def form_request(self, node_id: int, nic_id: int) -> List[str]:
|
||||
def form_request(self, node_id: int, nic_id: int) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_idx=node_id)
|
||||
nic_num = self.manager.get_nic_num_by_idx(node_idx=node_id, nic_idx=nic_id)
|
||||
@@ -960,7 +962,7 @@ class NetworkPortEnableAction(AbstractAction):
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"port_id": max_nics_per_node}
|
||||
|
||||
def form_request(self, target_nodename: str, port_id: int) -> List[str]:
|
||||
def form_request(self, target_nodename: str, port_id: int) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if target_nodename is None or port_id is None:
|
||||
return ["do_nothing"]
|
||||
@@ -979,7 +981,7 @@ class NetworkPortDisableAction(AbstractAction):
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"port_id": max_nics_per_node}
|
||||
|
||||
def form_request(self, target_nodename: str, port_id: int) -> List[str]:
|
||||
def form_request(self, target_nodename: str, port_id: int) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if target_nodename is None or port_id is None:
|
||||
return ["do_nothing"]
|
||||
@@ -1315,7 +1317,7 @@ class ActionManager:
|
||||
act_identifier, act_options = self.action_map[action]
|
||||
return act_identifier, act_options
|
||||
|
||||
def form_request(self, action_identifier: str, action_options: Dict) -> List[str]:
|
||||
def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat:
|
||||
"""Take action in CAOS format and use the execution definition to change it into PrimAITE request format."""
|
||||
act_obj = self.actions[action_identifier]
|
||||
return act_obj.form_request(**action_options)
|
||||
|
||||
@@ -70,6 +70,8 @@ class AgentSettings(BaseModel):
|
||||
"Configuration for when an agent begins performing it's actions"
|
||||
flatten_obs: bool = True
|
||||
"Whether to flatten the observation space before passing it to the agent. True by default."
|
||||
action_masking: bool = False
|
||||
"Whether to return action masks at each step."
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Optional[Dict]) -> "AgentSettings":
|
||||
@@ -207,6 +209,7 @@ class ProxyAgent(AbstractAgent):
|
||||
)
|
||||
self.most_recent_action: ActType
|
||||
self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
|
||||
self.action_masking: bool = agent_settings.action_masking if agent_settings else False
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import DEFAULT_BANDWIDTH, getLogger
|
||||
@@ -202,6 +203,23 @@ class PrimaiteGame:
|
||||
return True
|
||||
return False
|
||||
|
||||
def action_mask(self, agent_name: str) -> np.ndarray:
|
||||
"""
|
||||
Return the action mask for the agent.
|
||||
|
||||
This is a boolean list corresponding to the agent's action space. A False entry means this action cannot be
|
||||
performed during this step.
|
||||
|
||||
:return: Action mask
|
||||
:rtype: List[bool]
|
||||
"""
|
||||
agent = self.agents[agent_name]
|
||||
mask = [True] * len(agent.action_manager.action_map)
|
||||
for i, action in agent.action_manager.action_map.items():
|
||||
request = agent.action_manager.form_request(action_identifier=action[0], action_options=action[1])
|
||||
mask[i] = self.simulation._request_manager.check_valid(request, {})
|
||||
return np.asarray(mask, dtype=np.int8)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the game, this will close the simulation."""
|
||||
return NotImplemented
|
||||
|
||||
218
src/primaite/notebooks/Action-masking.ipynb
Normal file
218
src/primaite/notebooks/Action-masking.ipynb
Normal file
@@ -0,0 +1,218 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Action Masking\n",
|
||||
"\n",
|
||||
"PrimAITE environments support action masking. The action mask shows which of the agent's actions are applicable with the current environment state. For example, a node can only be turned on if it is currently turned off."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"from prettytable import PrettyTable\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env = PrimaiteGymEnv(data_manipulation_config_path())\n",
|
||||
"env.action_masking = True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The action mask is a list of booleans that specifies whether each action in the agent's action map is currently possible. Demonstrated here:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"act_table = PrettyTable((\"number\", \"action\", \"parameters\", \"mask\"))\n",
|
||||
"mask = env.action_masks()\n",
|
||||
"actions = env.agent.action_manager.action_map\n",
|
||||
"max_str_len = 70\n",
|
||||
"for act,mask in zip(actions.items(), mask):\n",
|
||||
" act_num, act_data = act\n",
|
||||
" act_type, act_params = act_data\n",
|
||||
" act_params = s if len(s:=str(act_params))<max_str_len else f\"{s[:max_str_len-3]}...\"\n",
|
||||
" act_table.add_row((act_num, act_type, act_params, mask))\n",
|
||||
"print(act_table)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Action masking for Stable Baselines3 agents\n",
|
||||
"SB3 agents automatically use the action_masks method during the training loop"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sb3_contrib import MaskablePPO\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = MaskablePPO(\"MlpPolicy\", env, gamma=0.4, seed=32)\n",
|
||||
"model.learn(1024)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Action masking for Ray RLLib agents\n",
|
||||
"Ray uses a different API to obtain action masks, but this is handled by the PrimaiteRayEnv and PrimaiteRayMarlEnv classes"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
|
||||
"from ray.rllib.algorithms.ppo import PPOConfig\n",
|
||||
"import yaml\n",
|
||||
"from ray import air, tune\n",
|
||||
"from ray.rllib.examples.rl_modules.classes.action_masking_rlm import ActionMaskingTorchRLModule\n",
|
||||
"from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"for agent in cfg['agents']:\n",
|
||||
" if agent[\"ref\"] == \"defender\":\n",
|
||||
" agent['agent_settings']['flatten_obs'] = True\n",
|
||||
"env_config = cfg\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"config = (\n",
|
||||
" PPOConfig()\n",
|
||||
" .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)\n",
|
||||
" .environment(env=PrimaiteRayEnv, env_config=cfg, action_mask_key=\"action_mask\")\n",
|
||||
" .rl_module(rl_module_spec=SingleAgentRLModuleSpec(module_class = ActionMaskingTorchRLModule))\n",
|
||||
" .env_runners(num_env_runners=0)\n",
|
||||
" .training(train_batch_size=128)\n",
|
||||
")\n",
|
||||
"algo = config.build()\n",
|
||||
"for i in range(2):\n",
|
||||
" results = algo.train()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Action masking with MARL in Ray RLLib\n",
|
||||
"Each agent has their own action mask, this is useful if the agents have different action spaces."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec\n",
|
||||
"from primaite.session.ray_envs import PrimaiteRayMARLEnv\n",
|
||||
"from primaite.config.load import data_manipulation_marl_config_path"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(data_manipulation_marl_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"env_config = cfg\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"config = (\n",
|
||||
" PPOConfig()\n",
|
||||
" .multi_agent(\n",
|
||||
" policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n",
|
||||
" policy_mapping_fn=lambda agent_id, *args, **kwargs: agent_id,\n",
|
||||
" )\n",
|
||||
" .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)\n",
|
||||
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg, action_mask_key=\"action_mask\")\n",
|
||||
" .rl_module(rl_module_spec=MultiAgentRLModuleSpec(module_specs={\n",
|
||||
" \"defender_1\":SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),\n",
|
||||
" \"defender_2\":SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),\n",
|
||||
" }))\n",
|
||||
" .env_runners(num_env_runners=0)\n",
|
||||
" .training(train_batch_size=128)\n",
|
||||
")\n",
|
||||
"algo = config.build()\n",
|
||||
"for i in range(2):\n",
|
||||
" results = algo.train()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -4,6 +4,7 @@ from os import PathLike
|
||||
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
|
||||
|
||||
import gymnasium
|
||||
import numpy as np
|
||||
from gymnasium.core import ActType, ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
@@ -41,6 +42,21 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
self.total_reward_per_episode: Dict[int, float] = {}
|
||||
"""Average rewards of agents per episode."""
|
||||
|
||||
def action_masks(self) -> np.ndarray:
|
||||
"""
|
||||
Return the action mask for the agent.
|
||||
|
||||
This is a boolean list corresponding to the agent's action space. A False entry means this action cannot be
|
||||
performed during this step.
|
||||
|
||||
:return: Action mask
|
||||
:rtype: List[bool]
|
||||
"""
|
||||
if not self.agent.action_masking:
|
||||
return np.asarray([True] * len(self.agent.action_manager.action_map))
|
||||
else:
|
||||
return self.game.action_mask(self._agent_name)
|
||||
|
||||
@property
|
||||
def agent(self) -> ProxyAgent:
|
||||
"""Grab a fresh reference to the agent object because it will be reinstantiated each episode."""
|
||||
|
||||
@@ -3,6 +3,7 @@ import json
|
||||
from typing import Dict, SupportsFloat, Tuple
|
||||
|
||||
import gymnasium
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
||||
@@ -38,15 +39,19 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
|
||||
self.terminateds = set()
|
||||
self.truncateds = set()
|
||||
self.observation_space = gymnasium.spaces.Dict(
|
||||
{
|
||||
name: gymnasium.spaces.flatten_space(agent.observation_manager.space)
|
||||
for name, agent in self.agents.items()
|
||||
}
|
||||
)
|
||||
self.action_space = gymnasium.spaces.Dict(
|
||||
{name: agent.action_manager.space for name, agent in self.agents.items()}
|
||||
self.observation_space = spaces.Dict(
|
||||
{name: spaces.flatten_space(agent.observation_manager.space) for name, agent in self.agents.items()}
|
||||
)
|
||||
for agent_name in self._agent_ids:
|
||||
agent = self.game.rl_agents[agent_name]
|
||||
if agent.action_masking:
|
||||
self.observation_space[agent_name] = spaces.Dict(
|
||||
{
|
||||
"action_mask": spaces.MultiBinary(agent.action_manager.space.n),
|
||||
"observations": self.observation_space[agent_name],
|
||||
}
|
||||
)
|
||||
self.action_space = spaces.Dict({name: agent.action_manager.space for name, agent in self.agents.items()})
|
||||
self._obs_space_in_preferred_format = True
|
||||
self._action_space_in_preferred_format = True
|
||||
super().__init__()
|
||||
@@ -131,13 +136,17 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
|
||||
def _get_obs(self) -> Dict[str, ObsType]:
|
||||
"""Return the current observation."""
|
||||
obs = {}
|
||||
all_obs = {}
|
||||
for agent_name in self._agent_ids:
|
||||
agent = self.game.rl_agents[agent_name]
|
||||
unflat_space = agent.observation_manager.space
|
||||
unflat_obs = agent.observation_manager.current_observation
|
||||
obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
return obs
|
||||
obs = gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
if agent.action_masking:
|
||||
all_obs[agent_name] = {"action_mask": self.game.action_mask(agent_name), "observations": obs}
|
||||
else:
|
||||
all_obs[agent_name] = obs
|
||||
return all_obs
|
||||
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
@@ -158,15 +167,30 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
self.env = PrimaiteGymEnv(env_config=env_config)
|
||||
# self.env.episode_counter -= 1
|
||||
self.action_space = self.env.action_space
|
||||
self.observation_space = self.env.observation_space
|
||||
if self.env.agent.action_masking:
|
||||
self.observation_space = spaces.Dict(
|
||||
{"action_mask": spaces.MultiBinary(self.env.action_space.n), "observations": self.env.observation_space}
|
||||
)
|
||||
else:
|
||||
self.observation_space = self.env.observation_space
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
if self.env.agent.action_masking:
|
||||
obs, *_ = self.env.reset(seed=seed)
|
||||
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}
|
||||
return new_obs, *_
|
||||
return self.env.reset(seed=seed)
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
|
||||
"""Perform a step in the environment."""
|
||||
return self.env.step(action)
|
||||
# if action masking is enabled, intercept the step method and add action mask to observation
|
||||
if self.env.agent.action_masking:
|
||||
obs, *_ = self.env.step(action)
|
||||
new_obs = {"action_mask": self.game.action_mask(self.env._agent_name), "observations": obs}
|
||||
return new_obs, *_
|
||||
else:
|
||||
return self.env.step(action)
|
||||
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
|
||||
@@ -3,9 +3,10 @@
|
||||
"""Core of the PrimAITE Simulator."""
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Dict, List, Literal, Optional, Union
|
||||
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from prettytable import PrettyTable
|
||||
from pydantic import BaseModel, ConfigDict, Field, validate_call
|
||||
|
||||
from primaite import getLogger
|
||||
@@ -164,18 +165,51 @@ class RequestManager(BaseModel):
|
||||
|
||||
self.request_types.pop(name)
|
||||
|
||||
def get_request_types_recursively(self) -> List[List[str]]:
|
||||
"""Recursively generate request tree for this component."""
|
||||
def get_request_types_recursively(self) -> List[RequestFormat]:
|
||||
"""
|
||||
Recursively generate request tree for this component.
|
||||
|
||||
:param parent_valid: Whether this sub-request's parent request was valid. This value should not be specified by
|
||||
users, it is used by the recursive call.
|
||||
:type parent_valid: bool
|
||||
:returns: A list of tuples where the first tuple element is the request string and the second is whether that
|
||||
request is currently possible to execute.
|
||||
:rtype: List[Tuple[RequestFormat, bool]]
|
||||
"""
|
||||
requests = []
|
||||
for req_name, req in self.request_types.items():
|
||||
if isinstance(req.func, RequestManager):
|
||||
sub_requests = req.func.get_request_types_recursively()
|
||||
sub_requests = [[req_name] + a for a in sub_requests]
|
||||
sub_requests = req.func.get_request_types_recursively() # recurse
|
||||
sub_requests = [([req_name] + a) for a in sub_requests] # prepend parent request to leaf
|
||||
requests.extend(sub_requests)
|
||||
else:
|
||||
requests.append([req_name])
|
||||
else: # leaf node found
|
||||
requests.append(req_name)
|
||||
return requests
|
||||
|
||||
def show(self) -> None:
|
||||
"""Display all currently available requests and whether they are valid."""
|
||||
table = PrettyTable(["request"])
|
||||
table.align = "l"
|
||||
table.add_rows(self.get_request_types_recursively())
|
||||
print(table)
|
||||
|
||||
def check_valid(self, request: RequestFormat, context: Dict) -> bool:
|
||||
"""Check if this request would be valid in the current state of the simulation without invoking it."""
|
||||
|
||||
request_key = request[0]
|
||||
request_options = request[1:]
|
||||
|
||||
if request_key not in self.request_types:
|
||||
return False
|
||||
|
||||
request_type = self.request_types[request_key]
|
||||
|
||||
# recurse if we are not at a leaf node
|
||||
if isinstance(request_type.func, RequestManager):
|
||||
return request_type.func.check_valid(request_options, context)
|
||||
|
||||
return request_type.validator(request_options, context)
|
||||
|
||||
|
||||
class SimComponent(BaseModel):
|
||||
"""Extension of pydantic BaseModel with additional methods that must be defined by all classes in the simulator."""
|
||||
|
||||
@@ -52,6 +52,8 @@ class GroupMembershipValidator(RequestPermissionValidator):
|
||||
def __call__(self, request: List[str], context: Dict) -> bool:
|
||||
"""Permit the action if the request comes from an account which belongs to the right group."""
|
||||
# if context request source is part of any groups mentioned in self.allow_groups, return true, otherwise false
|
||||
if not context:
|
||||
return False
|
||||
requestor_groups: List[str] = context["request_source"]["groups"]
|
||||
for allowed_group in self.allowed_groups:
|
||||
if allowed_group.name in requestor_groups:
|
||||
|
||||
@@ -92,6 +92,7 @@ class Service(IOSoftware):
|
||||
_is_service_running = Service._StateValidator(service=self, state=ServiceOperatingState.RUNNING)
|
||||
_is_service_stopped = Service._StateValidator(service=self, state=ServiceOperatingState.STOPPED)
|
||||
_is_service_paused = Service._StateValidator(service=self, state=ServiceOperatingState.PAUSED)
|
||||
_is_service_disabled = Service._StateValidator(service=self, state=ServiceOperatingState.DISABLED)
|
||||
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request(
|
||||
@@ -131,7 +132,12 @@ class Service(IOSoftware):
|
||||
),
|
||||
)
|
||||
rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable())))
|
||||
rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable())))
|
||||
rm.add_request(
|
||||
"enable",
|
||||
RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self.enable()), validator=_is_service_disabled
|
||||
),
|
||||
)
|
||||
rm.add_request(
|
||||
"fix",
|
||||
RequestType(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -243,25 +243,25 @@ agents:
|
||||
action: "NODE_FILE_SCAN"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
10:
|
||||
action: "NODE_FILE_CHECKHASH"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
11:
|
||||
action: "NODE_FILE_DELETE"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
12:
|
||||
action: "NODE_FILE_REPAIR"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
13:
|
||||
action: "NODE_SERVICE_FIX"
|
||||
@@ -272,22 +272,22 @@ agents:
|
||||
action: "NODE_FOLDER_SCAN"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
15:
|
||||
action: "NODE_FOLDER_CHECKHASH"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
16:
|
||||
action: "NODE_FOLDER_REPAIR"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
17:
|
||||
action: "NODE_FOLDER_RESTORE"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
18:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
@@ -518,11 +518,22 @@ agents:
|
||||
nodes:
|
||||
- node_name: domain_controller
|
||||
- node_name: web_server
|
||||
applications:
|
||||
- application_name: DatabaseClient
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- node_name: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
services:
|
||||
- service_name: DatabaseService
|
||||
- node_name: backup_server
|
||||
- node_name: security_suite
|
||||
- node_name: client_1
|
||||
- node_name: client_2
|
||||
|
||||
max_folders_per_node: 2
|
||||
max_files_per_folder: 2
|
||||
max_services_per_node: 2
|
||||
@@ -557,6 +568,7 @@ agents:
|
||||
|
||||
agent_settings:
|
||||
flatten_obs: true
|
||||
action_masking: true
|
||||
|
||||
|
||||
|
||||
@@ -634,6 +646,8 @@ simulation:
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- type: DatabaseService
|
||||
options:
|
||||
backup_server_ip: 192.168.1.16
|
||||
|
||||
- type: server
|
||||
hostname: backup_server
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
@@ -29,6 +30,7 @@ from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
ray.init(local_mode=True)
|
||||
ACTION_SPACE_NODE_VALUES = 1
|
||||
ACTION_SPACE_NODE_ACTION_VALUES = 1
|
||||
|
||||
|
||||
1
tests/e2e_integration_tests/action_masking/__init__.py
Normal file
1
tests/e2e_integration_tests/action_masking/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
@@ -0,0 +1,158 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import importlib
|
||||
from typing import Dict
|
||||
|
||||
import yaml
|
||||
from ray import air, init, tune
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
|
||||
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
|
||||
from ray.rllib.examples.rl_modules.classes.action_masking_rlm import ActionMaskingTorchRLModule
|
||||
from sb3_contrib import MaskablePPO
|
||||
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"
|
||||
MARL_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
|
||||
|
||||
|
||||
def test_sb3_action_masking(monkeypatch):
|
||||
# There's no simple way of capturing what the action mask was at every step, therefore we are mocking the action
|
||||
# mask function here to save the output of the action mask method and pass through the result back to the agent.
|
||||
old_action_mask_method = PrimaiteGame.action_mask
|
||||
mask_history = []
|
||||
|
||||
def cache_action_mask(obj, agent_name):
|
||||
mask = old_action_mask_method(obj, agent_name)
|
||||
mask_history.append(mask)
|
||||
return mask
|
||||
|
||||
# Even though it's easy to know which CAOS action the agent took by looking at agent history, we don't know which
|
||||
# action map action integer that was, therefore we cache it by using monkeypatch
|
||||
action_num_history = []
|
||||
|
||||
def cache_step(env, action: int):
|
||||
action_num_history.append(action)
|
||||
return PrimaiteGymEnv.step(env, action)
|
||||
|
||||
monkeypatch.setattr(PrimaiteGame, "action_mask", cache_action_mask)
|
||||
env = PrimaiteGymEnv(CFG_PATH)
|
||||
monkeypatch.setattr(env, "step", lambda action: cache_step(env, action))
|
||||
|
||||
model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, batch_size=32)
|
||||
model.learn(512)
|
||||
|
||||
assert len(action_num_history) == len(mask_history) > 0
|
||||
# Make sure the masks had at least some False entries, if it was all True then the mask was disabled
|
||||
assert any([not all(x) for x in mask_history])
|
||||
# When the agent takes action N from its action map, we need to have a look at the action mask and make sure that
|
||||
# the N-th entry was True, meaning that it was a valid action at that step.
|
||||
# This plucks out the mask history at step i, and at action entry a and checks that it's set to True, and this
|
||||
# happens for all steps i in the episode
|
||||
assert all(mask_history[i][a] for i, a in enumerate(action_num_history))
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
def test_ray_single_agent_action_masking(monkeypatch):
|
||||
"""Check that a Ray agent uses the action mask and never chooses invalid actions."""
|
||||
with open(CFG_PATH, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
for agent in cfg["agents"]:
|
||||
if agent["ref"] == "defender":
|
||||
agent["agent_settings"]["flatten_obs"] = True
|
||||
|
||||
# There's no simple way of capturing what the action mask was at every step, therefore we are mocking the step
|
||||
# function to save the action mask and the agent's chosen action to a local variable.
|
||||
old_step_method = PrimaiteRayEnv.step
|
||||
action_num_history = []
|
||||
mask_history = []
|
||||
|
||||
def cache_step(self, action: int):
|
||||
action_num_history.append(action)
|
||||
obs, *_ = old_step_method(self, action)
|
||||
action_mask = obs["action_mask"]
|
||||
mask_history.append(action_mask)
|
||||
return obs, *_
|
||||
|
||||
monkeypatch.setattr(PrimaiteRayEnv, "step", lambda *args, **kwargs: cache_step(*args, **kwargs))
|
||||
|
||||
# Configure Ray PPO to use action masking by using the ActionMaskingTorchRLModule
|
||||
config = (
|
||||
PPOConfig()
|
||||
.api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)
|
||||
.environment(env=PrimaiteRayEnv, env_config=cfg, action_mask_key="action_mask")
|
||||
.rl_module(rl_module_spec=SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule))
|
||||
.env_runners(num_env_runners=0)
|
||||
.training(train_batch_size=128)
|
||||
)
|
||||
algo = config.build()
|
||||
algo.train()
|
||||
|
||||
assert len(action_num_history) == len(mask_history) > 0
|
||||
# Make sure the masks had at least some False entries, if it was all True then the mask was disabled
|
||||
assert any([not all(x) for x in mask_history])
|
||||
# When the agent takes action N from its action map, we need to have a look at the action mask and make sure that
|
||||
# the N-th action was valid.
|
||||
# The first step uses the action mask provided by the reset method, so we are only checking from the second step
|
||||
# onward, that's why we need to use mask_history[:-1] and action_num_history[1:]
|
||||
assert all(mask_history[:-1][i][a] for i, a in enumerate(action_num_history[1:]))
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
def test_ray_multi_agent_action_masking(monkeypatch):
|
||||
"""Check that Ray agents never take invalid actions when using MARL."""
|
||||
with open(MARL_PATH, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
old_step_method = PrimaiteRayMARLEnv.step
|
||||
action_num_history = {"defender_1": [], "defender_2": []}
|
||||
mask_history = {"defender_1": [], "defender_2": []}
|
||||
|
||||
def cache_step(self, actions: Dict[str, int]):
|
||||
for agent_name, action in actions.items():
|
||||
action_num_history[agent_name].append(action)
|
||||
obs, *_ = old_step_method(self, actions)
|
||||
for (
|
||||
agent_name,
|
||||
o,
|
||||
) in obs.items():
|
||||
mask_history[agent_name].append(o["action_mask"])
|
||||
return obs, *_
|
||||
|
||||
monkeypatch.setattr(PrimaiteRayMARLEnv, "step", lambda *args, **kwargs: cache_step(*args, **kwargs))
|
||||
|
||||
config = (
|
||||
PPOConfig()
|
||||
.multi_agent(
|
||||
policies={
|
||||
"defender_1",
|
||||
"defender_2",
|
||||
}, # These names are the same as the agents defined in the example config.
|
||||
policy_mapping_fn=lambda agent_id, *args, **kwargs: agent_id,
|
||||
)
|
||||
.api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)
|
||||
.environment(env=PrimaiteRayMARLEnv, env_config=cfg, action_mask_key="action_mask")
|
||||
.rl_module(
|
||||
rl_module_spec=MultiAgentRLModuleSpec(
|
||||
module_specs={
|
||||
"defender_1": SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),
|
||||
"defender_2": SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),
|
||||
}
|
||||
)
|
||||
)
|
||||
.env_runners(num_env_runners=0)
|
||||
.training(train_batch_size=128)
|
||||
)
|
||||
algo = config.build()
|
||||
algo.train()
|
||||
|
||||
for agent_name in ["defender_1", "defender_2"]:
|
||||
act_hist = action_num_history[agent_name]
|
||||
mask_hist = mask_history[agent_name]
|
||||
assert len(act_hist) == len(mask_hist) > 0
|
||||
assert any([not all(x) for x in mask_hist])
|
||||
assert all(mask_hist[:-1][i][a] for i, a in enumerate(act_hist[1:]))
|
||||
monkeypatch.undo()
|
||||
@@ -16,8 +16,6 @@ def test_rllib_multi_agent_compatibility():
|
||||
with open(MULTI_AGENT_PATH, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
ray.init()
|
||||
|
||||
config = (
|
||||
PPOConfig()
|
||||
.environment(env=PrimaiteRayMARLEnv, env_config=cfg)
|
||||
@@ -39,4 +37,3 @@ def test_rllib_multi_agent_compatibility():
|
||||
),
|
||||
param_space=config,
|
||||
).fit()
|
||||
ray.shutdown()
|
||||
|
||||
@@ -20,9 +20,6 @@ def test_rllib_single_agent_compatibility():
|
||||
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
|
||||
ray.shutdown()
|
||||
ray.init()
|
||||
|
||||
env_config = {"game": game}
|
||||
config = {
|
||||
"env": PrimaiteRayEnv,
|
||||
@@ -41,4 +38,3 @@ def test_rllib_single_agent_compatibility():
|
||||
assert save_file.exists()
|
||||
|
||||
save_file.unlink() # clean up
|
||||
ray.shutdown()
|
||||
|
||||
@@ -65,25 +65,25 @@ class TestPrimaiteEnvironment:
|
||||
cfg = yaml.safe_load(f)
|
||||
env = PrimaiteRayMARLEnv(env_config=cfg)
|
||||
|
||||
assert set(env._agent_ids) == {"defender1", "defender2"}
|
||||
assert set(env._agent_ids) == {"defender_1", "defender_2"}
|
||||
|
||||
assert len(env.agents) == 2
|
||||
defender1 = env.agents["defender1"]
|
||||
defender2 = env.agents["defender2"]
|
||||
assert (num_actions_1 := len(defender1.action_manager.action_map)) == 54
|
||||
assert (num_actions_2 := len(defender2.action_manager.action_map)) == 38
|
||||
defender_1 = env.agents["defender_1"]
|
||||
defender_2 = env.agents["defender_2"]
|
||||
assert (num_actions_1 := len(defender_1.action_manager.action_map)) == 78
|
||||
assert (num_actions_2 := len(defender_2.action_manager.action_map)) == 78
|
||||
|
||||
# ensure we can run all valid actions without error
|
||||
for act_1 in range(num_actions_1):
|
||||
env.step({"defender1": act_1, "defender2": 0})
|
||||
env.step({"defender_1": act_1, "defender_2": 0})
|
||||
for act_2 in range(num_actions_2):
|
||||
env.step({"defender1": 0, "defender2": act_2})
|
||||
env.step({"defender_1": 0, "defender_2": act_2})
|
||||
|
||||
# ensure we get error when taking an invalid action
|
||||
with pytest.raises(KeyError):
|
||||
env.step({"defender1": num_actions_1, "defender2": 0})
|
||||
env.step({"defender_1": num_actions_1, "defender_2": 0})
|
||||
with pytest.raises(KeyError):
|
||||
env.step({"defender1": 0, "defender2": num_actions_2})
|
||||
env.step({"defender_1": 0, "defender_2": num_actions_2})
|
||||
|
||||
def test_error_thrown_on_bad_configuration(self):
|
||||
"""Make sure we throw an error when the config is bad."""
|
||||
|
||||
@@ -99,7 +99,7 @@ class TestConfigureDatabaseAction:
|
||||
game.step()
|
||||
|
||||
assert db_client.server_ip_address == old_ip
|
||||
assert db_client.server_password is "admin123"
|
||||
assert db_client.server_password == "admin123"
|
||||
|
||||
|
||||
class TestConfigureRansomwareScriptAction:
|
||||
|
||||
161
tests/integration_tests/game_layer/test_action_mask.py
Normal file
161
tests/integration_tests/game_layer/test_action_mask.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
from tests.conftest import TEST_ASSETS_ROOT
|
||||
|
||||
CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"
|
||||
|
||||
|
||||
def test_mask_contents_correct():
|
||||
env = PrimaiteGymEnv(CFG_PATH)
|
||||
game = env.game
|
||||
sim = game.simulation
|
||||
net = sim.network
|
||||
mask = game.action_mask("defender")
|
||||
agent = env.agent
|
||||
node_list = agent.action_manager.node_names
|
||||
action_map = agent.action_manager.action_map
|
||||
|
||||
# CHECK NIC ENABLE/DISABLE ACTIONS
|
||||
for action_num, action in action_map.items():
|
||||
mask = game.action_mask("defender")
|
||||
act_type, act_params = action
|
||||
|
||||
if act_type == "NODE_NIC_ENABLE":
|
||||
node_name = node_list[act_params["node_id"]]
|
||||
node_obj = net.get_node_by_hostname(node_name)
|
||||
nic_obj = node_obj.network_interface[act_params["nic_id"] + 1]
|
||||
assert nic_obj.enabled
|
||||
assert not mask[action_num]
|
||||
nic_obj.disable()
|
||||
mask = game.action_mask("defender")
|
||||
assert mask[action_num]
|
||||
nic_obj.enable()
|
||||
|
||||
if act_type == "NODE_NIC_DISABLE":
|
||||
node_name = node_list[act_params["node_id"]]
|
||||
node_obj = net.get_node_by_hostname(node_name)
|
||||
nic_obj = node_obj.network_interface[act_params["nic_id"] + 1]
|
||||
assert nic_obj.enabled
|
||||
assert mask[action_num]
|
||||
nic_obj.disable()
|
||||
mask = game.action_mask("defender")
|
||||
assert not mask[action_num]
|
||||
nic_obj.enable()
|
||||
|
||||
if act_type == "ROUTER_ACL_ADDRULE":
|
||||
assert mask[action_num]
|
||||
|
||||
if act_type == "ROUTER_ACL_REMOVERULE":
|
||||
assert mask[action_num]
|
||||
|
||||
if act_type == "NODE_RESET":
|
||||
node_name = node_list[act_params["node_id"]]
|
||||
node_obj = net.get_node_by_hostname(node_name)
|
||||
assert node_obj.operating_state is NodeOperatingState.ON
|
||||
assert mask[action_num]
|
||||
node_obj.operating_state = NodeOperatingState.OFF
|
||||
mask = game.action_mask("defender")
|
||||
assert not mask[action_num]
|
||||
node_obj.operating_state = NodeOperatingState.ON
|
||||
|
||||
if act_type == "NODE_SHUTDOWN":
|
||||
node_name = node_list[act_params["node_id"]]
|
||||
node_obj = net.get_node_by_hostname(node_name)
|
||||
assert node_obj.operating_state is NodeOperatingState.ON
|
||||
assert mask[action_num]
|
||||
node_obj.operating_state = NodeOperatingState.OFF
|
||||
mask = game.action_mask("defender")
|
||||
assert not mask[action_num]
|
||||
node_obj.operating_state = NodeOperatingState.ON
|
||||
|
||||
if act_type == "NODE_OS_SCAN":
|
||||
node_name = node_list[act_params["node_id"]]
|
||||
node_obj = net.get_node_by_hostname(node_name)
|
||||
assert node_obj.operating_state is NodeOperatingState.ON
|
||||
assert mask[action_num]
|
||||
node_obj.operating_state = NodeOperatingState.OFF
|
||||
mask = game.action_mask("defender")
|
||||
assert not mask[action_num]
|
||||
node_obj.operating_state = NodeOperatingState.ON
|
||||
|
||||
if act_type == "NODE_STARTUP":
|
||||
node_name = node_list[act_params["node_id"]]
|
||||
node_obj = net.get_node_by_hostname(node_name)
|
||||
assert node_obj.operating_state is NodeOperatingState.ON
|
||||
assert not mask[action_num]
|
||||
node_obj.operating_state = NodeOperatingState.OFF
|
||||
mask = game.action_mask("defender")
|
||||
assert mask[action_num]
|
||||
node_obj.operating_state = NodeOperatingState.ON
|
||||
|
||||
if act_type == "DONOTHING":
|
||||
assert mask[action_num]
|
||||
|
||||
if act_type == "NODE_SERVICE_DISABLE":
|
||||
assert mask[action_num]
|
||||
|
||||
if act_type in ["NODE_SERVICE_SCAN", "NODE_SERVICE_STOP", "NODE_SERVICE_PAUSE"]:
|
||||
node_name = node_list[act_params["node_id"]]
|
||||
service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]]
|
||||
node_obj = net.get_node_by_hostname(node_name)
|
||||
service_obj = node_obj.software_manager.software.get(service_name)
|
||||
assert service_obj.operating_state is ServiceOperatingState.RUNNING
|
||||
assert mask[action_num]
|
||||
service_obj.operating_state = ServiceOperatingState.DISABLED
|
||||
mask = game.action_mask("defender")
|
||||
assert not mask[action_num]
|
||||
service_obj.operating_state = ServiceOperatingState.RUNNING
|
||||
|
||||
if act_type == "NODE_SERVICE_RESUME":
|
||||
node_name = node_list[act_params["node_id"]]
|
||||
service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]]
|
||||
node_obj = net.get_node_by_hostname(node_name)
|
||||
service_obj = node_obj.software_manager.software.get(service_name)
|
||||
assert service_obj.operating_state is ServiceOperatingState.RUNNING
|
||||
assert not mask[action_num]
|
||||
service_obj.operating_state = ServiceOperatingState.PAUSED
|
||||
mask = game.action_mask("defender")
|
||||
assert mask[action_num]
|
||||
service_obj.operating_state = ServiceOperatingState.RUNNING
|
||||
|
||||
if act_type == "NODE_SERVICE_START":
|
||||
node_name = node_list[act_params["node_id"]]
|
||||
service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]]
|
||||
node_obj = net.get_node_by_hostname(node_name)
|
||||
service_obj = node_obj.software_manager.software.get(service_name)
|
||||
assert service_obj.operating_state is ServiceOperatingState.RUNNING
|
||||
assert not mask[action_num]
|
||||
service_obj.operating_state = ServiceOperatingState.STOPPED
|
||||
mask = game.action_mask("defender")
|
||||
assert mask[action_num]
|
||||
service_obj.operating_state = ServiceOperatingState.RUNNING
|
||||
|
||||
if act_type == "NODE_SERVICE_ENABLE":
|
||||
node_name = node_list[act_params["node_id"]]
|
||||
service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]]
|
||||
node_obj = net.get_node_by_hostname(node_name)
|
||||
service_obj = node_obj.software_manager.software.get(service_name)
|
||||
assert service_obj.operating_state is ServiceOperatingState.RUNNING
|
||||
assert not mask[action_num]
|
||||
service_obj.operating_state = ServiceOperatingState.DISABLED
|
||||
mask = game.action_mask("defender")
|
||||
assert mask[action_num]
|
||||
service_obj.operating_state = ServiceOperatingState.RUNNING
|
||||
|
||||
if act_type in ["NODE_FILE_SCAN", "NODE_FILE_CHECKHASH", "NODE_FILE_DELETE"]:
|
||||
node_name = node_list[act_params["node_id"]]
|
||||
folder_name = agent.action_manager.get_folder_name_by_idx(act_params["node_id"], act_params["folder_id"])
|
||||
file_name = agent.action_manager.get_file_name_by_idx(
|
||||
act_params["node_id"], act_params["folder_id"], act_params["file_id"]
|
||||
)
|
||||
node_obj = net.get_node_by_hostname(node_name)
|
||||
file_obj = node_obj.file_system.get_file(folder_name, file_name, include_deleted=True)
|
||||
assert not file_obj.deleted
|
||||
assert mask[action_num]
|
||||
service_obj.operating_state = ServiceOperatingState.DISABLED
|
||||
mask = game.action_mask("defender")
|
||||
assert mask[action_num]
|
||||
service_obj.operating_state = ServiceOperatingState.RUNNING
|
||||
Reference in New Issue
Block a user