diff --git a/docs/index.rst b/docs/index.rst index 5749ad56..431dea28 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -123,6 +123,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE! source/environment source/customising_scenarios source/varying_config_files + source/action_masking .. toctree:: :caption: Notebooks: diff --git a/docs/source/action_masking.rst b/docs/source/action_masking.rst new file mode 100644 index 00000000..3e5b967b --- /dev/null +++ b/docs/source/action_masking.rst @@ -0,0 +1,80 @@ +.. only:: comment + + © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + +Action Masking +************** +The PrimAITE simulation is able to provide action masks in the environment output. These action masks let the agents know +about which actions are invalid based on the current environment state. For instance, it's not possible to install +software on a node that is turned off. Therefore, if an agent has a NODE_SOFTWARE_INSTALL in it's action map for that node, +the action mask will show `0` in the corresponding entry. + +Configuration +============= +Action masking is supported for agents that use the `ProxyAgent` class (the class used for connecting to RL algorithms). +In order to use action masking, set the agent_settings.action_masking parameter to True in the config file. + +Masking Logic +============= +The following logic is applied: + +* **DONOTHING** : Always possible +* **NODE_HOST_SERVICE_SCAN** : Node is on. Service is running. +* **NODE_HOST_SERVICE_STOP** : Node is on. Service is running. +* **NODE_HOST_SERVICE_START** : Node is on. Service is stopped. +* **NODE_HOST_SERVICE_PAUSE** : Node is on. Service is running. +* **NODE_HOST_SERVICE_RESUME** : Node is on. Service is paused. +* **NODE_HOST_SERVICE_RESTART** : Node is on. Service is running. +* **NODE_HOST_SERVICE_DISABLE** : Node is on. +* **NODE_HOST_SERVICE_ENABLE** : Node is on. Service is disabled. +* **NODE_HOST_SERVICE_FIX** : Node is on. Service is running. +* **NODE_HOST_APPLICATION_EXECUTE** : Node is on. +* **NODE_HOST_APPLICATION_SCAN** : Node is on. Application is running. +* **NODE_HOST_APPLICATION_CLOSE** : Node is on. Application is running. +* **NODE_HOST_APPLICATION_FIX** : Node is on. Application is running. +* **NODE_HOST_APPLICATION_INSTALL** : Node is on. +* **NODE_HOST_APPLICATION_REMOVE** : Node is on. +* **NODE_HOST_FILE_SCAN** : Node is on. File exists. File not deleted. +* **NODE_HOST_FILE_CREATE** : Node is on. +* **NODE_HOST_FILE_CHECKHASH** : Node is on. File exists. File not deleted. +* **NODE_HOST_FILE_DELETE** : Node is on. File exists. +* **NODE_HOST_FILE_REPAIR** : Node is on. File exists. File not deleted. +* **NODE_HOST_FILE_RESTORE** : Node is on. File exists. File is deleted. +* **NODE_HOST_FILE_CORRUPT** : Node is on. File exists. File not deleted. +* **NODE_HOST_FILE_ACCESS** : Node is on. File exists. File not deleted. +* **NODE_HOST_FOLDER_CREATE** : Node is on. +* **NODE_HOST_FOLDER_SCAN** : Node is on. Folder exists. Folder not deleted. +* **NODE_HOST_FOLDER_CHECKHASH** : Node is on. Folder exists. Folder not deleted. +* **NODE_HOST_FOLDER_REPAIR** : Node is on. Folder exists. Folder not deleted. +* **NODE_HOST_FOLDER_RESTORE** : Node is on. Folder exists. Folder is deleted. +* **NODE_HOST_OS_SCAN** : Node is on. +* **NODE_HOST_NIC_ENABLE** : NIC is disabled. Node is on. +* **NODE_HOST_NIC_DISABLE** : NIC is enabled. Node is on. +* **NODE_HOST_SHUTDOWN** : Node is on. +* **NODE_HOST_STARTUP** : Node is off. +* **NODE_HOST_RESET** : Node is on. +* **NODE_HOST_NMAP_PING_SCAN** : Node is on. +* **NODE_HOST_NMAP_PORT_SCAN** : Node is on. +* **NODE_HOST_NMAP_NETWORK_SERVICE_RECON** : Node is on. +* **NODE_ROUTER_PORT_ENABLE** : Router is on. +* **NODE_ROUTER_PORT_DISABLE** : Router is on. +* **NODE_ROUTER_ACL_ADDRULE** : Router is on. +* **NODE_ROUTER_ACL_REMOVERULE** : Router is on. +* **NODE_FIREWALL_PORT_ENABLE** : Firewall is on. +* **NODE_FIREWALL_PORT_DISABLE** : Firewall is on. +* **NODE_FIREWALL_ACL_ADDRULE** : Firewall is on. +* **NODE_FIREWALL_ACL_REMOVERULE** : Firewall is on. + + +Mechanism +========= +The environment iterates over the RL agent's ``action_map`` and generates the corresponding simulator request string. +It uses the ``RequestManager.check_valid()`` method to invoke the relevant ``RequestPermissionValidator`` without +actually running the request on the simulation. + +Current Limitations +=================== +Currently, action masking only considers whether the action as a whole is possible, it doesn't verify that the exact +parameter combination passed to the action make sense in the current context. For instance, if ACL rule 3 on router_1 is +already populated, the action for adding another rule at position 3 will be available regardless, as long as that router +is turned on. This will never block valid actions. It will just occasionally allow invalid actions. diff --git a/pyproject.toml b/pyproject.toml index a0c2e3eb..9e919604 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ rl = [ "ray[rllib] >= 2.20.0, < 3", "tensorflow==2.12.0", "stable-baselines3[extra]==2.1.0", + "sb3-contrib==2.1.0", ] dev = [ "build==0.10.0", diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml index 6d4ec9b4..97442903 100644 --- a/src/primaite/config/_package_data/data_manipulation.yaml +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -741,6 +741,7 @@ agents: agent_settings: flatten_obs: true + action_masking: true diff --git a/src/primaite/config/_package_data/data_manipulation_marl.yaml b/src/primaite/config/_package_data/data_manipulation_marl.yaml index 2e8221a0..ba666781 100644 --- a/src/primaite/config/_package_data/data_manipulation_marl.yaml +++ b/src/primaite/config/_package_data/data_manipulation_marl.yaml @@ -733,6 +733,7 @@ agents: agent_settings: flatten_obs: true + action_masking: true - ref: defender_2 team: BLUE @@ -1316,6 +1317,7 @@ agents: agent_settings: flatten_obs: true + action_masking: true diff --git a/src/primaite/config/load.py b/src/primaite/config/load.py index 3483fc87..144e0733 100644 --- a/src/primaite/config/load.py +++ b/src/primaite/config/load.py @@ -44,3 +44,18 @@ def data_manipulation_config_path() -> Path: _LOGGER.error(msg) raise FileNotFoundError(msg) return path + + +def data_manipulation_marl_config_path() -> Path: + """ + Get the path to the MARL example config. + + :return: Path to yaml config file for the MARL scenario. + :rtype: Path + """ + path = _EXAMPLE_CFG / "data_manipulation_marl.yaml" + if not path.exists(): + msg = f"Example config does not exist: {path}. Have you run `primaite setup`?" + _LOGGER.error(msg) + raise FileNotFoundError(msg) + return path diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index b3b7189c..9a5fedc9 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -49,7 +49,7 @@ class AbstractAction(ABC): objects.""" @abstractmethod - def form_request(self) -> List[str]: + def form_request(self) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" return [] @@ -67,7 +67,7 @@ class DoNothingAction(AbstractAction): # i.e. a choice between one option. To make enumerating this action easier, we are adding a 'dummy' paramter # with one option. This just aids the Action Manager to enumerate all possibilities. - def form_request(self, **kwargs) -> List[str]: + def form_request(self, **kwargs) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" return ["do_nothing"] @@ -86,7 +86,7 @@ class NodeServiceAbstractAction(AbstractAction): self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services} self.verb: str # define but don't initialise: defends against children classes not defining this - def form_request(self, node_id: int, service_id: int) -> List[str]: + def form_request(self, node_id: int, service_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) service_name = self.manager.get_service_name_by_idx(node_id, service_id) @@ -181,7 +181,7 @@ class NodeApplicationAbstractAction(AbstractAction): self.shape: Dict[str, int] = {"node_id": num_nodes, "application_id": num_applications} self.verb: str # define but don't initialise: defends against children classes not defining this - def form_request(self, node_id: int, application_id: int) -> List[str]: + def form_request(self, node_id: int, application_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) application_name = self.manager.get_application_name_by_idx(node_id, application_id) @@ -229,7 +229,7 @@ class NodeApplicationInstallAction(AbstractAction): super().__init__(manager=manager) self.shape: Dict[str, int] = {"node_id": num_nodes} - def form_request(self, node_id: int, application_name: str) -> List[str]: + def form_request(self, node_id: int, application_name: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) if node_name is None: @@ -324,7 +324,7 @@ class NodeApplicationRemoveAction(AbstractAction): super().__init__(manager=manager) self.shape: Dict[str, int] = {"node_id": num_nodes} - def form_request(self, node_id: int, application_name: str) -> List[str]: + def form_request(self, node_id: int, application_name: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) if node_name is None: @@ -346,7 +346,7 @@ class NodeFolderAbstractAction(AbstractAction): self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders} self.verb: str # define but don't initialise: defends against children classes not defining this - def form_request(self, node_id: int, folder_id: int) -> List[str]: + def form_request(self, node_id: int, folder_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id) @@ -394,7 +394,9 @@ class NodeFileCreateAction(AbstractAction): super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) self.verb: str = "create" - def form_request(self, node_id: int, folder_name: str, file_name: str, force: Optional[bool] = False) -> List[str]: + def form_request( + self, node_id: int, folder_name: str, file_name: str, force: Optional[bool] = False + ) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) if node_name is None or folder_name is None or file_name is None: @@ -409,7 +411,7 @@ class NodeFolderCreateAction(AbstractAction): super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) self.verb: str = "create" - def form_request(self, node_id: int, folder_name: str) -> List[str]: + def form_request(self, node_id: int, folder_name: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) if node_name is None or folder_name is None: @@ -430,7 +432,7 @@ class NodeFileAbstractAction(AbstractAction): self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders, "file_id": num_files} self.verb: str # define but don't initialise: defends against children classes not defining this - def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]: + def form_request(self, node_id: int, folder_id: int, file_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id) @@ -463,7 +465,7 @@ class NodeFileDeleteAction(NodeFileAbstractAction): super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) self.verb: str = "delete" - def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]: + def form_request(self, node_id: int, folder_id: int, file_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id) @@ -504,7 +506,7 @@ class NodeFileAccessAction(AbstractAction): super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) self.verb: str = "access" - def form_request(self, node_id: int, folder_name: str, file_name: str) -> List[str]: + def form_request(self, node_id: int, folder_name: str, file_name: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) if node_name is None or folder_name is None or file_name is None: @@ -525,7 +527,7 @@ class NodeAbstractAction(AbstractAction): self.shape: Dict[str, int] = {"node_id": num_nodes} self.verb: str # define but don't initialise: defends against children classes not defining this - def form_request(self, node_id: int) -> List[str]: + def form_request(self, node_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) return ["network", "node", node_name, self.verb] @@ -740,7 +742,7 @@ class RouterACLRemoveRuleAction(AbstractAction): super().__init__(manager=manager) self.shape: Dict[str, int] = {"position": max_acl_rules} - def form_request(self, target_router: str, position: int) -> List[str]: + def form_request(self, target_router: str, position: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" return ["network", "node", target_router, "acl", "remove_rule", position] @@ -923,7 +925,7 @@ class HostNICAbstractAction(AbstractAction): self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node} self.verb: str # define but don't initialise: defends against children classes not defining this - def form_request(self, node_id: int, nic_id: int) -> List[str]: + def form_request(self, node_id: int, nic_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_idx=node_id) nic_num = self.manager.get_nic_num_by_idx(node_idx=node_id, nic_idx=nic_id) @@ -960,7 +962,7 @@ class NetworkPortEnableAction(AbstractAction): super().__init__(manager=manager) self.shape: Dict[str, int] = {"port_id": max_nics_per_node} - def form_request(self, target_nodename: str, port_id: int) -> List[str]: + def form_request(self, target_nodename: str, port_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" if target_nodename is None or port_id is None: return ["do_nothing"] @@ -979,7 +981,7 @@ class NetworkPortDisableAction(AbstractAction): super().__init__(manager=manager) self.shape: Dict[str, int] = {"port_id": max_nics_per_node} - def form_request(self, target_nodename: str, port_id: int) -> List[str]: + def form_request(self, target_nodename: str, port_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" if target_nodename is None or port_id is None: return ["do_nothing"] @@ -1315,7 +1317,7 @@ class ActionManager: act_identifier, act_options = self.action_map[action] return act_identifier, act_options - def form_request(self, action_identifier: str, action_options: Dict) -> List[str]: + def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat: """Take action in CAOS format and use the execution definition to change it into PrimAITE request format.""" act_obj = self.actions[action_identifier] return act_obj.form_request(**action_options) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index c53b1956..f57dc191 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -70,6 +70,8 @@ class AgentSettings(BaseModel): "Configuration for when an agent begins performing it's actions" flatten_obs: bool = True "Whether to flatten the observation space before passing it to the agent. True by default." + action_masking: bool = False + "Whether to return action masks at each step." @classmethod def from_config(cls, config: Optional[Dict]) -> "AgentSettings": @@ -207,6 +209,7 @@ class ProxyAgent(AbstractAgent): ) self.most_recent_action: ActType self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False + self.action_masking: bool = agent_settings.action_masking if agent_settings else False def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 1e6aeae0..dcb4aa90 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -3,6 +3,7 @@ from ipaddress import IPv4Address from typing import Dict, List, Optional +import numpy as np from pydantic import BaseModel, ConfigDict from primaite import DEFAULT_BANDWIDTH, getLogger @@ -202,6 +203,23 @@ class PrimaiteGame: return True return False + def action_mask(self, agent_name: str) -> np.ndarray: + """ + Return the action mask for the agent. + + This is a boolean list corresponding to the agent's action space. A False entry means this action cannot be + performed during this step. + + :return: Action mask + :rtype: List[bool] + """ + agent = self.agents[agent_name] + mask = [True] * len(agent.action_manager.action_map) + for i, action in agent.action_manager.action_map.items(): + request = agent.action_manager.form_request(action_identifier=action[0], action_options=action[1]) + mask[i] = self.simulation._request_manager.check_valid(request, {}) + return np.asarray(mask, dtype=np.int8) + def close(self) -> None: """Close the game, this will close the simulation.""" return NotImplemented diff --git a/src/primaite/notebooks/Action-masking.ipynb b/src/primaite/notebooks/Action-masking.ipynb new file mode 100644 index 00000000..0e067b26 --- /dev/null +++ b/src/primaite/notebooks/Action-masking.ipynb @@ -0,0 +1,218 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Action Masking\n", + "\n", + "PrimAITE environments support action masking. The action mask shows which of the agent's actions are applicable with the current environment state. For example, a node can only be turned on if it is currently turned off." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from primaite.session.environment import PrimaiteGymEnv\n", + "from primaite.config.load import data_manipulation_config_path\n", + "from prettytable import PrettyTable\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env = PrimaiteGymEnv(data_manipulation_config_path())\n", + "env.action_masking = True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The action mask is a list of booleans that specifies whether each action in the agent's action map is currently possible. Demonstrated here:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "act_table = PrettyTable((\"number\", \"action\", \"parameters\", \"mask\"))\n", + "mask = env.action_masks()\n", + "actions = env.agent.action_manager.action_map\n", + "max_str_len = 70\n", + "for act,mask in zip(actions.items(), mask):\n", + " act_num, act_data = act\n", + " act_type, act_params = act_data\n", + " act_params = s if len(s:=str(act_params)) np.ndarray: + """ + Return the action mask for the agent. + + This is a boolean list corresponding to the agent's action space. A False entry means this action cannot be + performed during this step. + + :return: Action mask + :rtype: List[bool] + """ + if not self.agent.action_masking: + return np.asarray([True] * len(self.agent.action_manager.action_map)) + else: + return self.game.action_mask(self._agent_name) + @property def agent(self) -> ProxyAgent: """Grab a fresh reference to the agent object because it will be reinstantiated each episode.""" diff --git a/src/primaite/session/ray_envs.py b/src/primaite/session/ray_envs.py index fc5d73d8..1adc324c 100644 --- a/src/primaite/session/ray_envs.py +++ b/src/primaite/session/ray_envs.py @@ -3,6 +3,7 @@ import json from typing import Dict, SupportsFloat, Tuple import gymnasium +from gymnasium import spaces from gymnasium.core import ActType, ObsType from ray.rllib.env.multi_agent_env import MultiAgentEnv @@ -38,15 +39,19 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): self.terminateds = set() self.truncateds = set() - self.observation_space = gymnasium.spaces.Dict( - { - name: gymnasium.spaces.flatten_space(agent.observation_manager.space) - for name, agent in self.agents.items() - } - ) - self.action_space = gymnasium.spaces.Dict( - {name: agent.action_manager.space for name, agent in self.agents.items()} + self.observation_space = spaces.Dict( + {name: spaces.flatten_space(agent.observation_manager.space) for name, agent in self.agents.items()} ) + for agent_name in self._agent_ids: + agent = self.game.rl_agents[agent_name] + if agent.action_masking: + self.observation_space[agent_name] = spaces.Dict( + { + "action_mask": spaces.MultiBinary(agent.action_manager.space.n), + "observations": self.observation_space[agent_name], + } + ) + self.action_space = spaces.Dict({name: agent.action_manager.space for name, agent in self.agents.items()}) self._obs_space_in_preferred_format = True self._action_space_in_preferred_format = True super().__init__() @@ -131,13 +136,17 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): def _get_obs(self) -> Dict[str, ObsType]: """Return the current observation.""" - obs = {} + all_obs = {} for agent_name in self._agent_ids: agent = self.game.rl_agents[agent_name] unflat_space = agent.observation_manager.space unflat_obs = agent.observation_manager.current_observation - obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs) - return obs + obs = gymnasium.spaces.flatten(unflat_space, unflat_obs) + if agent.action_masking: + all_obs[agent_name] = {"action_mask": self.game.action_mask(agent_name), "observations": obs} + else: + all_obs[agent_name] = obs + return all_obs def close(self): """Close the simulation.""" @@ -158,15 +167,30 @@ class PrimaiteRayEnv(gymnasium.Env): self.env = PrimaiteGymEnv(env_config=env_config) # self.env.episode_counter -= 1 self.action_space = self.env.action_space - self.observation_space = self.env.observation_space + if self.env.agent.action_masking: + self.observation_space = spaces.Dict( + {"action_mask": spaces.MultiBinary(self.env.action_space.n), "observations": self.env.observation_space} + ) + else: + self.observation_space = self.env.observation_space def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" + if self.env.agent.action_masking: + obs, *_ = self.env.reset(seed=seed) + new_obs = {"action_mask": self.env.action_masks(), "observations": obs} + return new_obs, *_ return self.env.reset(seed=seed) def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]: """Perform a step in the environment.""" - return self.env.step(action) + # if action masking is enabled, intercept the step method and add action mask to observation + if self.env.agent.action_masking: + obs, *_ = self.env.step(action) + new_obs = {"action_mask": self.game.action_mask(self.env._agent_name), "observations": obs} + return new_obs, *_ + else: + return self.env.step(action) def close(self): """Close the simulation.""" diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 70485af5..a5e39cc8 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -3,9 +3,10 @@ """Core of the PrimAITE Simulator.""" import warnings from abc import abstractmethod -from typing import Callable, Dict, List, Literal, Optional, Union +from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union from uuid import uuid4 +from prettytable import PrettyTable from pydantic import BaseModel, ConfigDict, Field, validate_call from primaite import getLogger @@ -164,18 +165,51 @@ class RequestManager(BaseModel): self.request_types.pop(name) - def get_request_types_recursively(self) -> List[List[str]]: - """Recursively generate request tree for this component.""" + def get_request_types_recursively(self) -> List[RequestFormat]: + """ + Recursively generate request tree for this component. + + :param parent_valid: Whether this sub-request's parent request was valid. This value should not be specified by + users, it is used by the recursive call. + :type parent_valid: bool + :returns: A list of tuples where the first tuple element is the request string and the second is whether that + request is currently possible to execute. + :rtype: List[Tuple[RequestFormat, bool]] + """ requests = [] for req_name, req in self.request_types.items(): if isinstance(req.func, RequestManager): - sub_requests = req.func.get_request_types_recursively() - sub_requests = [[req_name] + a for a in sub_requests] + sub_requests = req.func.get_request_types_recursively() # recurse + sub_requests = [([req_name] + a) for a in sub_requests] # prepend parent request to leaf requests.extend(sub_requests) - else: - requests.append([req_name]) + else: # leaf node found + requests.append(req_name) return requests + def show(self) -> None: + """Display all currently available requests and whether they are valid.""" + table = PrettyTable(["request"]) + table.align = "l" + table.add_rows(self.get_request_types_recursively()) + print(table) + + def check_valid(self, request: RequestFormat, context: Dict) -> bool: + """Check if this request would be valid in the current state of the simulation without invoking it.""" + + request_key = request[0] + request_options = request[1:] + + if request_key not in self.request_types: + return False + + request_type = self.request_types[request_key] + + # recurse if we are not at a leaf node + if isinstance(request_type.func, RequestManager): + return request_type.func.check_valid(request_options, context) + + return request_type.validator(request_options, context) + class SimComponent(BaseModel): """Extension of pydantic BaseModel with additional methods that must be defined by all classes in the simulator.""" diff --git a/src/primaite/simulator/domain/controller.py b/src/primaite/simulator/domain/controller.py index 37e60aaa..a264ba24 100644 --- a/src/primaite/simulator/domain/controller.py +++ b/src/primaite/simulator/domain/controller.py @@ -52,6 +52,8 @@ class GroupMembershipValidator(RequestPermissionValidator): def __call__(self, request: List[str], context: Dict) -> bool: """Permit the action if the request comes from an account which belongs to the right group.""" # if context request source is part of any groups mentioned in self.allow_groups, return true, otherwise false + if not context: + return False requestor_groups: List[str] = context["request_source"]["groups"] for allowed_group in self.allowed_groups: if allowed_group.name in requestor_groups: diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 8167a8a9..5adea6e7 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -92,6 +92,7 @@ class Service(IOSoftware): _is_service_running = Service._StateValidator(service=self, state=ServiceOperatingState.RUNNING) _is_service_stopped = Service._StateValidator(service=self, state=ServiceOperatingState.STOPPED) _is_service_paused = Service._StateValidator(service=self, state=ServiceOperatingState.PAUSED) + _is_service_disabled = Service._StateValidator(service=self, state=ServiceOperatingState.DISABLED) rm = super()._init_request_manager() rm.add_request( @@ -131,7 +132,12 @@ class Service(IOSoftware): ), ) rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable()))) - rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable()))) + rm.add_request( + "enable", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.enable()), validator=_is_service_disabled + ), + ) rm.add_request( "fix", RequestType( diff --git a/tests/assets/configs/multi_agent_session.yaml b/tests/assets/configs/multi_agent_session.yaml index 971f36f8..a2d64605 100644 --- a/tests/assets/configs/multi_agent_session.yaml +++ b/tests/assets/configs/multi_agent_session.yaml @@ -1,3 +1,10 @@ +io_settings: + save_agent_actions: false + save_step_metadata: false + save_pcap_logs: false + save_sys_logs: false + + game: max_episode_length: 128 ports: @@ -13,31 +20,105 @@ game: agents: - ref: client_2_green_user team: GREEN - type: PeriodicAgent + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 observation_space: null action_space: action_list: - type: DONOTHING - type: NODE_APPLICATION_EXECUTE - options: nodes: - node_name: client_2 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 reward_function: reward_components: - - type: DUMMY + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_2 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_2 + + - ref: client_1_green_user + team: GREEN + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + observation_space: null + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + + reward_function: + reward_components: + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_1 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_1 + + + - agent_settings: # options specific to this particular agent type, basically args of __init__(self) - start_settings: - start_step: 25 - frequency: 20 - variance: 5 - ref: data_manipulation_attacker team: RED @@ -57,6 +138,9 @@ agents: - node_name: client_1 applications: - application_name: DataManipulationBot + - node_name: client_2 + applications: + - application_name: DataManipulationBot max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 @@ -71,7 +155,7 @@ agents: frequency: 20 variance: 5 - - ref: defender1 + - ref: defender_1 team: BLUE type: ProxyAgent @@ -194,318 +278,425 @@ agents: 3: action: "NODE_SERVICE_START" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 4: action: "NODE_SERVICE_PAUSE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 5: action: "NODE_SERVICE_RESUME" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 6: action: "NODE_SERVICE_RESTART" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 7: action: "NODE_SERVICE_DISABLE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 8: action: "NODE_SERVICE_ENABLE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 9: # check database.db file action: "NODE_FILE_SCAN" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 10: - action: "NODE_FILE_CHECKHASH" + action: "NODE_FILE_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 11: action: "NODE_FILE_DELETE" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 12: action: "NODE_FILE_REPAIR" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 13: action: "NODE_SERVICE_FIX" options: - node_id: 2 - service_id: 0 + node_id: 2 + service_id: 0 14: action: "NODE_FOLDER_SCAN" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 15: - action: "NODE_FOLDER_CHECKHASH" + action: "NODE_FOLDER_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 16: action: "NODE_FOLDER_REPAIR" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 17: action: "NODE_FOLDER_RESTORE" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 18: action: "NODE_OS_SCAN" options: - node_id: 2 - 19: # shutdown client 1 + node_id: 0 + 19: action: "NODE_SHUTDOWN" options: - node_id: 5 + node_id: 0 20: - action: "NODE_STARTUP" + action: NODE_STARTUP options: - node_id: 5 + node_id: 0 21: - action: "NODE_RESET" + action: NODE_RESET options: - node_id: 5 - 22: # "ACL: ADDRULE - Block outgoing traffic from client 1" (not supported in Primaite) - action: "ROUTER_ACL_ADDRULE" + node_id: 0 + 22: + action: "NODE_OS_SCAN" options: - target_router: router_1 - position: 1 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite) - action: "ROUTER_ACL_ADDRULE" + node_id: 1 + 23: + action: "NODE_SHUTDOWN" options: - target_router: router_1 - position: 2 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 24: # block tcp traffic from client 1 to web app - action: "ROUTER_ACL_ADDRULE" + node_id: 1 + 24: + action: NODE_STARTUP options: - target_router: router_1 - position: 3 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 25: # block tcp traffic from client 2 to web app - action: "ROUTER_ACL_ADDRULE" + node_id: 1 + 25: + action: NODE_RESET options: - target_router: router_1 - position: 4 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 26: - action: "ROUTER_ACL_ADDRULE" + node_id: 1 + 26: # old action num: 18 + action: "NODE_OS_SCAN" options: - target_router: router_1 - position: 5 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + node_id: 2 27: - action: "ROUTER_ACL_ADDRULE" + action: "NODE_SHUTDOWN" options: - target_router: router_1 - position: 6 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + node_id: 2 28: - action: "ROUTER_ACL_REMOVERULE" + action: NODE_STARTUP options: - target_router: router_1 - position: 0 + node_id: 2 29: - action: "ROUTER_ACL_REMOVERULE" + action: NODE_RESET options: - target_router: router_1 - position: 1 + node_id: 2 30: - action: "ROUTER_ACL_REMOVERULE" + action: "NODE_OS_SCAN" options: - target_router: router_1 - position: 2 + node_id: 3 31: - action: "ROUTER_ACL_REMOVERULE" + action: "NODE_SHUTDOWN" options: - target_router: router_1 - position: 3 + node_id: 3 32: - action: "ROUTER_ACL_REMOVERULE" + action: NODE_STARTUP options: - target_router: router_1 - position: 4 + node_id: 3 33: - action: "ROUTER_ACL_REMOVERULE" + action: NODE_RESET options: - target_router: router_1 - position: 5 + node_id: 3 34: - action: "ROUTER_ACL_REMOVERULE" + action: "NODE_OS_SCAN" options: - target_router: router_1 - position: 6 + node_id: 4 35: - action: "ROUTER_ACL_REMOVERULE" + action: "NODE_SHUTDOWN" options: - target_router: router_1 - position: 7 + node_id: 4 36: - action: "ROUTER_ACL_REMOVERULE" + action: NODE_STARTUP options: - target_router: router_1 - position: 8 + node_id: 4 37: - action: "ROUTER_ACL_REMOVERULE" + action: NODE_RESET options: - target_router: router_1 - position: 9 + node_id: 4 38: - action: "HOST_NIC_DISABLE" + action: "NODE_OS_SCAN" options: - node_id: 0 - nic_id: 0 - 39: - action: "HOST_NIC_ENABLE" + node_id: 5 + 39: # old action num: 19 # shutdown client 1 + action: "NODE_SHUTDOWN" options: - node_id: 0 - nic_id: 0 - 40: - action: "HOST_NIC_DISABLE" + node_id: 5 + 40: # old action num: 20 + action: NODE_STARTUP options: - node_id: 1 - nic_id: 0 - 41: - action: "HOST_NIC_ENABLE" + node_id: 5 + 41: # old action num: 21 + action: NODE_RESET options: - node_id: 1 - nic_id: 0 + node_id: 5 42: - action: "HOST_NIC_DISABLE" + action: "NODE_OS_SCAN" options: - node_id: 2 - nic_id: 0 + node_id: 6 43: + action: "NODE_SHUTDOWN" + options: + node_id: 6 + 44: + action: NODE_STARTUP + options: + node_id: 6 + 45: + action: NODE_RESET + options: + node_id: 6 + + 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 1 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 2 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 48: # old action num: 24 # block tcp traffic from client 1 to web app + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 3 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 49: # old action num: 25 # block tcp traffic from client 2 to web app + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 4 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 50: # old action num: 26 + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 5 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 51: # old action num: 27 + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 6 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 52: # old action num: 28 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 0 + 53: # old action num: 29 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 1 + 54: # old action num: 30 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 2 + 55: # old action num: 31 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 3 + 56: # old action num: 32 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 4 + 57: # old action num: 33 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 5 + 58: # old action num: 34 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 6 + 59: # old action num: 35 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 7 + 60: # old action num: 36 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 8 + 61: # old action num: 37 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 9 + 62: # old action num: 38 + action: "HOST_NIC_DISABLE" + options: + node_id: 0 + nic_id: 0 + 63: # old action num: 39 + action: "HOST_NIC_ENABLE" + options: + node_id: 0 + nic_id: 0 + 64: # old action num: 40 + action: "HOST_NIC_DISABLE" + options: + node_id: 1 + nic_id: 0 + 65: # old action num: 41 + action: "HOST_NIC_ENABLE" + options: + node_id: 1 + nic_id: 0 + 66: # old action num: 42 + action: "HOST_NIC_DISABLE" + options: + node_id: 2 + nic_id: 0 + 67: # old action num: 43 action: "HOST_NIC_ENABLE" options: node_id: 2 nic_id: 0 - 44: + 68: # old action num: 44 action: "HOST_NIC_DISABLE" options: node_id: 3 nic_id: 0 - 45: + 69: # old action num: 45 action: "HOST_NIC_ENABLE" options: node_id: 3 nic_id: 0 - 46: + 70: # old action num: 46 action: "HOST_NIC_DISABLE" options: node_id: 4 nic_id: 0 - 47: + 71: # old action num: 47 action: "HOST_NIC_ENABLE" options: node_id: 4 nic_id: 0 - 48: + 72: # old action num: 48 action: "HOST_NIC_DISABLE" options: node_id: 4 nic_id: 1 - 49: + 73: # old action num: 49 action: "HOST_NIC_ENABLE" options: node_id: 4 nic_id: 1 - 50: + 74: # old action num: 50 action: "HOST_NIC_DISABLE" options: node_id: 5 nic_id: 0 - 51: + 75: # old action num: 51 action: "HOST_NIC_ENABLE" options: node_id: 5 nic_id: 0 - 52: + 76: # old action num: 52 action: "HOST_NIC_DISABLE" options: node_id: 6 nic_id: 0 - 53: + 77: # old action num: 53 action: "HOST_NIC_ENABLE" options: node_id: 6 nic_id: 0 - options: nodes: - node_name: domain_controller - node_name: web_server + applications: + - application_name: DatabaseClient + services: + - service_name: WebServer - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db + services: + - service_name: DatabaseService - node_name: backup_server - node_name: security_suite - node_name: client_1 - node_name: client_2 + max_folders_per_node: 2 max_files_per_folder: 2 max_services_per_node: 2 @@ -521,27 +712,30 @@ agents: - 192.168.10.22 - 192.168.10.110 + reward_function: reward_components: - type: DATABASE_FILE_INTEGRITY - weight: 0.5 + weight: 0.40 options: node_hostname: database_server folder_name: database file_name: database.db - - - - type: WEB_SERVER_404_PENALTY - weight: 0.5 + - type: SHARED_REWARD + weight: 1.0 options: - node_hostname: web_server - service_name: web_server_web_service + agent_name: client_1_green_user + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_2_green_user agent_settings: - # ... + flatten_obs: true + action_masking: true - - ref: defender2 + - ref: defender_2 team: BLUE type: ProxyAgent @@ -640,7 +834,11 @@ agents: - type: NODE_STARTUP - type: NODE_RESET - type: ROUTER_ACL_ADDRULE + options: + target_router: router_1 - type: ROUTER_ACL_REMOVERULE + options: + target_router: router_1 - type: HOST_NIC_ENABLE - type: HOST_NIC_DISABLE @@ -664,99 +862,196 @@ agents: 3: action: "NODE_SERVICE_START" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 4: action: "NODE_SERVICE_PAUSE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 5: action: "NODE_SERVICE_RESUME" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 6: action: "NODE_SERVICE_RESTART" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 7: action: "NODE_SERVICE_DISABLE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 8: action: "NODE_SERVICE_ENABLE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 9: # check database.db file action: "NODE_FILE_SCAN" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 10: - action: "NODE_FILE_CHECKHASH" + action: "NODE_FILE_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 11: action: "NODE_FILE_DELETE" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 12: action: "NODE_FILE_REPAIR" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 13: action: "NODE_SERVICE_FIX" options: - node_id: 2 - service_id: 0 + node_id: 2 + service_id: 0 14: action: "NODE_FOLDER_SCAN" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 15: - action: "NODE_FOLDER_CHECKHASH" + action: "NODE_FOLDER_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 16: action: "NODE_FOLDER_REPAIR" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 17: action: "NODE_FOLDER_RESTORE" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 18: action: "NODE_OS_SCAN" options: - node_id: 2 - 19: # shutdown client 1 + node_id: 0 + 19: action: "NODE_SHUTDOWN" options: - node_id: 5 + node_id: 0 20: - action: "NODE_STARTUP" + action: NODE_STARTUP options: - node_id: 5 + node_id: 0 21: - action: "NODE_RESET" + action: NODE_RESET options: - node_id: 5 - 22: # "ACL: ADDRULE - Block outgoing traffic from client 1" (not supported in Primaite) + node_id: 0 + 22: + action: "NODE_OS_SCAN" + options: + node_id: 1 + 23: + action: "NODE_SHUTDOWN" + options: + node_id: 1 + 24: + action: NODE_STARTUP + options: + node_id: 1 + 25: + action: NODE_RESET + options: + node_id: 1 + 26: # old action num: 18 + action: "NODE_OS_SCAN" + options: + node_id: 2 + 27: + action: "NODE_SHUTDOWN" + options: + node_id: 2 + 28: + action: NODE_STARTUP + options: + node_id: 2 + 29: + action: NODE_RESET + options: + node_id: 2 + 30: + action: "NODE_OS_SCAN" + options: + node_id: 3 + 31: + action: "NODE_SHUTDOWN" + options: + node_id: 3 + 32: + action: NODE_STARTUP + options: + node_id: 3 + 33: + action: NODE_RESET + options: + node_id: 3 + 34: + action: "NODE_OS_SCAN" + options: + node_id: 4 + 35: + action: "NODE_SHUTDOWN" + options: + node_id: 4 + 36: + action: NODE_STARTUP + options: + node_id: 4 + 37: + action: NODE_RESET + options: + node_id: 4 + 38: + action: "NODE_OS_SCAN" + options: + node_id: 5 + 39: # old action num: 19 # shutdown client 1 + action: "NODE_SHUTDOWN" + options: + node_id: 5 + 40: # old action num: 20 + action: NODE_STARTUP + options: + node_id: 5 + 41: # old action num: 21 + action: NODE_RESET + options: + node_id: 5 + 42: + action: "NODE_OS_SCAN" + options: + node_id: 6 + 43: + action: "NODE_SHUTDOWN" + options: + node_id: 6 + 44: + action: NODE_STARTUP + options: + node_id: 6 + 45: + action: NODE_RESET + options: + node_id: 6 + + 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" action: "ROUTER_ACL_ADDRULE" options: target_router: router_1 @@ -769,7 +1064,7 @@ agents: protocol_id: 1 source_wildcard_id: 0 dest_wildcard_id: 0 - 23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite) + 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" action: "ROUTER_ACL_ADDRULE" options: target_router: router_1 @@ -782,7 +1077,7 @@ agents: protocol_id: 1 source_wildcard_id: 0 dest_wildcard_id: 0 - 24: # block tcp traffic from client 1 to web app + 48: # old action num: 24 # block tcp traffic from client 1 to web app action: "ROUTER_ACL_ADDRULE" options: target_router: router_1 @@ -795,7 +1090,7 @@ agents: protocol_id: 3 source_wildcard_id: 0 dest_wildcard_id: 0 - 25: # block tcp traffic from client 2 to web app + 49: # old action num: 25 # block tcp traffic from client 2 to web app action: "ROUTER_ACL_ADDRULE" options: target_router: router_1 @@ -808,7 +1103,7 @@ agents: protocol_id: 3 source_wildcard_id: 0 dest_wildcard_id: 0 - 26: + 50: # old action num: 26 action: "ROUTER_ACL_ADDRULE" options: target_router: router_1 @@ -821,7 +1116,7 @@ agents: protocol_id: 3 source_wildcard_id: 0 dest_wildcard_id: 0 - 27: + 51: # old action num: 27 action: "ROUTER_ACL_ADDRULE" options: target_router: router_1 @@ -834,67 +1129,159 @@ agents: protocol_id: 3 source_wildcard_id: 0 dest_wildcard_id: 0 - 28: + 52: # old action num: 28 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 0 - 29: + 53: # old action num: 29 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 1 - 30: + 54: # old action num: 30 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 2 - 31: + 55: # old action num: 31 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 3 - 32: + 56: # old action num: 32 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 4 - 33: + 57: # old action num: 33 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 5 - 34: + 58: # old action num: 34 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 6 - 35: + 59: # old action num: 35 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 7 - 36: + 60: # old action num: 36 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 8 - 37: + 61: # old action num: 37 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 9 + 62: # old action num: 38 + action: "HOST_NIC_DISABLE" + options: + node_id: 0 + nic_id: 0 + 63: # old action num: 39 + action: "HOST_NIC_ENABLE" + options: + node_id: 0 + nic_id: 0 + 64: # old action num: 40 + action: "HOST_NIC_DISABLE" + options: + node_id: 1 + nic_id: 0 + 65: # old action num: 41 + action: "HOST_NIC_ENABLE" + options: + node_id: 1 + nic_id: 0 + 66: # old action num: 42 + action: "HOST_NIC_DISABLE" + options: + node_id: 2 + nic_id: 0 + 67: # old action num: 43 + action: "HOST_NIC_ENABLE" + options: + node_id: 2 + nic_id: 0 + 68: # old action num: 44 + action: "HOST_NIC_DISABLE" + options: + node_id: 3 + nic_id: 0 + 69: # old action num: 45 + action: "HOST_NIC_ENABLE" + options: + node_id: 3 + nic_id: 0 + 70: # old action num: 46 + action: "HOST_NIC_DISABLE" + options: + node_id: 4 + nic_id: 0 + 71: # old action num: 47 + action: "HOST_NIC_ENABLE" + options: + node_id: 4 + nic_id: 0 + 72: # old action num: 48 + action: "HOST_NIC_DISABLE" + options: + node_id: 4 + nic_id: 1 + 73: # old action num: 49 + action: "HOST_NIC_ENABLE" + options: + node_id: 4 + nic_id: 1 + 74: # old action num: 50 + action: "HOST_NIC_DISABLE" + options: + node_id: 5 + nic_id: 0 + 75: # old action num: 51 + action: "HOST_NIC_ENABLE" + options: + node_id: 5 + nic_id: 0 + 76: # old action num: 52 + action: "HOST_NIC_DISABLE" + options: + node_id: 6 + nic_id: 0 + 77: # old action num: 53 + action: "HOST_NIC_ENABLE" + options: + node_id: 6 + nic_id: 0 + options: nodes: - node_name: domain_controller - node_name: web_server + applications: + - application_name: DatabaseClient + services: + - service_name: WebServer - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db + services: + - service_name: DatabaseService - node_name: backup_server - node_name: security_suite - node_name: client_1 - node_name: client_2 + max_folders_per_node: 2 max_files_per_folder: 2 max_services_per_node: 2 @@ -913,50 +1300,63 @@ agents: reward_function: reward_components: - type: DATABASE_FILE_INTEGRITY - weight: 0.5 + weight: 0.40 options: node_hostname: database_server folder_name: database file_name: database.db - - - - type: WEB_SERVER_404_PENALTY - weight: 0.5 + - type: SHARED_REWARD + weight: 1.0 options: - node_hostname: web_server - service_name: web_server_web_service + agent_name: client_1_green_user + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_2_green_user agent_settings: - # ... - + flatten_obs: true + action_masking: true simulation: network: + nmne_config: + capture_nmne: true + nmne_capture_keywords: + - DELETE nodes: - - type: router - hostname: router_1 + - hostname: router_1 + type: router num_ports: 5 ports: 1: ip_address: 192.168.1.1 subnet_mask: 255.255.255.0 2: - ip_address: 192.168.1.1 + ip_address: 192.168.10.1 subnet_mask: 255.255.255.0 acl: - 0: + 18: action: PERMIT src_port: POSTGRES_SERVER dst_port: POSTGRES_SERVER - 1: + 19: action: PERMIT src_port: DNS dst_port: DNS + 20: + action: PERMIT + src_port: FTP + dst_port: FTP + 21: + action: PERMIT + src_port: HTTP + dst_port: HTTP 22: action: PERMIT src_port: ARP @@ -965,16 +1365,16 @@ simulation: action: PERMIT protocol: ICMP - - type: switch - hostname: switch_1 + - hostname: switch_1 + type: switch num_ports: 8 - - type: switch - hostname: switch_2 + - hostname: switch_2 + type: switch num_ports: 8 - - type: server - hostname: domain_controller + - hostname: domain_controller + type: server ip_address: 192.168.1.10 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -984,8 +1384,8 @@ simulation: domain_mapping: arcd.com: 192.168.1.12 # web server - - type: server - hostname: web_server + - hostname: web_server + type: server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -997,17 +1397,21 @@ simulation: options: db_server_ip: 192.168.1.14 - - type: server - hostname: database_server + + - hostname: database_server + type: server ip_address: 192.168.1.14 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - type: DatabaseService + options: + backup_server_ip: 192.168.1.16 + - type: FTPClient - - type: server - hostname: backup_server + - hostname: backup_server + type: server ip_address: 192.168.1.16 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -1015,8 +1419,8 @@ simulation: services: - type: FTPServer - - type: server - hostname: security_suite + - hostname: security_suite + type: server ip_address: 192.168.1.110 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -1026,8 +1430,8 @@ simulation: ip_address: 192.168.10.110 subnet_mask: 255.255.255.0 - - type: computer - hostname: client_1 + - hostname: client_1 + type: computer ip_address: 192.168.10.21 subnet_mask: 255.255.255.0 default_gateway: 192.168.10.1 @@ -1035,24 +1439,43 @@ simulation: applications: - type: DataManipulationBot options: - port_scan_p_of_success: 0.1 - data_manipulation_p_of_success: 0.1 + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 + - type: WebBrowser + options: + target_url: http://arcd.com/users/ + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 services: - type: DNSClient - - type: computer - hostname: client_2 + - hostname: client_2 + type: computer ip_address: 192.168.10.22 subnet_mask: 255.255.255.0 default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - type: WebBrowser + options: + target_url: http://arcd.com/users/ + - type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 services: - type: DNSClient + + links: - endpoint_a_hostname: router_1 endpoint_a_port: 1 diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index 54143af0..eb8103e8 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -243,25 +243,25 @@ agents: action: "NODE_FILE_SCAN" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 10: action: "NODE_FILE_CHECKHASH" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 11: action: "NODE_FILE_DELETE" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 12: action: "NODE_FILE_REPAIR" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 13: action: "NODE_SERVICE_FIX" @@ -272,22 +272,22 @@ agents: action: "NODE_FOLDER_SCAN" options: node_id: 2 - folder_id: 1 + folder_id: 0 15: action: "NODE_FOLDER_CHECKHASH" options: node_id: 2 - folder_id: 1 + folder_id: 0 16: action: "NODE_FOLDER_REPAIR" options: node_id: 2 - folder_id: 1 + folder_id: 0 17: action: "NODE_FOLDER_RESTORE" options: node_id: 2 - folder_id: 1 + folder_id: 0 18: action: "NODE_OS_SCAN" options: @@ -518,11 +518,22 @@ agents: nodes: - node_name: domain_controller - node_name: web_server + applications: + - application_name: DatabaseClient + services: + - service_name: WebServer - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db + services: + - service_name: DatabaseService - node_name: backup_server - node_name: security_suite - node_name: client_1 - node_name: client_2 + max_folders_per_node: 2 max_files_per_folder: 2 max_services_per_node: 2 @@ -557,6 +568,7 @@ agents: agent_settings: flatten_obs: true + action_masking: true @@ -634,6 +646,8 @@ simulation: dns_server: 192.168.1.10 services: - type: DatabaseService + options: + backup_server_ip: 192.168.1.16 - type: server hostname: backup_server diff --git a/tests/conftest.py b/tests/conftest.py index a60a26f1..b8b50182 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Tuple import pytest +import ray import yaml from primaite import getLogger, PRIMAITE_PATHS @@ -29,6 +30,7 @@ from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.web_server.web_server import WebServer from tests import TEST_ASSETS_ROOT +ray.init(local_mode=True) ACTION_SPACE_NODE_VALUES = 1 ACTION_SPACE_NODE_ACTION_VALUES = 1 diff --git a/tests/e2e_integration_tests/action_masking/__init__.py b/tests/e2e_integration_tests/action_masking/__init__.py new file mode 100644 index 00000000..be6c00e7 --- /dev/null +++ b/tests/e2e_integration_tests/action_masking/__init__.py @@ -0,0 +1 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK diff --git a/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py new file mode 100644 index 00000000..a299b913 --- /dev/null +++ b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py @@ -0,0 +1,158 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import importlib +from typing import Dict + +import yaml +from ray import air, init, tune +from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.examples.rl_modules.classes.action_masking_rlm import ActionMaskingTorchRLModule +from sb3_contrib import MaskablePPO + +from primaite.game.game import PrimaiteGame +from primaite.session.environment import PrimaiteGymEnv +from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv +from tests import TEST_ASSETS_ROOT + +CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml" +MARL_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml" + + +def test_sb3_action_masking(monkeypatch): + # There's no simple way of capturing what the action mask was at every step, therefore we are mocking the action + # mask function here to save the output of the action mask method and pass through the result back to the agent. + old_action_mask_method = PrimaiteGame.action_mask + mask_history = [] + + def cache_action_mask(obj, agent_name): + mask = old_action_mask_method(obj, agent_name) + mask_history.append(mask) + return mask + + # Even though it's easy to know which CAOS action the agent took by looking at agent history, we don't know which + # action map action integer that was, therefore we cache it by using monkeypatch + action_num_history = [] + + def cache_step(env, action: int): + action_num_history.append(action) + return PrimaiteGymEnv.step(env, action) + + monkeypatch.setattr(PrimaiteGame, "action_mask", cache_action_mask) + env = PrimaiteGymEnv(CFG_PATH) + monkeypatch.setattr(env, "step", lambda action: cache_step(env, action)) + + model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, batch_size=32) + model.learn(512) + + assert len(action_num_history) == len(mask_history) > 0 + # Make sure the masks had at least some False entries, if it was all True then the mask was disabled + assert any([not all(x) for x in mask_history]) + # When the agent takes action N from its action map, we need to have a look at the action mask and make sure that + # the N-th entry was True, meaning that it was a valid action at that step. + # This plucks out the mask history at step i, and at action entry a and checks that it's set to True, and this + # happens for all steps i in the episode + assert all(mask_history[i][a] for i, a in enumerate(action_num_history)) + monkeypatch.undo() + + +def test_ray_single_agent_action_masking(monkeypatch): + """Check that a Ray agent uses the action mask and never chooses invalid actions.""" + with open(CFG_PATH, "r") as f: + cfg = yaml.safe_load(f) + for agent in cfg["agents"]: + if agent["ref"] == "defender": + agent["agent_settings"]["flatten_obs"] = True + + # There's no simple way of capturing what the action mask was at every step, therefore we are mocking the step + # function to save the action mask and the agent's chosen action to a local variable. + old_step_method = PrimaiteRayEnv.step + action_num_history = [] + mask_history = [] + + def cache_step(self, action: int): + action_num_history.append(action) + obs, *_ = old_step_method(self, action) + action_mask = obs["action_mask"] + mask_history.append(action_mask) + return obs, *_ + + monkeypatch.setattr(PrimaiteRayEnv, "step", lambda *args, **kwargs: cache_step(*args, **kwargs)) + + # Configure Ray PPO to use action masking by using the ActionMaskingTorchRLModule + config = ( + PPOConfig() + .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True) + .environment(env=PrimaiteRayEnv, env_config=cfg, action_mask_key="action_mask") + .rl_module(rl_module_spec=SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule)) + .env_runners(num_env_runners=0) + .training(train_batch_size=128) + ) + algo = config.build() + algo.train() + + assert len(action_num_history) == len(mask_history) > 0 + # Make sure the masks had at least some False entries, if it was all True then the mask was disabled + assert any([not all(x) for x in mask_history]) + # When the agent takes action N from its action map, we need to have a look at the action mask and make sure that + # the N-th action was valid. + # The first step uses the action mask provided by the reset method, so we are only checking from the second step + # onward, that's why we need to use mask_history[:-1] and action_num_history[1:] + assert all(mask_history[:-1][i][a] for i, a in enumerate(action_num_history[1:])) + monkeypatch.undo() + + +def test_ray_multi_agent_action_masking(monkeypatch): + """Check that Ray agents never take invalid actions when using MARL.""" + with open(MARL_PATH, "r") as f: + cfg = yaml.safe_load(f) + + old_step_method = PrimaiteRayMARLEnv.step + action_num_history = {"defender_1": [], "defender_2": []} + mask_history = {"defender_1": [], "defender_2": []} + + def cache_step(self, actions: Dict[str, int]): + for agent_name, action in actions.items(): + action_num_history[agent_name].append(action) + obs, *_ = old_step_method(self, actions) + for ( + agent_name, + o, + ) in obs.items(): + mask_history[agent_name].append(o["action_mask"]) + return obs, *_ + + monkeypatch.setattr(PrimaiteRayMARLEnv, "step", lambda *args, **kwargs: cache_step(*args, **kwargs)) + + config = ( + PPOConfig() + .multi_agent( + policies={ + "defender_1", + "defender_2", + }, # These names are the same as the agents defined in the example config. + policy_mapping_fn=lambda agent_id, *args, **kwargs: agent_id, + ) + .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True) + .environment(env=PrimaiteRayMARLEnv, env_config=cfg, action_mask_key="action_mask") + .rl_module( + rl_module_spec=MultiAgentRLModuleSpec( + module_specs={ + "defender_1": SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule), + "defender_2": SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule), + } + ) + ) + .env_runners(num_env_runners=0) + .training(train_batch_size=128) + ) + algo = config.build() + algo.train() + + for agent_name in ["defender_1", "defender_2"]: + act_hist = action_num_history[agent_name] + mask_hist = mask_history[agent_name] + assert len(act_hist) == len(mask_hist) > 0 + assert any([not all(x) for x in mask_hist]) + assert all(mask_hist[:-1][i][a] for i, a in enumerate(act_hist[1:])) + monkeypatch.undo() diff --git a/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py index 96ec799c..e015c33c 100644 --- a/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py +++ b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py @@ -16,8 +16,6 @@ def test_rllib_multi_agent_compatibility(): with open(MULTI_AGENT_PATH, "r") as f: cfg = yaml.safe_load(f) - ray.init() - config = ( PPOConfig() .environment(env=PrimaiteRayMARLEnv, env_config=cfg) @@ -39,4 +37,3 @@ def test_rllib_multi_agent_compatibility(): ), param_space=config, ).fit() - ray.shutdown() diff --git a/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py index d6cacfd2..a02a078c 100644 --- a/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py +++ b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py @@ -20,9 +20,6 @@ def test_rllib_single_agent_compatibility(): game = PrimaiteGame.from_config(cfg) - ray.shutdown() - ray.init() - env_config = {"game": game} config = { "env": PrimaiteRayEnv, @@ -41,4 +38,3 @@ def test_rllib_single_agent_compatibility(): assert save_file.exists() save_file.unlink() # clean up - ray.shutdown() diff --git a/tests/e2e_integration_tests/test_environment.py b/tests/e2e_integration_tests/test_environment.py index c8238aba..dcd51193 100644 --- a/tests/e2e_integration_tests/test_environment.py +++ b/tests/e2e_integration_tests/test_environment.py @@ -65,25 +65,25 @@ class TestPrimaiteEnvironment: cfg = yaml.safe_load(f) env = PrimaiteRayMARLEnv(env_config=cfg) - assert set(env._agent_ids) == {"defender1", "defender2"} + assert set(env._agent_ids) == {"defender_1", "defender_2"} assert len(env.agents) == 2 - defender1 = env.agents["defender1"] - defender2 = env.agents["defender2"] - assert (num_actions_1 := len(defender1.action_manager.action_map)) == 54 - assert (num_actions_2 := len(defender2.action_manager.action_map)) == 38 + defender_1 = env.agents["defender_1"] + defender_2 = env.agents["defender_2"] + assert (num_actions_1 := len(defender_1.action_manager.action_map)) == 78 + assert (num_actions_2 := len(defender_2.action_manager.action_map)) == 78 # ensure we can run all valid actions without error for act_1 in range(num_actions_1): - env.step({"defender1": act_1, "defender2": 0}) + env.step({"defender_1": act_1, "defender_2": 0}) for act_2 in range(num_actions_2): - env.step({"defender1": 0, "defender2": act_2}) + env.step({"defender_1": 0, "defender_2": act_2}) # ensure we get error when taking an invalid action with pytest.raises(KeyError): - env.step({"defender1": num_actions_1, "defender2": 0}) + env.step({"defender_1": num_actions_1, "defender_2": 0}) with pytest.raises(KeyError): - env.step({"defender1": 0, "defender2": num_actions_2}) + env.step({"defender_1": 0, "defender_2": num_actions_2}) def test_error_thrown_on_bad_configuration(self): """Make sure we throw an error when the config is bad.""" diff --git a/tests/integration_tests/game_layer/actions/test_configure_actions.py b/tests/integration_tests/game_layer/actions/test_configure_actions.py index b7acc8a8..0c9ec6f0 100644 --- a/tests/integration_tests/game_layer/actions/test_configure_actions.py +++ b/tests/integration_tests/game_layer/actions/test_configure_actions.py @@ -99,7 +99,7 @@ class TestConfigureDatabaseAction: game.step() assert db_client.server_ip_address == old_ip - assert db_client.server_password is "admin123" + assert db_client.server_password == "admin123" class TestConfigureRansomwareScriptAction: diff --git a/tests/integration_tests/game_layer/test_action_mask.py b/tests/integration_tests/game_layer/test_action_mask.py new file mode 100644 index 00000000..64464724 --- /dev/null +++ b/tests/integration_tests/game_layer/test_action_mask.py @@ -0,0 +1,161 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from primaite.session.environment import PrimaiteGymEnv +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.host.host_node import HostNode +from primaite.simulator.system.services.service import ServiceOperatingState +from tests.conftest import TEST_ASSETS_ROOT + +CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml" + + +def test_mask_contents_correct(): + env = PrimaiteGymEnv(CFG_PATH) + game = env.game + sim = game.simulation + net = sim.network + mask = game.action_mask("defender") + agent = env.agent + node_list = agent.action_manager.node_names + action_map = agent.action_manager.action_map + + # CHECK NIC ENABLE/DISABLE ACTIONS + for action_num, action in action_map.items(): + mask = game.action_mask("defender") + act_type, act_params = action + + if act_type == "NODE_NIC_ENABLE": + node_name = node_list[act_params["node_id"]] + node_obj = net.get_node_by_hostname(node_name) + nic_obj = node_obj.network_interface[act_params["nic_id"] + 1] + assert nic_obj.enabled + assert not mask[action_num] + nic_obj.disable() + mask = game.action_mask("defender") + assert mask[action_num] + nic_obj.enable() + + if act_type == "NODE_NIC_DISABLE": + node_name = node_list[act_params["node_id"]] + node_obj = net.get_node_by_hostname(node_name) + nic_obj = node_obj.network_interface[act_params["nic_id"] + 1] + assert nic_obj.enabled + assert mask[action_num] + nic_obj.disable() + mask = game.action_mask("defender") + assert not mask[action_num] + nic_obj.enable() + + if act_type == "ROUTER_ACL_ADDRULE": + assert mask[action_num] + + if act_type == "ROUTER_ACL_REMOVERULE": + assert mask[action_num] + + if act_type == "NODE_RESET": + node_name = node_list[act_params["node_id"]] + node_obj = net.get_node_by_hostname(node_name) + assert node_obj.operating_state is NodeOperatingState.ON + assert mask[action_num] + node_obj.operating_state = NodeOperatingState.OFF + mask = game.action_mask("defender") + assert not mask[action_num] + node_obj.operating_state = NodeOperatingState.ON + + if act_type == "NODE_SHUTDOWN": + node_name = node_list[act_params["node_id"]] + node_obj = net.get_node_by_hostname(node_name) + assert node_obj.operating_state is NodeOperatingState.ON + assert mask[action_num] + node_obj.operating_state = NodeOperatingState.OFF + mask = game.action_mask("defender") + assert not mask[action_num] + node_obj.operating_state = NodeOperatingState.ON + + if act_type == "NODE_OS_SCAN": + node_name = node_list[act_params["node_id"]] + node_obj = net.get_node_by_hostname(node_name) + assert node_obj.operating_state is NodeOperatingState.ON + assert mask[action_num] + node_obj.operating_state = NodeOperatingState.OFF + mask = game.action_mask("defender") + assert not mask[action_num] + node_obj.operating_state = NodeOperatingState.ON + + if act_type == "NODE_STARTUP": + node_name = node_list[act_params["node_id"]] + node_obj = net.get_node_by_hostname(node_name) + assert node_obj.operating_state is NodeOperatingState.ON + assert not mask[action_num] + node_obj.operating_state = NodeOperatingState.OFF + mask = game.action_mask("defender") + assert mask[action_num] + node_obj.operating_state = NodeOperatingState.ON + + if act_type == "DONOTHING": + assert mask[action_num] + + if act_type == "NODE_SERVICE_DISABLE": + assert mask[action_num] + + if act_type in ["NODE_SERVICE_SCAN", "NODE_SERVICE_STOP", "NODE_SERVICE_PAUSE"]: + node_name = node_list[act_params["node_id"]] + service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]] + node_obj = net.get_node_by_hostname(node_name) + service_obj = node_obj.software_manager.software.get(service_name) + assert service_obj.operating_state is ServiceOperatingState.RUNNING + assert mask[action_num] + service_obj.operating_state = ServiceOperatingState.DISABLED + mask = game.action_mask("defender") + assert not mask[action_num] + service_obj.operating_state = ServiceOperatingState.RUNNING + + if act_type == "NODE_SERVICE_RESUME": + node_name = node_list[act_params["node_id"]] + service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]] + node_obj = net.get_node_by_hostname(node_name) + service_obj = node_obj.software_manager.software.get(service_name) + assert service_obj.operating_state is ServiceOperatingState.RUNNING + assert not mask[action_num] + service_obj.operating_state = ServiceOperatingState.PAUSED + mask = game.action_mask("defender") + assert mask[action_num] + service_obj.operating_state = ServiceOperatingState.RUNNING + + if act_type == "NODE_SERVICE_START": + node_name = node_list[act_params["node_id"]] + service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]] + node_obj = net.get_node_by_hostname(node_name) + service_obj = node_obj.software_manager.software.get(service_name) + assert service_obj.operating_state is ServiceOperatingState.RUNNING + assert not mask[action_num] + service_obj.operating_state = ServiceOperatingState.STOPPED + mask = game.action_mask("defender") + assert mask[action_num] + service_obj.operating_state = ServiceOperatingState.RUNNING + + if act_type == "NODE_SERVICE_ENABLE": + node_name = node_list[act_params["node_id"]] + service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]] + node_obj = net.get_node_by_hostname(node_name) + service_obj = node_obj.software_manager.software.get(service_name) + assert service_obj.operating_state is ServiceOperatingState.RUNNING + assert not mask[action_num] + service_obj.operating_state = ServiceOperatingState.DISABLED + mask = game.action_mask("defender") + assert mask[action_num] + service_obj.operating_state = ServiceOperatingState.RUNNING + + if act_type in ["NODE_FILE_SCAN", "NODE_FILE_CHECKHASH", "NODE_FILE_DELETE"]: + node_name = node_list[act_params["node_id"]] + folder_name = agent.action_manager.get_folder_name_by_idx(act_params["node_id"], act_params["folder_id"]) + file_name = agent.action_manager.get_file_name_by_idx( + act_params["node_id"], act_params["folder_id"], act_params["file_id"] + ) + node_obj = net.get_node_by_hostname(node_name) + file_obj = node_obj.file_system.get_file(folder_name, file_name, include_deleted=True) + assert not file_obj.deleted + assert mask[action_num] + service_obj.operating_state = ServiceOperatingState.DISABLED + mask = game.action_mask("defender") + assert mask[action_num] + service_obj.operating_state = ServiceOperatingState.RUNNING