From cc721056d89563d64aae94ae9c936480a7c6388a Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Fri, 8 Mar 2024 19:32:07 +0000 Subject: [PATCH] #2350: configurable NMNE category thresholds --- .../_package_data/data_manipulation.yaml | 5 + .../agent/observations/nic_observations.py | 175 ++++++++++++++++++ .../agent/observations/node_observations.py | 3 +- .../game/agent/observations/observations.py | 102 ---------- src/primaite/game/game.py | 7 +- ...software_installation_and_configuration.py | 11 +- .../test_game_options_config.py | 25 +++ .../observations/test_observations.py | 42 ++++- .../network/test_capture_nmne.py | 2 +- 9 files changed, 261 insertions(+), 111 deletions(-) create mode 100644 src/primaite/game/agent/observations/nic_observations.py create mode 100644 tests/integration_tests/configuration_file_parsing/test_game_options_config.py diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml index dffb40ea..47204878 100644 --- a/src/primaite/config/_package_data/data_manipulation.yaml +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -30,6 +30,11 @@ game: - ICMP - TCP - UDP + thresholds: + nmne: + high: 10 + medium: 5 + low: 0 agents: - ref: client_2_green_user diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py new file mode 100644 index 00000000..39298ffe --- /dev/null +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -0,0 +1,175 @@ +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING + +from gymnasium import spaces + +from primaite.game.agent.observations.observations import AbstractObservation +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.simulator.network.nmne import CAPTURE_NMNE + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class NicObservation(AbstractObservation): + """Observation of a Network Interface Card (NIC) in the network.""" + + low_nmne_threshold: int = 0 + """The minimum number of malicious network events to be considered low.""" + med_nmne_threshold: int = 5 + """The minimum number of malicious network events to be considered medium.""" + high_nmne_threshold: int = 10 + """The minimum number of malicious network events to be considered high.""" + + @property + def default_observation(self) -> Dict: + """The default NIC observation dict.""" + data = {"nic_status": 0} + if CAPTURE_NMNE: + data.update({"nmne": {"inbound": 0, "outbound": 0}}) + + return data + + def __init__( + self, + where: Optional[Tuple[str]] = None, + low_nmne_threshold: Optional[int] = 0, + med_nmne_threshold: Optional[int] = 5, + high_nmne_threshold: Optional[int] = 10, + ) -> None: + """Initialise NIC observation. + + :param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical + example may look like this: + ['network','nodes',,'NICs',] + If None, this denotes that the NIC does not exist and the observation will be populated with zeroes. + :type where: Optional[Tuple[str]], optional + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + + if low_nmne_threshold or med_nmne_threshold or high_nmne_threshold: + self._validate_nmne_categories( + low_nmne_threshold=low_nmne_threshold, + med_nmne_threshold=med_nmne_threshold, + high_nmne_threshold=high_nmne_threshold, + ) + + def _validate_nmne_categories( + self, low_nmne_threshold: int = 0, med_nmne_threshold: int = 5, high_nmne_threshold: int = 10 + ): + """ + Validates the nmne threshold config. + + If the configuration is valid, the thresholds will be set, otherwise, an exception is raised. + + :param: low_nmne_threshold: The minimum number of malicious network events to be considered low + :param: med_nmne_threshold: The minimum number of malicious network events to be considered medium + :param: high_nmne_threshold: The minimum number of malicious network events to be considered high + """ + if high_nmne_threshold <= med_nmne_threshold: + raise Exception( + f"nmne_categories: high nmne count ({high_nmne_threshold}) must be greater " + f"than medium nmne count ({med_nmne_threshold})" + ) + + if med_nmne_threshold <= low_nmne_threshold: + raise Exception( + f"nmne_categories: medium nmne count ({med_nmne_threshold}) must be greater " + f"than low nmne count ({low_nmne_threshold})" + ) + + self.high_nmne_threshold = high_nmne_threshold + self.med_nmne_threshold = med_nmne_threshold + self.low_nmne_threshold = low_nmne_threshold + + def _categorise_mne_count(self, nmne_count: int) -> int: + """ + Categorise the number of Malicious Network Events (NMNEs) into discrete bins. + + This helps in classifying the severity or volume of MNEs into manageable levels for the agent. + + Bins are defined as follows: + - 0: No MNEs detected (0 events). + - 1: Low number of MNEs (default 1-5 events). + - 2: Moderate number of MNEs (default 6-10 events). + - 3: High number of MNEs (default more than 10 events). + + :param nmne_count: Number of MNEs detected. + :return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count. + """ + if nmne_count > self.high_nmne_threshold: + return 3 + elif nmne_count > self.med_nmne_threshold: + return 2 + elif nmne_count > self.low_nmne_threshold: + return 1 + return 0 + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + nic_state = access_from_nested_dict(state, self.where) + + if nic_state is NOT_PRESENT_IN_STATE: + return self.default_observation + else: + obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2} + if CAPTURE_NMNE: + obs_dict.update({"nmne": {}}) + direction_dict = nic_state["nmne"].get("direction", {}) + inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) + inbound_count = inbound_keywords.get("*", 0) + outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {}) + outbound_count = outbound_keywords.get("*", 0) + obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count) + obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count) + return obs_dict + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + return spaces.Dict( + { + "nic_status": spaces.Discrete(3), + "nmne": spaces.Dict({"inbound": spaces.Discrete(6), "outbound": spaces.Discrete(6)}), + } + ) + + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation": + """Create NIC observation from a config. + + :param config: Dictionary containing the configuration for this NIC observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent + node. A typical location for a node ``where`` can be: ['network','nodes',] + :type parent_where: Optional[List[str]] + :return: Constructed NIC observation + :rtype: NicObservation + """ + low_nmne_threshold = None + med_nmne_threshold = None + high_nmne_threshold = None + + if game and game.options and game.options.thresholds and game.options.thresholds.get("nmne"): + threshold = game.options.thresholds["nmne"] + + low_nmne_threshold = int(threshold.get("low")) if threshold.get("low") is not None else None + med_nmne_threshold = int(threshold.get("medium")) if threshold.get("medium") is not None else None + high_nmne_threshold = int(threshold.get("high")) if threshold.get("high") is not None else None + + return cls( + where=parent_where + ["NICs", config["nic_num"]], + low_nmne_threshold=low_nmne_threshold, + med_nmne_threshold=med_nmne_threshold, + high_nmne_threshold=high_nmne_threshold, + ) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 93c6765b..f211a6b5 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -4,7 +4,8 @@ from gymnasium import spaces from primaite import getLogger from primaite.game.agent.observations.file_system_observations import FolderObservation -from primaite.game.agent.observations.observations import AbstractObservation, NicObservation +from primaite.game.agent.observations.nic_observations import NicObservation +from primaite.game.agent.observations.observations import AbstractObservation from primaite.game.agent.observations.software_observation import ServiceObservation from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py index 10e69ea5..6236b00d 100644 --- a/src/primaite/game/agent/observations/observations.py +++ b/src/primaite/game/agent/observations/observations.py @@ -7,7 +7,6 @@ from gymnasium import spaces from primaite import getLogger from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -from primaite.simulator.network.nmne import CAPTURE_NMNE _LOGGER = getLogger(__name__) @@ -116,107 +115,6 @@ class LinkObservation(AbstractObservation): return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]]) -class NicObservation(AbstractObservation): - """Observation of a Network Interface Card (NIC) in the network.""" - - @property - def default_observation(self) -> Dict: - """The default NIC observation dict.""" - data = {"nic_status": 0} - if CAPTURE_NMNE: - data.update({"nmne": {"inbound": 0, "outbound": 0}}) - - return data - - def __init__(self, where: Optional[Tuple[str]] = None) -> None: - """Initialise NIC observation. - - :param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical - example may look like this: - ['network','nodes',,'NICs',] - If None, this denotes that the NIC does not exist and the observation will be populated with zeroes. - :type where: Optional[Tuple[str]], optional - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - - def _categorise_mne_count(self, nmne_count: int) -> int: - """ - Categorise the number of Malicious Network Events (NMNEs) into discrete bins. - - This helps in classifying the severity or volume of MNEs into manageable levels for the agent. - - Bins are defined as follows: - - 0: No MNEs detected (0 events). - - 1: Low number of MNEs (1-5 events). - - 2: Moderate number of MNEs (6-10 events). - - 3: High number of MNEs (more than 10 events). - - :param nmne_count: Number of MNEs detected. - :return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count. - """ - if nmne_count > 10: - return 3 - elif nmne_count > 5: - return 2 - elif nmne_count > 0: - return 1 - return 0 - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - nic_state = access_from_nested_dict(state, self.where) - - if nic_state is NOT_PRESENT_IN_STATE: - return self.default_observation - else: - obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2} - if CAPTURE_NMNE: - obs_dict.update({"nmne": {}}) - direction_dict = nic_state["nmne"].get("direction", {}) - inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) - inbound_count = inbound_keywords.get("*", 0) - outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {}) - outbound_count = outbound_keywords.get("*", 0) - obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count) - obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count) - return obs_dict - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - return spaces.Dict( - { - "nic_status": spaces.Discrete(3), - "nmne": spaces.Dict({"inbound": spaces.Discrete(6), "outbound": spaces.Discrete(6)}), - } - ) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation": - """Create NIC observation from a config. - - :param config: Dictionary containing the configuration for this NIC observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent - node. A typical location for a node ``where`` can be: ['network','nodes',] - :type parent_where: Optional[List[str]] - :return: Constructed NIC observation - :rtype: NicObservation - """ - return cls(where=parent_where + ["NICs", config["nic_num"]]) - - class AclObservation(AbstractObservation): """Observation of an Access Control List (ACL) in the network.""" diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 33f9186b..3edb8651 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -1,6 +1,6 @@ """PrimAITE game - Encapsulates the simulation and agents.""" from ipaddress import IPv4Address -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple from pydantic import BaseModel, ConfigDict @@ -67,8 +67,13 @@ class PrimaiteGameOptions(BaseModel): model_config = ConfigDict(extra="forbid") max_episode_length: int = 256 + """Maximum number of episodes for the PrimAITE game.""" ports: List[str] + """A whitelist of available ports in the simulation.""" protocols: List[str] + """A whitelist of available protocols in the simulation.""" + thresholds: Optional[Dict] = {} + """A dict containing the thresholds used for determining what is acceptable during observations.""" class PrimaiteGame: diff --git a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py index f993af5f..a5fcb372 100644 --- a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py +++ b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py @@ -5,8 +5,9 @@ from typing import Union import yaml from primaite.config.load import data_manipulation_config_path -from primaite.game.agent.interface import ProxyAgent, RandomAgent +from primaite.game.agent.interface import ProxyAgent from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent +from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer @@ -43,15 +44,15 @@ def test_example_config(): # green agent 1 assert "client_2_green_user" in game.agents - assert isinstance(game.agents["client_2_green_user"], RandomAgent) + assert isinstance(game.agents["client_2_green_user"], ProbabilisticAgent) # green agent 2 assert "client_1_green_user" in game.agents - assert isinstance(game.agents["client_1_green_user"], RandomAgent) + assert isinstance(game.agents["client_1_green_user"], ProbabilisticAgent) # red agent - assert "client_1_data_manipulation_red_bot" in game.agents - assert isinstance(game.agents["client_1_data_manipulation_red_bot"], DataManipulationAgent) + assert "data_manipulation_attacker" in game.agents + assert isinstance(game.agents["data_manipulation_attacker"], DataManipulationAgent) # blue agent assert "defender" in game.agents diff --git a/tests/integration_tests/configuration_file_parsing/test_game_options_config.py b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py new file mode 100644 index 00000000..adbbf2b5 --- /dev/null +++ b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py @@ -0,0 +1,25 @@ +from pathlib import Path +from typing import Union + +import yaml + +from primaite.config.load import data_manipulation_config_path +from primaite.game.game import PrimaiteGame +from tests import TEST_ASSETS_ROOT + +BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" + + +def load_config(config_path: Union[str, Path]) -> PrimaiteGame: + """Returns a PrimaiteGame object which loads the contents of a given yaml path.""" + with open(config_path, "r") as f: + cfg = yaml.safe_load(f) + + return PrimaiteGame.from_config(cfg) + + +def test_thresholds(): + """Test that the game options can be parsed correctly.""" + game = load_config(data_manipulation_config_path()) + + assert game.options.thresholds is not None diff --git a/tests/integration_tests/game_layer/observations/test_observations.py b/tests/integration_tests/game_layer/observations/test_observations.py index eccda238..97df7882 100644 --- a/tests/integration_tests/game_layer/observations/test_observations.py +++ b/tests/integration_tests/game_layer/observations/test_observations.py @@ -1,6 +1,6 @@ import pytest -from primaite.game.agent.observations.observations import NicObservation +from primaite.game.agent.observations.nic_observations import NicObservation from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.sim_container import Simulation @@ -33,3 +33,43 @@ def test_nic(simulation): nic.disable() observation_state = nic_obs.observe(simulation.describe_state()) assert observation_state.get("nic_status") == 2 # disabled + + +def test_nic_categories(simulation): + """Test the NIC observation nmne count categories.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + + nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1]) + + assert nic_obs.high_nmne_threshold == 10 # default + assert nic_obs.med_nmne_threshold == 5 # default + assert nic_obs.low_nmne_threshold == 0 # default + + nic_obs = NicObservation( + where=["network", "nodes", pc.hostname, "NICs", 1], + low_nmne_threshold=3, + med_nmne_threshold=6, + high_nmne_threshold=9, + ) + + assert nic_obs.high_nmne_threshold == 9 + assert nic_obs.med_nmne_threshold == 6 + assert nic_obs.low_nmne_threshold == 3 + + with pytest.raises(Exception): + # should throw an error + NicObservation( + where=["network", "nodes", pc.hostname, "NICs", 1], + low_nmne_threshold=9, + med_nmne_threshold=6, + high_nmne_threshold=9, + ) + + with pytest.raises(Exception): + # should throw an error + NicObservation( + where=["network", "nodes", pc.hostname, "NICs", 1], + low_nmne_threshold=3, + med_nmne_threshold=9, + high_nmne_threshold=9, + ) diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py index 4bbde32f..32d4ee8f 100644 --- a/tests/integration_tests/network/test_capture_nmne.py +++ b/tests/integration_tests/network/test_capture_nmne.py @@ -1,4 +1,4 @@ -from primaite.game.agent.observations.observations import NicObservation +from primaite.game.agent.observations.nic_observations import NicObservation from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.nmne import set_nmne_config from primaite.simulator.sim_container import Simulation