diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py index 1c73d026..fe959c9f 100644 --- a/src/primaite/game/agent/observations/file_system_observations.py +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -26,7 +26,13 @@ class FileObservation(AbstractObservation, identifier="FILE"): file_system_requires_scan: Optional[bool] = None """If True, the file must be scanned to update the health state. Tf False, the true state is always shown.""" - def __init__(self, where: WhereType, include_num_access: bool, file_system_requires_scan: bool) -> None: + def __init__( + self, + where: WhereType, + include_num_access: bool, + file_system_requires_scan: bool, + thresholds: Optional[Dict] = {}, + ) -> None: """ Initialise a file observation instance. @@ -48,10 +54,22 @@ class FileObservation(AbstractObservation, identifier="FILE"): if self.include_num_access: self.default_observation["num_access"] = 0 - # TODO: allow these to be configured in yaml - self.high_threshold = 10 - self.med_threshold = 5 - self.low_threshold = 0 + if thresholds.get("file_access") is None: + self.low_threshold = 0 + self.med_threshold = 5 + self.high_threshold = 10 + else: + if self._validate_thresholds( + thresholds=[ + thresholds.get("file_access")["low"], + thresholds.get("file_access")["medium"], + thresholds.get("file_access")["high"], + ], + threshold_identifier="file_access", + ): + self.low_threshold = thresholds.get("file_access")["low"] + self.med_threshold = thresholds.get("file_access")["medium"] + self.high_threshold = thresholds.get("file_access")["high"] def _categorise_num_access(self, num_access: int) -> int: """ @@ -122,6 +140,7 @@ class FileObservation(AbstractObservation, identifier="FILE"): where=parent_where + ["files", config.file_name], include_num_access=config.include_num_access, file_system_requires_scan=config.file_system_requires_scan, + thresholds=config.thresholds, ) @@ -149,6 +168,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): num_files: int, include_num_access: bool, file_system_requires_scan: bool, + thresholds: Optional[Dict] = {}, ) -> None: """ Initialise a folder observation instance. @@ -170,6 +190,23 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): self.file_system_requires_scan: bool = file_system_requires_scan + if thresholds.get("file_access") is None: + self.low_threshold = 0 + self.med_threshold = 5 + self.high_threshold = 10 + else: + if self._validate_thresholds( + thresholds=[ + thresholds.get("file_access")["low"], + thresholds.get("file_access")["medium"], + thresholds.get("file_access")["high"], + ], + threshold_identifier="file_access", + ): + self.low_threshold = thresholds.get("file_access")["low"] + self.med_threshold = thresholds.get("file_access")["medium"] + self.high_threshold = thresholds.get("file_access")["high"] + self.files: List[FileObservation] = files while len(self.files) < num_files: self.files.append( @@ -177,6 +214,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): where=None, include_num_access=include_num_access, file_system_requires_scan=self.file_system_requires_scan, + thresholds=thresholds, ) ) while len(self.files) > num_files: @@ -248,6 +286,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): for file_config in config.files: file_config.include_num_access = config.include_num_access file_config.file_system_requires_scan = config.file_system_requires_scan + file_config.thresholds = config.thresholds files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files] return cls( @@ -256,4 +295,5 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): num_files=config.num_files, include_num_access=config.include_num_access, file_system_requires_scan=config.file_system_requires_scan, + thresholds=config.thresholds, ) diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 4419ccc7..fa7ceae5 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -151,7 +151,13 @@ class HostObservation(AbstractObservation, identifier="HOST"): self.nics: List[NICObservation] = network_interfaces while len(self.nics) < num_nics: - self.nics.append(NICObservation(where=None, include_nmne=include_nmne, monitored_traffic=monitored_traffic)) + self.nics.append( + NICObservation( + where=None, + include_nmne=include_nmne, + monitored_traffic=monitored_traffic, + ) + ) while len(self.nics) > num_nics: truncated_nic = self.nics.pop() msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}" @@ -257,12 +263,16 @@ class HostObservation(AbstractObservation, identifier="HOST"): where = parent_where + [config.hostname] # Pass down shared/common config items + for app_config in config.applications: + app_config.thresholds = config.thresholds for folder_config in config.folders: folder_config.include_num_access = config.include_num_access folder_config.num_files = config.num_files folder_config.file_system_requires_scan = config.file_system_requires_scan + folder_config.thresholds = config.thresholds for nic_config in config.network_interfaces: nic_config.include_nmne = config.include_nmne + nic_config.thresholds = config.thresholds services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services] applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications] @@ -273,7 +283,10 @@ class HostObservation(AbstractObservation, identifier="HOST"): count = 1 while len(nics) < config.num_nics: nic_config = NICObservation.ConfigSchema( - nic_num=count, include_nmne=config.include_nmne, monitored_traffic=config.monitored_traffic + nic_num=count, + include_nmne=config.include_nmne, + monitored_traffic=config.monitored_traffic, + thresholds=config.thresholds, ) nics.append(NICObservation.from_config(config=nic_config, parent_where=where)) count += 1 diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 002ee4da..48fa11dc 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -24,7 +24,13 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): monitored_traffic: Optional[Dict] = None """A dict containing which traffic types are to be included in the observation.""" - def __init__(self, where: WhereType, include_nmne: bool, monitored_traffic: Optional[Dict] = None) -> None: + def __init__( + self, + where: WhereType, + include_nmne: bool, + monitored_traffic: Optional[Dict] = None, + thresholds: Optional[Dict] = {}, + ) -> None: """ Initialise a network interface observation instance. @@ -44,10 +50,22 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): self.nmne_inbound_last_step: int = 0 self.nmne_outbound_last_step: int = 0 - # TODO: allow these to be configured in yaml - self.high_nmne_threshold = 10 - self.med_nmne_threshold = 5 - self.low_nmne_threshold = 0 + if thresholds.get("nmne") is None: + self.low_threshold = 0 + self.med_threshold = 5 + self.high_threshold = 10 + else: + if self._validate_thresholds( + thresholds=[ + thresholds.get("nmne")["low"], + thresholds.get("nmne")["medium"], + thresholds.get("nmne")["high"], + ], + threshold_identifier="nmne", + ): + self.low_threshold = thresholds.get("nmne")["low"] + self.med_threshold = thresholds.get("nmne")["medium"] + self.high_threshold = thresholds.get("nmne")["high"] self.monitored_traffic = monitored_traffic if self.monitored_traffic: @@ -86,11 +104,11 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): :param nmne_count: Number of MNEs detected. :return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count. """ - if nmne_count > self.high_nmne_threshold: + if nmne_count > self.high_threshold: return 3 - elif nmne_count > self.med_nmne_threshold: + elif nmne_count > self.med_threshold: return 2 - elif nmne_count > self.low_nmne_threshold: + elif nmne_count > self.low_threshold: return 1 return 0 @@ -224,6 +242,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): where=parent_where + ["NICs", config.nic_num], include_nmne=config.include_nmne, monitored_traffic=config.monitored_traffic, + thresholds=config.thresholds, ) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index e263cadb..91bf402e 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -195,6 +195,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): host_config.file_system_requires_scan = config.file_system_requires_scan if host_config.include_users is None: host_config.include_users = config.include_users + if host_config.thresholds is None: + host_config.thresholds = config.thresholds for router_config in config.routers: if router_config.num_ports is None: @@ -211,6 +213,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): router_config.num_rules = config.num_rules if router_config.include_users is None: router_config.include_users = config.include_users + if router_config.thresholds is None: + router_config.thresholds = config.thresholds for firewall_config in config.firewalls: if firewall_config.ip_list is None: @@ -225,6 +229,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): firewall_config.num_rules = config.num_rules if firewall_config.include_users is None: firewall_config.include_users = config.include_users + if firewall_config.thresholds is None: + firewall_config.thresholds = config.thresholds hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts] routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers] diff --git a/src/primaite/game/agent/observations/observation_manager.py b/src/primaite/game/agent/observations/observation_manager.py index 9b20fdcb..cc32918c 100644 --- a/src/primaite/game/agent/observations/observation_manager.py +++ b/src/primaite/game/agent/observations/observation_manager.py @@ -113,7 +113,9 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"): instances = dict() for component in config.components: obs_class = AbstractObservation._registry[component.type] - obs_instance = obs_class.from_config(config=obs_class.ConfigSchema(**component.options)) + obs_instance = obs_class.from_config( + config=obs_class.ConfigSchema(**component.options, thresholds=config.thresholds) + ) instances[component.label] = obs_instance return cls(components=instances) @@ -176,7 +178,7 @@ class ObservationManager: return self.obs.space @classmethod - def from_config(cls, config: Optional[Dict]) -> "ObservationManager": + def from_config(cls, config: Optional[Dict], thresholds: Optional[Dict] = {}) -> "ObservationManager": """ Create observation space from a config. @@ -187,11 +189,15 @@ class ObservationManager: AbstractObservation options: this must adhere to the chosen observation type's ConfigSchema nested class. :type config: Dict + :param thresholds: Dictionary containing the observation thresholds. + :type thresholds: Optional[Dict] """ if config is None: return cls(NullObservation()) obs_type = config["type"] obs_class = AbstractObservation._registry[obs_type] - observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"])) + observation = obs_class.from_config( + config=obs_class.ConfigSchema(**config["options"], thresholds=thresholds), + ) obs_manager = cls(observation) return obs_manager diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py index a9663c56..0b209f52 100644 --- a/src/primaite/game/agent/observations/observations.py +++ b/src/primaite/game/agent/observations/observations.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK """Manages the observation space for the agent.""" from abc import ABC, abstractmethod -from typing import Any, Dict, Iterable, Optional, Type, Union +from typing import Any, Dict, Iterable, List, Optional, Type, Union from gymnasium import spaces from gymnasium.core import ObsType @@ -19,6 +19,9 @@ class AbstractObservation(ABC): class ConfigSchema(ABC, BaseModel): """Config schema for observations.""" + thresholds: Optional[Dict] = None + """A dict containing the observation thresholds.""" + model_config = ConfigDict(extra="forbid") _registry: Dict[str, Type["AbstractObservation"]] = {} @@ -67,3 +70,34 @@ class AbstractObservation(ABC): def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> "AbstractObservation": """Create this observation space component form a serialised format.""" return cls() + + def _validate_thresholds(self, thresholds: List[int] = None, threshold_identifier: Optional[str] = "") -> bool: + """ + Method that checks if the thresholds are non overlapping and in the correct (ascending) order. + + Pass in the thresholds from low to high e.g. + thresholds=[low_threshold, med_threshold, ..._threshold, high_threshold] + + Throws an error if the threshold is not valid + + :param: thresholds: List of thresholds in ascending order. + :type: List[int] + :param: threshold_identifier: The name of the threshold option. + :type: Optional[str] + + :returns: bool + """ + if thresholds is None or len(thresholds) < 2: + raise Exception(f"{threshold_identifier} thresholds are invalid {thresholds}") + for idx in range(1, len(thresholds)): + if not isinstance(thresholds[idx], int): + raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.") + if not isinstance(thresholds[idx - 1], int): + raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.") + + if thresholds[idx] <= thresholds[idx - 1]: + raise Exception( + f"{threshold_identifier} threshold ({thresholds[idx]}) " + f"is greater than or equal to ({thresholds[idx - 1]}.)" + ) + return True diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py index 15cd2447..10adb5c5 100644 --- a/src/primaite/game/agent/observations/software_observation.py +++ b/src/primaite/game/agent/observations/software_observation.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from __future__ import annotations -from typing import Dict +from typing import Dict, Optional from gymnasium import spaces from gymnasium.core import ObsType @@ -82,7 +82,7 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): application_name: str """Name of the application, used for querying simulation state dictionary""" - def __init__(self, where: WhereType) -> None: + def __init__(self, where: WhereType, thresholds: Optional[Dict] = {}) -> None: """ Initialise an application observation instance. @@ -94,16 +94,28 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): self.where = where self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0} - # TODO: allow these to be configured in yaml - self.high_threshold = 10 - self.med_threshold = 5 - self.low_threshold = 0 + if thresholds.get("app_executions") is None: + self.low_threshold = 0 + self.med_threshold = 5 + self.high_threshold = 10 + else: + if self._validate_thresholds( + thresholds=[ + thresholds.get("app_executions")["low"], + thresholds.get("app_executions")["medium"], + thresholds.get("app_executions")["high"], + ], + threshold_identifier="app_executions", + ): + self.low_threshold = thresholds.get("app_executions")["low"] + self.med_threshold = thresholds.get("app_executions")["medium"] + self.high_threshold = thresholds.get("app_executions")["high"] def _categorise_num_executions(self, num_executions: int) -> int: """ - Represent number of file accesses as a categorical variable. + Represent number of application executions as a categorical variable. - :param num_access: Number of file accesses. + :param num_access: Number of application executions. :return: Bin number corresponding to the number of accesses. """ if num_executions > self.high_threshold: @@ -161,4 +173,4 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): :return: Constructed application observation instance. :rtype: ApplicationObservation """ - return cls(where=parent_where + ["applications", config.application_name]) + return cls(where=parent_where + ["applications", config.application_name], thresholds=config.thresholds) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 123b6ddd..441ea632 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -531,7 +531,7 @@ class PrimaiteGame: reward_function_cfg = agent_cfg["reward_function"] # CREATE OBSERVATION SPACE - obs_space = ObservationManager.from_config(observation_space_cfg) + obs_space = ObservationManager.from_config(config=observation_space_cfg, thresholds=game.options.thresholds) # CREATE ACTION SPACE action_space = ActionManager.from_config(game, action_space_cfg) diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml index fed0f52d..03cf2207 100644 --- a/tests/assets/configs/basic_switched_network.yaml +++ b/tests/assets/configs/basic_switched_network.yaml @@ -25,7 +25,19 @@ game: - ICMP - TCP - UDP - + thresholds: + nmne: + high: 100 + medium: 25 + low: 5 + file_access: + high: 10 + medium: 5 + low: 2 + app_executions: + high: 5 + medium: 3 + low: 2 agents: - ref: client_2_green_user team: GREEN @@ -79,10 +91,16 @@ agents: options: hosts: - hostname: client_1 + applications: + - application_name: WebBrowser + folders: + - folder_name: root + files: + - file_name: "test.txt" - hostname: client_2 - hostname: client_3 num_services: 1 - num_applications: 0 + num_applications: 1 num_folders: 1 num_files: 1 num_nics: 2 @@ -219,6 +237,9 @@ simulation: options: ntp_server_ip: 192.168.1.10 - type: NTPServer + file_system: + - root: + - "test.txt" - hostname: client_2 type: computer ip_address: 192.168.10.22 diff --git a/tests/integration_tests/configuration_file_parsing/test_game_options_config.py b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py index 32d88c92..2cb5520e 100644 --- a/tests/integration_tests/configuration_file_parsing/test_game_options_config.py +++ b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py @@ -8,7 +8,7 @@ from primaite.config.load import data_manipulation_config_path from primaite.game.game import PrimaiteGame from tests import TEST_ASSETS_ROOT -BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" +BASIC_SWITCHED_NETWORK_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" def load_config(config_path: Union[str, Path]) -> PrimaiteGame: @@ -24,3 +24,42 @@ def test_thresholds(): game = load_config(data_manipulation_config_path()) assert game.options.thresholds is not None + + +def test_nmne_threshold(): + """Test that the NMNE thresholds are properly loaded in by observation.""" + game = load_config(BASIC_SWITCHED_NETWORK_CONFIG) + + assert game.options.thresholds["nmne"] is not None + + # get NIC observation + nic_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].nics[0] + assert nic_obs.low_threshold == 5 + assert nic_obs.med_threshold == 25 + assert nic_obs.high_threshold == 100 + + +def test_file_access_threshold(): + """Test that the NMNE thresholds are properly loaded in by observation.""" + game = load_config(BASIC_SWITCHED_NETWORK_CONFIG) + + assert game.options.thresholds["file_access"] is not None + + # get file observation + file_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].folders[0].files[0] + assert file_obs.low_threshold == 2 + assert file_obs.med_threshold == 5 + assert file_obs.high_threshold == 10 + + +def test_app_executions_threshold(): + """Test that the NMNE thresholds are properly loaded in by observation.""" + game = load_config(BASIC_SWITCHED_NETWORK_CONFIG) + + assert game.options.thresholds["app_executions"] is not None + + # get application observation + app_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].applications[0] + assert app_obs.low_threshold == 2 + assert app_obs.med_threshold == 3 + assert app_obs.high_threshold == 5 diff --git a/tests/integration_tests/game_layer/observations/test_file_system_observations.py b/tests/integration_tests/game_layer/observations/test_file_system_observations.py index e2ab2990..cbd9f8c0 100644 --- a/tests/integration_tests/game_layer/observations/test_file_system_observations.py +++ b/tests/integration_tests/game_layer/observations/test_file_system_observations.py @@ -44,6 +44,38 @@ def test_file_observation(simulation): assert observation_state.get("health_status") == 3 # corrupted +def test_config_file_access_categories(simulation): + pc: Computer = simulation.network.get_node_by_hostname("client_1") + file_obs = FileObservation( + where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"], + include_num_access=False, + file_system_requires_scan=True, + thresholds={"file_access": {"low": 3, "medium": 6, "high": 9}}, + ) + + assert file_obs.high_threshold == 9 + assert file_obs.med_threshold == 6 + assert file_obs.low_threshold == 3 + + with pytest.raises(Exception): + # should throw an error + FileObservation( + where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"], + include_num_access=False, + file_system_requires_scan=True, + thresholds={"file_access": {"low": 9, "medium": 6, "high": 9}}, + ) + + with pytest.raises(Exception): + # should throw an error + FileObservation( + where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"], + include_num_access=False, + file_system_requires_scan=True, + thresholds={"file_access": {"low": 3, "medium": 9, "high": 9}}, + ) + + def test_folder_observation(simulation): """Test the folder observation.""" pc: Computer = simulation.network.get_node_by_hostname("client_1") diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index ef789ba7..cafdec45 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -110,33 +110,28 @@ def test_nic_categories(simulation): nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True) - assert nic_obs.high_nmne_threshold == 10 # default - assert nic_obs.med_nmne_threshold == 5 # default - assert nic_obs.low_nmne_threshold == 0 # default + assert nic_obs.high_threshold == 10 # default + assert nic_obs.med_threshold == 5 # default + assert nic_obs.low_threshold == 0 # default -@pytest.mark.skip(reason="Feature not implemented yet") def test_config_nic_categories(simulation): pc: Computer = simulation.network.get_node_by_hostname("client_1") nic_obs = NICObservation( where=["network", "nodes", pc.hostname, "NICs", 1], - low_nmne_threshold=3, - med_nmne_threshold=6, - high_nmne_threshold=9, + thresholds={"nmne": {"low": 3, "medium": 6, "high": 9}}, include_nmne=True, ) - assert nic_obs.high_nmne_threshold == 9 - assert nic_obs.med_nmne_threshold == 6 - assert nic_obs.low_nmne_threshold == 3 + assert nic_obs.high_threshold == 9 + assert nic_obs.med_threshold == 6 + assert nic_obs.low_threshold == 3 with pytest.raises(Exception): # should throw an error NICObservation( where=["network", "nodes", pc.hostname, "NICs", 1], - low_nmne_threshold=9, - med_nmne_threshold=6, - high_nmne_threshold=9, + thresholds={"nmne": {"low": 9, "medium": 6, "high": 9}}, include_nmne=True, ) @@ -144,9 +139,7 @@ def test_config_nic_categories(simulation): # should throw an error NICObservation( where=["network", "nodes", pc.hostname, "NICs", 1], - low_nmne_threshold=3, - med_nmne_threshold=9, - high_nmne_threshold=9, + thresholds={"nmne": {"low": 3, "medium": 9, "high": 9}}, include_nmne=True, ) diff --git a/tests/integration_tests/game_layer/observations/test_software_observations.py b/tests/integration_tests/game_layer/observations/test_software_observations.py index 998aa755..25081585 100644 --- a/tests/integration_tests/game_layer/observations/test_software_observations.py +++ b/tests/integration_tests/game_layer/observations/test_software_observations.py @@ -69,3 +69,30 @@ def test_application_observation(simulation): assert observation_state.get("health_status") == 1 assert observation_state.get("operating_status") == 1 # running assert observation_state.get("num_executions") == 1 + + +def test_application_executions_categories(simulation): + pc: Computer = simulation.network.get_node_by_hostname("client_1") + + app_obs = ApplicationObservation( + where=["network", "nodes", pc.hostname, "applications", "WebBrowser"], + thresholds={"app_executions": {"low": 3, "medium": 6, "high": 9}}, + ) + + assert app_obs.high_threshold == 9 + assert app_obs.med_threshold == 6 + assert app_obs.low_threshold == 3 + + with pytest.raises(Exception): + # should throw an error + ApplicationObservation( + where=["network", "nodes", pc.hostname, "applications", "WebBrowser"], + thresholds={"app_executions": {"low": 9, "medium": 6, "high": 9}}, + ) + + with pytest.raises(Exception): + # should throw an error + ApplicationObservation( + where=["network", "nodes", pc.hostname, "applications", "WebBrowser"], + thresholds={"app_executions": {"low": 3, "medium": 9, "high": 9}}, + )