diff --git a/CHANGELOG.md b/CHANGELOG.md index 056742e4..c748a969 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Log observation space data by episode and step. - Added `show_history` method to Agents, allowing you to view actions taken by an agent per step. By default, `DONOTHING` actions are omitted. - New ``NODE_SEND_LOCAL_COMMAND`` action implemented which grants agents the ability to execute commands locally. (Previously limited to remote only) +- Added ability to be able to set the observation threshold for NMNE, file access and application executions ### Changed - ACL's are no longer applied to layer-2 traffic. diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py index fe959c9f..b24b26a6 100644 --- a/src/primaite/game/agent/observations/file_system_observations.py +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -55,21 +55,35 @@ class FileObservation(AbstractObservation, identifier="FILE"): self.default_observation["num_access"] = 0 if thresholds.get("file_access") is None: - self.low_threshold = 0 - self.med_threshold = 5 - self.high_threshold = 10 + self.low_file_access_threshold = 0 + self.med_file_access_threshold = 5 + self.high_file_access_threshold = 10 else: - if self._validate_thresholds( + self._set_file_access_threshold( thresholds=[ thresholds.get("file_access")["low"], thresholds.get("file_access")["medium"], thresholds.get("file_access")["high"], - ], - threshold_identifier="file_access", - ): - self.low_threshold = thresholds.get("file_access")["low"] - self.med_threshold = thresholds.get("file_access")["medium"] - self.high_threshold = thresholds.get("file_access")["high"] + ] + ) + + def _set_file_access_threshold(self, thresholds: List[int]): + """ + Method that validates and then sets the file access threshold. + + :param: thresholds: The file access threshold to validate and set. + """ + if self._validate_thresholds( + thresholds=[ + thresholds[0], + thresholds[1], + thresholds[2], + ], + threshold_identifier="file_access", + ): + self.low_file_access_threshold = thresholds[0] + self.med_file_access_threshold = thresholds[1] + self.high_file_access_threshold = thresholds[2] def _categorise_num_access(self, num_access: int) -> int: """ @@ -78,11 +92,11 @@ class FileObservation(AbstractObservation, identifier="FILE"): :param num_access: Number of file accesses. :return: Bin number corresponding to the number of accesses. """ - if num_access > self.high_threshold: + if num_access > self.high_file_access_threshold: return 3 - elif num_access > self.med_threshold: + elif num_access > self.med_file_access_threshold: return 2 - elif num_access > self.low_threshold: + elif num_access > self.low_file_access_threshold: return 1 return 0 @@ -190,23 +204,6 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): self.file_system_requires_scan: bool = file_system_requires_scan - if thresholds.get("file_access") is None: - self.low_threshold = 0 - self.med_threshold = 5 - self.high_threshold = 10 - else: - if self._validate_thresholds( - thresholds=[ - thresholds.get("file_access")["low"], - thresholds.get("file_access")["medium"], - thresholds.get("file_access")["high"], - ], - threshold_identifier="file_access", - ): - self.low_threshold = thresholds.get("file_access")["low"] - self.med_threshold = thresholds.get("file_access")["medium"] - self.high_threshold = thresholds.get("file_access")["high"] - self.files: List[FileObservation] = files while len(self.files) < num_files: self.files.append( diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 30ee240d..0dabd9f4 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from __future__ import annotations -from typing import ClassVar, Dict, Optional +from typing import ClassVar, Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType @@ -55,21 +55,17 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): self.nmne_outbound_last_step: int = 0 if thresholds.get("nmne") is None: - self.low_threshold = 0 - self.med_threshold = 5 - self.high_threshold = 10 + self.low_nmne_threshold = 0 + self.med_nmne_threshold = 5 + self.high_nmne_threshold = 10 else: - if self._validate_thresholds( + self._set_nmne_threshold( thresholds=[ thresholds.get("nmne")["low"], thresholds.get("nmne")["medium"], thresholds.get("nmne")["high"], - ], - threshold_identifier="nmne", - ): - self.low_threshold = thresholds.get("nmne")["low"] - self.med_threshold = thresholds.get("nmne")["medium"] - self.high_threshold = thresholds.get("nmne")["high"] + ] + ) self.monitored_traffic = monitored_traffic if self.monitored_traffic: @@ -108,11 +104,11 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): :param nmne_count: Number of MNEs detected. :return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count. """ - if nmne_count > self.high_threshold: + if nmne_count > self.high_nmne_threshold: return 3 - elif nmne_count > self.med_threshold: + elif nmne_count > self.med_nmne_threshold: return 2 - elif nmne_count > self.low_threshold: + elif nmne_count > self.low_nmne_threshold: return 1 return 0 @@ -126,6 +122,20 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): bandwidth_utilisation = traffic_value / nic_max_bandwidth return int(bandwidth_utilisation * 9) + 1 + def _set_nmne_threshold(self, thresholds: List[int]): + """ + Method that validates and then sets the NMNE threshold. + + :param: thresholds: The NMNE threshold to validate and set. + """ + if self._validate_thresholds( + thresholds=thresholds, + threshold_identifier="nmne", + ): + self.low_nmne_threshold = thresholds[0] + self.med_nmne_threshold = thresholds[1] + self.high_nmne_threshold = thresholds[2] + def observe(self, state: Dict) -> ObsType: """ Generate observation based on the current state of the simulation. diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py index 0b209f52..7a31a26b 100644 --- a/src/primaite/game/agent/observations/observations.py +++ b/src/primaite/game/agent/observations/observations.py @@ -97,7 +97,7 @@ class AbstractObservation(ABC): if thresholds[idx] <= thresholds[idx - 1]: raise Exception( - f"{threshold_identifier} threshold ({thresholds[idx]}) " - f"is greater than or equal to ({thresholds[idx - 1]}.)" + f"{threshold_identifier} threshold ({thresholds[idx - 1]}) " + f"is greater than or equal to ({thresholds[idx]}.)" ) return True diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py index 10ffe3fc..0318c864 100644 --- a/src/primaite/game/agent/observations/software_observation.py +++ b/src/primaite/game/agent/observations/software_observation.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from __future__ import annotations -from typing import Dict, Optional +from typing import Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType @@ -109,21 +109,35 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0} if thresholds.get("app_executions") is None: - self.low_threshold = 0 - self.med_threshold = 5 - self.high_threshold = 10 + self.low_app_execution_threshold = 0 + self.med_app_execution_threshold = 5 + self.high_app_execution_threshold = 10 else: - if self._validate_thresholds( + self._set_application_execution_thresholds( thresholds=[ thresholds.get("app_executions")["low"], thresholds.get("app_executions")["medium"], thresholds.get("app_executions")["high"], - ], - threshold_identifier="app_executions", - ): - self.low_threshold = thresholds.get("app_executions")["low"] - self.med_threshold = thresholds.get("app_executions")["medium"] - self.high_threshold = thresholds.get("app_executions")["high"] + ] + ) + + def _set_application_execution_thresholds(self, thresholds: List[int]): + """ + Method that validates and then sets the application execution threshold. + + :param: thresholds: The application execution threshold to validate and set. + """ + if self._validate_thresholds( + thresholds=[ + thresholds[0], + thresholds[1], + thresholds[2], + ], + threshold_identifier="app_executions", + ): + self.low_app_execution_threshold = thresholds[0] + self.med_app_execution_threshold = thresholds[1] + self.high_app_execution_threshold = thresholds[2] def _categorise_num_executions(self, num_executions: int) -> int: """ @@ -132,11 +146,11 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): :param num_access: Number of application executions. :return: Bin number corresponding to the number of accesses. """ - if num_executions > self.high_threshold: + if num_executions > self.high_app_execution_threshold: return 3 - elif num_executions > self.med_threshold: + elif num_executions > self.med_app_execution_threshold: return 2 - elif num_executions > self.low_threshold: + elif num_executions > self.low_app_execution_threshold: return 1 return 0 diff --git a/tests/integration_tests/configuration_file_parsing/test_game_options_config.py b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py index 2cb5520e..4098db7f 100644 --- a/tests/integration_tests/configuration_file_parsing/test_game_options_config.py +++ b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py @@ -34,9 +34,9 @@ def test_nmne_threshold(): # get NIC observation nic_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].nics[0] - assert nic_obs.low_threshold == 5 - assert nic_obs.med_threshold == 25 - assert nic_obs.high_threshold == 100 + assert nic_obs.low_nmne_threshold == 5 + assert nic_obs.med_nmne_threshold == 25 + assert nic_obs.high_nmne_threshold == 100 def test_file_access_threshold(): @@ -47,9 +47,9 @@ def test_file_access_threshold(): # get file observation file_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].folders[0].files[0] - assert file_obs.low_threshold == 2 - assert file_obs.med_threshold == 5 - assert file_obs.high_threshold == 10 + assert file_obs.low_file_access_threshold == 2 + assert file_obs.med_file_access_threshold == 5 + assert file_obs.high_file_access_threshold == 10 def test_app_executions_threshold(): @@ -60,6 +60,6 @@ def test_app_executions_threshold(): # get application observation app_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].applications[0] - assert app_obs.low_threshold == 2 - assert app_obs.med_threshold == 3 - assert app_obs.high_threshold == 5 + assert app_obs.low_app_execution_threshold == 2 + assert app_obs.med_app_execution_threshold == 3 + assert app_obs.high_app_execution_threshold == 5 diff --git a/tests/integration_tests/game_layer/observations/test_file_system_observations.py b/tests/integration_tests/game_layer/observations/test_file_system_observations.py index cbd9f8c0..6356c297 100644 --- a/tests/integration_tests/game_layer/observations/test_file_system_observations.py +++ b/tests/integration_tests/game_layer/observations/test_file_system_observations.py @@ -53,9 +53,9 @@ def test_config_file_access_categories(simulation): thresholds={"file_access": {"low": 3, "medium": 6, "high": 9}}, ) - assert file_obs.high_threshold == 9 - assert file_obs.med_threshold == 6 - assert file_obs.low_threshold == 3 + assert file_obs.high_file_access_threshold == 9 + assert file_obs.med_file_access_threshold == 6 + assert file_obs.low_file_access_threshold == 3 with pytest.raises(Exception): # should throw an error diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index 9b2baf25..d01d0c8e 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -118,9 +118,9 @@ def test_nic_categories(simulation): nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True) - assert nic_obs.high_threshold == 10 # default - assert nic_obs.med_threshold == 5 # default - assert nic_obs.low_threshold == 0 # default + assert nic_obs.high_nmne_threshold == 10 # default + assert nic_obs.med_nmne_threshold == 5 # default + assert nic_obs.low_nmne_threshold == 0 # default def test_config_nic_categories(simulation): @@ -131,9 +131,9 @@ def test_config_nic_categories(simulation): include_nmne=True, ) - assert nic_obs.high_threshold == 9 - assert nic_obs.med_threshold == 6 - assert nic_obs.low_threshold == 3 + assert nic_obs.high_nmne_threshold == 9 + assert nic_obs.med_nmne_threshold == 6 + assert nic_obs.low_nmne_threshold == 3 with pytest.raises(Exception): # should throw an error diff --git a/tests/integration_tests/game_layer/observations/test_software_observations.py b/tests/integration_tests/game_layer/observations/test_software_observations.py index 22374718..a0637969 100644 --- a/tests/integration_tests/game_layer/observations/test_software_observations.py +++ b/tests/integration_tests/game_layer/observations/test_software_observations.py @@ -84,9 +84,9 @@ def test_application_executions_categories(simulation): thresholds={"app_executions": {"low": 3, "medium": 6, "high": 9}}, ) - assert app_obs.high_threshold == 9 - assert app_obs.med_threshold == 6 - assert app_obs.low_threshold == 3 + assert app_obs.high_app_execution_threshold == 9 + assert app_obs.med_app_execution_threshold == 6 + assert app_obs.low_app_execution_threshold == 3 with pytest.raises(Exception): # should throw an error