#2445: apply PR suggestions
This commit is contained in:
@@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Log observation space data by episode and step.
|
||||
- Added `show_history` method to Agents, allowing you to view actions taken by an agent per step. By default, `DONOTHING` actions are omitted.
|
||||
- New ``NODE_SEND_LOCAL_COMMAND`` action implemented which grants agents the ability to execute commands locally. (Previously limited to remote only)
|
||||
- Added ability to be able to set the observation threshold for NMNE, file access and application executions
|
||||
|
||||
### Changed
|
||||
- ACL's are no longer applied to layer-2 traffic.
|
||||
|
||||
@@ -55,21 +55,35 @@ class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
self.default_observation["num_access"] = 0
|
||||
|
||||
if thresholds.get("file_access") is None:
|
||||
self.low_threshold = 0
|
||||
self.med_threshold = 5
|
||||
self.high_threshold = 10
|
||||
self.low_file_access_threshold = 0
|
||||
self.med_file_access_threshold = 5
|
||||
self.high_file_access_threshold = 10
|
||||
else:
|
||||
if self._validate_thresholds(
|
||||
self._set_file_access_threshold(
|
||||
thresholds=[
|
||||
thresholds.get("file_access")["low"],
|
||||
thresholds.get("file_access")["medium"],
|
||||
thresholds.get("file_access")["high"],
|
||||
],
|
||||
threshold_identifier="file_access",
|
||||
):
|
||||
self.low_threshold = thresholds.get("file_access")["low"]
|
||||
self.med_threshold = thresholds.get("file_access")["medium"]
|
||||
self.high_threshold = thresholds.get("file_access")["high"]
|
||||
]
|
||||
)
|
||||
|
||||
def _set_file_access_threshold(self, thresholds: List[int]):
|
||||
"""
|
||||
Method that validates and then sets the file access threshold.
|
||||
|
||||
:param: thresholds: The file access threshold to validate and set.
|
||||
"""
|
||||
if self._validate_thresholds(
|
||||
thresholds=[
|
||||
thresholds[0],
|
||||
thresholds[1],
|
||||
thresholds[2],
|
||||
],
|
||||
threshold_identifier="file_access",
|
||||
):
|
||||
self.low_file_access_threshold = thresholds[0]
|
||||
self.med_file_access_threshold = thresholds[1]
|
||||
self.high_file_access_threshold = thresholds[2]
|
||||
|
||||
def _categorise_num_access(self, num_access: int) -> int:
|
||||
"""
|
||||
@@ -78,11 +92,11 @@ class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
:param num_access: Number of file accesses.
|
||||
:return: Bin number corresponding to the number of accesses.
|
||||
"""
|
||||
if num_access > self.high_threshold:
|
||||
if num_access > self.high_file_access_threshold:
|
||||
return 3
|
||||
elif num_access > self.med_threshold:
|
||||
elif num_access > self.med_file_access_threshold:
|
||||
return 2
|
||||
elif num_access > self.low_threshold:
|
||||
elif num_access > self.low_file_access_threshold:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
@@ -190,23 +204,6 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
|
||||
self.file_system_requires_scan: bool = file_system_requires_scan
|
||||
|
||||
if thresholds.get("file_access") is None:
|
||||
self.low_threshold = 0
|
||||
self.med_threshold = 5
|
||||
self.high_threshold = 10
|
||||
else:
|
||||
if self._validate_thresholds(
|
||||
thresholds=[
|
||||
thresholds.get("file_access")["low"],
|
||||
thresholds.get("file_access")["medium"],
|
||||
thresholds.get("file_access")["high"],
|
||||
],
|
||||
threshold_identifier="file_access",
|
||||
):
|
||||
self.low_threshold = thresholds.get("file_access")["low"]
|
||||
self.med_threshold = thresholds.get("file_access")["medium"]
|
||||
self.high_threshold = thresholds.get("file_access")["high"]
|
||||
|
||||
self.files: List[FileObservation] = files
|
||||
while len(self.files) < num_files:
|
||||
self.files.append(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar, Dict, Optional
|
||||
from typing import ClassVar, Dict, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -55,21 +55,17 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
self.nmne_outbound_last_step: int = 0
|
||||
|
||||
if thresholds.get("nmne") is None:
|
||||
self.low_threshold = 0
|
||||
self.med_threshold = 5
|
||||
self.high_threshold = 10
|
||||
self.low_nmne_threshold = 0
|
||||
self.med_nmne_threshold = 5
|
||||
self.high_nmne_threshold = 10
|
||||
else:
|
||||
if self._validate_thresholds(
|
||||
self._set_nmne_threshold(
|
||||
thresholds=[
|
||||
thresholds.get("nmne")["low"],
|
||||
thresholds.get("nmne")["medium"],
|
||||
thresholds.get("nmne")["high"],
|
||||
],
|
||||
threshold_identifier="nmne",
|
||||
):
|
||||
self.low_threshold = thresholds.get("nmne")["low"]
|
||||
self.med_threshold = thresholds.get("nmne")["medium"]
|
||||
self.high_threshold = thresholds.get("nmne")["high"]
|
||||
]
|
||||
)
|
||||
|
||||
self.monitored_traffic = monitored_traffic
|
||||
if self.monitored_traffic:
|
||||
@@ -108,11 +104,11 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
:param nmne_count: Number of MNEs detected.
|
||||
:return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count.
|
||||
"""
|
||||
if nmne_count > self.high_threshold:
|
||||
if nmne_count > self.high_nmne_threshold:
|
||||
return 3
|
||||
elif nmne_count > self.med_threshold:
|
||||
elif nmne_count > self.med_nmne_threshold:
|
||||
return 2
|
||||
elif nmne_count > self.low_threshold:
|
||||
elif nmne_count > self.low_nmne_threshold:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
@@ -126,6 +122,20 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
bandwidth_utilisation = traffic_value / nic_max_bandwidth
|
||||
return int(bandwidth_utilisation * 9) + 1
|
||||
|
||||
def _set_nmne_threshold(self, thresholds: List[int]):
|
||||
"""
|
||||
Method that validates and then sets the NMNE threshold.
|
||||
|
||||
:param: thresholds: The NMNE threshold to validate and set.
|
||||
"""
|
||||
if self._validate_thresholds(
|
||||
thresholds=thresholds,
|
||||
threshold_identifier="nmne",
|
||||
):
|
||||
self.low_nmne_threshold = thresholds[0]
|
||||
self.med_nmne_threshold = thresholds[1]
|
||||
self.high_nmne_threshold = thresholds[2]
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
@@ -97,7 +97,7 @@ class AbstractObservation(ABC):
|
||||
|
||||
if thresholds[idx] <= thresholds[idx - 1]:
|
||||
raise Exception(
|
||||
f"{threshold_identifier} threshold ({thresholds[idx]}) "
|
||||
f"is greater than or equal to ({thresholds[idx - 1]}.)"
|
||||
f"{threshold_identifier} threshold ({thresholds[idx - 1]}) "
|
||||
f"is greater than or equal to ({thresholds[idx]}.)"
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -109,21 +109,35 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
|
||||
self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0}
|
||||
|
||||
if thresholds.get("app_executions") is None:
|
||||
self.low_threshold = 0
|
||||
self.med_threshold = 5
|
||||
self.high_threshold = 10
|
||||
self.low_app_execution_threshold = 0
|
||||
self.med_app_execution_threshold = 5
|
||||
self.high_app_execution_threshold = 10
|
||||
else:
|
||||
if self._validate_thresholds(
|
||||
self._set_application_execution_thresholds(
|
||||
thresholds=[
|
||||
thresholds.get("app_executions")["low"],
|
||||
thresholds.get("app_executions")["medium"],
|
||||
thresholds.get("app_executions")["high"],
|
||||
],
|
||||
threshold_identifier="app_executions",
|
||||
):
|
||||
self.low_threshold = thresholds.get("app_executions")["low"]
|
||||
self.med_threshold = thresholds.get("app_executions")["medium"]
|
||||
self.high_threshold = thresholds.get("app_executions")["high"]
|
||||
]
|
||||
)
|
||||
|
||||
def _set_application_execution_thresholds(self, thresholds: List[int]):
|
||||
"""
|
||||
Method that validates and then sets the application execution threshold.
|
||||
|
||||
:param: thresholds: The application execution threshold to validate and set.
|
||||
"""
|
||||
if self._validate_thresholds(
|
||||
thresholds=[
|
||||
thresholds[0],
|
||||
thresholds[1],
|
||||
thresholds[2],
|
||||
],
|
||||
threshold_identifier="app_executions",
|
||||
):
|
||||
self.low_app_execution_threshold = thresholds[0]
|
||||
self.med_app_execution_threshold = thresholds[1]
|
||||
self.high_app_execution_threshold = thresholds[2]
|
||||
|
||||
def _categorise_num_executions(self, num_executions: int) -> int:
|
||||
"""
|
||||
@@ -132,11 +146,11 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
|
||||
:param num_access: Number of application executions.
|
||||
:return: Bin number corresponding to the number of accesses.
|
||||
"""
|
||||
if num_executions > self.high_threshold:
|
||||
if num_executions > self.high_app_execution_threshold:
|
||||
return 3
|
||||
elif num_executions > self.med_threshold:
|
||||
elif num_executions > self.med_app_execution_threshold:
|
||||
return 2
|
||||
elif num_executions > self.low_threshold:
|
||||
elif num_executions > self.low_app_execution_threshold:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
@@ -34,9 +34,9 @@ def test_nmne_threshold():
|
||||
|
||||
# get NIC observation
|
||||
nic_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].nics[0]
|
||||
assert nic_obs.low_threshold == 5
|
||||
assert nic_obs.med_threshold == 25
|
||||
assert nic_obs.high_threshold == 100
|
||||
assert nic_obs.low_nmne_threshold == 5
|
||||
assert nic_obs.med_nmne_threshold == 25
|
||||
assert nic_obs.high_nmne_threshold == 100
|
||||
|
||||
|
||||
def test_file_access_threshold():
|
||||
@@ -47,9 +47,9 @@ def test_file_access_threshold():
|
||||
|
||||
# get file observation
|
||||
file_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].folders[0].files[0]
|
||||
assert file_obs.low_threshold == 2
|
||||
assert file_obs.med_threshold == 5
|
||||
assert file_obs.high_threshold == 10
|
||||
assert file_obs.low_file_access_threshold == 2
|
||||
assert file_obs.med_file_access_threshold == 5
|
||||
assert file_obs.high_file_access_threshold == 10
|
||||
|
||||
|
||||
def test_app_executions_threshold():
|
||||
@@ -60,6 +60,6 @@ def test_app_executions_threshold():
|
||||
|
||||
# get application observation
|
||||
app_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].applications[0]
|
||||
assert app_obs.low_threshold == 2
|
||||
assert app_obs.med_threshold == 3
|
||||
assert app_obs.high_threshold == 5
|
||||
assert app_obs.low_app_execution_threshold == 2
|
||||
assert app_obs.med_app_execution_threshold == 3
|
||||
assert app_obs.high_app_execution_threshold == 5
|
||||
|
||||
@@ -53,9 +53,9 @@ def test_config_file_access_categories(simulation):
|
||||
thresholds={"file_access": {"low": 3, "medium": 6, "high": 9}},
|
||||
)
|
||||
|
||||
assert file_obs.high_threshold == 9
|
||||
assert file_obs.med_threshold == 6
|
||||
assert file_obs.low_threshold == 3
|
||||
assert file_obs.high_file_access_threshold == 9
|
||||
assert file_obs.med_file_access_threshold == 6
|
||||
assert file_obs.low_file_access_threshold == 3
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
|
||||
@@ -118,9 +118,9 @@ def test_nic_categories(simulation):
|
||||
|
||||
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True)
|
||||
|
||||
assert nic_obs.high_threshold == 10 # default
|
||||
assert nic_obs.med_threshold == 5 # default
|
||||
assert nic_obs.low_threshold == 0 # default
|
||||
assert nic_obs.high_nmne_threshold == 10 # default
|
||||
assert nic_obs.med_nmne_threshold == 5 # default
|
||||
assert nic_obs.low_nmne_threshold == 0 # default
|
||||
|
||||
|
||||
def test_config_nic_categories(simulation):
|
||||
@@ -131,9 +131,9 @@ def test_config_nic_categories(simulation):
|
||||
include_nmne=True,
|
||||
)
|
||||
|
||||
assert nic_obs.high_threshold == 9
|
||||
assert nic_obs.med_threshold == 6
|
||||
assert nic_obs.low_threshold == 3
|
||||
assert nic_obs.high_nmne_threshold == 9
|
||||
assert nic_obs.med_nmne_threshold == 6
|
||||
assert nic_obs.low_nmne_threshold == 3
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
|
||||
@@ -84,9 +84,9 @@ def test_application_executions_categories(simulation):
|
||||
thresholds={"app_executions": {"low": 3, "medium": 6, "high": 9}},
|
||||
)
|
||||
|
||||
assert app_obs.high_threshold == 9
|
||||
assert app_obs.med_threshold == 6
|
||||
assert app_obs.low_threshold == 3
|
||||
assert app_obs.high_app_execution_threshold == 9
|
||||
assert app_obs.med_app_execution_threshold == 6
|
||||
assert app_obs.low_app_execution_threshold == 3
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
|
||||
Reference in New Issue
Block a user