#2445: apply PR suggestions

This commit is contained in:
Czar Echavez
2024-09-25 10:50:26 +01:00
parent 171dd83f2f
commit b9df2bd6a8
9 changed files with 103 additions and 81 deletions

View File

@@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Log observation space data by episode and step.
- Added `show_history` method to Agents, allowing you to view actions taken by an agent per step. By default, `DONOTHING` actions are omitted.
- New ``NODE_SEND_LOCAL_COMMAND`` action implemented which grants agents the ability to execute commands locally. (Previously limited to remote only)
- Added ability to be able to set the observation threshold for NMNE, file access and application executions
### Changed
- ACL's are no longer applied to layer-2 traffic.

View File

@@ -55,21 +55,35 @@ class FileObservation(AbstractObservation, identifier="FILE"):
self.default_observation["num_access"] = 0
if thresholds.get("file_access") is None:
self.low_threshold = 0
self.med_threshold = 5
self.high_threshold = 10
self.low_file_access_threshold = 0
self.med_file_access_threshold = 5
self.high_file_access_threshold = 10
else:
if self._validate_thresholds(
self._set_file_access_threshold(
thresholds=[
thresholds.get("file_access")["low"],
thresholds.get("file_access")["medium"],
thresholds.get("file_access")["high"],
],
threshold_identifier="file_access",
):
self.low_threshold = thresholds.get("file_access")["low"]
self.med_threshold = thresholds.get("file_access")["medium"]
self.high_threshold = thresholds.get("file_access")["high"]
]
)
def _set_file_access_threshold(self, thresholds: List[int]):
"""
Method that validates and then sets the file access threshold.
:param: thresholds: The file access threshold to validate and set.
"""
if self._validate_thresholds(
thresholds=[
thresholds[0],
thresholds[1],
thresholds[2],
],
threshold_identifier="file_access",
):
self.low_file_access_threshold = thresholds[0]
self.med_file_access_threshold = thresholds[1]
self.high_file_access_threshold = thresholds[2]
def _categorise_num_access(self, num_access: int) -> int:
"""
@@ -78,11 +92,11 @@ class FileObservation(AbstractObservation, identifier="FILE"):
:param num_access: Number of file accesses.
:return: Bin number corresponding to the number of accesses.
"""
if num_access > self.high_threshold:
if num_access > self.high_file_access_threshold:
return 3
elif num_access > self.med_threshold:
elif num_access > self.med_file_access_threshold:
return 2
elif num_access > self.low_threshold:
elif num_access > self.low_file_access_threshold:
return 1
return 0
@@ -190,23 +204,6 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
self.file_system_requires_scan: bool = file_system_requires_scan
if thresholds.get("file_access") is None:
self.low_threshold = 0
self.med_threshold = 5
self.high_threshold = 10
else:
if self._validate_thresholds(
thresholds=[
thresholds.get("file_access")["low"],
thresholds.get("file_access")["medium"],
thresholds.get("file_access")["high"],
],
threshold_identifier="file_access",
):
self.low_threshold = thresholds.get("file_access")["low"]
self.med_threshold = thresholds.get("file_access")["medium"]
self.high_threshold = thresholds.get("file_access")["high"]
self.files: List[FileObservation] = files
while len(self.files) < num_files:
self.files.append(

View File

@@ -1,7 +1,7 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from __future__ import annotations
from typing import ClassVar, Dict, Optional
from typing import ClassVar, Dict, List, Optional
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -55,21 +55,17 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
self.nmne_outbound_last_step: int = 0
if thresholds.get("nmne") is None:
self.low_threshold = 0
self.med_threshold = 5
self.high_threshold = 10
self.low_nmne_threshold = 0
self.med_nmne_threshold = 5
self.high_nmne_threshold = 10
else:
if self._validate_thresholds(
self._set_nmne_threshold(
thresholds=[
thresholds.get("nmne")["low"],
thresholds.get("nmne")["medium"],
thresholds.get("nmne")["high"],
],
threshold_identifier="nmne",
):
self.low_threshold = thresholds.get("nmne")["low"]
self.med_threshold = thresholds.get("nmne")["medium"]
self.high_threshold = thresholds.get("nmne")["high"]
]
)
self.monitored_traffic = monitored_traffic
if self.monitored_traffic:
@@ -108,11 +104,11 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
:param nmne_count: Number of MNEs detected.
:return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count.
"""
if nmne_count > self.high_threshold:
if nmne_count > self.high_nmne_threshold:
return 3
elif nmne_count > self.med_threshold:
elif nmne_count > self.med_nmne_threshold:
return 2
elif nmne_count > self.low_threshold:
elif nmne_count > self.low_nmne_threshold:
return 1
return 0
@@ -126,6 +122,20 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
bandwidth_utilisation = traffic_value / nic_max_bandwidth
return int(bandwidth_utilisation * 9) + 1
def _set_nmne_threshold(self, thresholds: List[int]):
"""
Method that validates and then sets the NMNE threshold.
:param: thresholds: The NMNE threshold to validate and set.
"""
if self._validate_thresholds(
thresholds=thresholds,
threshold_identifier="nmne",
):
self.low_nmne_threshold = thresholds[0]
self.med_nmne_threshold = thresholds[1]
self.high_nmne_threshold = thresholds[2]
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.

View File

@@ -97,7 +97,7 @@ class AbstractObservation(ABC):
if thresholds[idx] <= thresholds[idx - 1]:
raise Exception(
f"{threshold_identifier} threshold ({thresholds[idx]}) "
f"is greater than or equal to ({thresholds[idx - 1]}.)"
f"{threshold_identifier} threshold ({thresholds[idx - 1]}) "
f"is greater than or equal to ({thresholds[idx]}.)"
)
return True

View File

@@ -1,7 +1,7 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from __future__ import annotations
from typing import Dict, Optional
from typing import Dict, List, Optional
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -109,21 +109,35 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0}
if thresholds.get("app_executions") is None:
self.low_threshold = 0
self.med_threshold = 5
self.high_threshold = 10
self.low_app_execution_threshold = 0
self.med_app_execution_threshold = 5
self.high_app_execution_threshold = 10
else:
if self._validate_thresholds(
self._set_application_execution_thresholds(
thresholds=[
thresholds.get("app_executions")["low"],
thresholds.get("app_executions")["medium"],
thresholds.get("app_executions")["high"],
],
threshold_identifier="app_executions",
):
self.low_threshold = thresholds.get("app_executions")["low"]
self.med_threshold = thresholds.get("app_executions")["medium"]
self.high_threshold = thresholds.get("app_executions")["high"]
]
)
def _set_application_execution_thresholds(self, thresholds: List[int]):
"""
Method that validates and then sets the application execution threshold.
:param: thresholds: The application execution threshold to validate and set.
"""
if self._validate_thresholds(
thresholds=[
thresholds[0],
thresholds[1],
thresholds[2],
],
threshold_identifier="app_executions",
):
self.low_app_execution_threshold = thresholds[0]
self.med_app_execution_threshold = thresholds[1]
self.high_app_execution_threshold = thresholds[2]
def _categorise_num_executions(self, num_executions: int) -> int:
"""
@@ -132,11 +146,11 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
:param num_access: Number of application executions.
:return: Bin number corresponding to the number of accesses.
"""
if num_executions > self.high_threshold:
if num_executions > self.high_app_execution_threshold:
return 3
elif num_executions > self.med_threshold:
elif num_executions > self.med_app_execution_threshold:
return 2
elif num_executions > self.low_threshold:
elif num_executions > self.low_app_execution_threshold:
return 1
return 0

View File

@@ -34,9 +34,9 @@ def test_nmne_threshold():
# get NIC observation
nic_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].nics[0]
assert nic_obs.low_threshold == 5
assert nic_obs.med_threshold == 25
assert nic_obs.high_threshold == 100
assert nic_obs.low_nmne_threshold == 5
assert nic_obs.med_nmne_threshold == 25
assert nic_obs.high_nmne_threshold == 100
def test_file_access_threshold():
@@ -47,9 +47,9 @@ def test_file_access_threshold():
# get file observation
file_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].folders[0].files[0]
assert file_obs.low_threshold == 2
assert file_obs.med_threshold == 5
assert file_obs.high_threshold == 10
assert file_obs.low_file_access_threshold == 2
assert file_obs.med_file_access_threshold == 5
assert file_obs.high_file_access_threshold == 10
def test_app_executions_threshold():
@@ -60,6 +60,6 @@ def test_app_executions_threshold():
# get application observation
app_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].applications[0]
assert app_obs.low_threshold == 2
assert app_obs.med_threshold == 3
assert app_obs.high_threshold == 5
assert app_obs.low_app_execution_threshold == 2
assert app_obs.med_app_execution_threshold == 3
assert app_obs.high_app_execution_threshold == 5

View File

@@ -53,9 +53,9 @@ def test_config_file_access_categories(simulation):
thresholds={"file_access": {"low": 3, "medium": 6, "high": 9}},
)
assert file_obs.high_threshold == 9
assert file_obs.med_threshold == 6
assert file_obs.low_threshold == 3
assert file_obs.high_file_access_threshold == 9
assert file_obs.med_file_access_threshold == 6
assert file_obs.low_file_access_threshold == 3
with pytest.raises(Exception):
# should throw an error

View File

@@ -118,9 +118,9 @@ def test_nic_categories(simulation):
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True)
assert nic_obs.high_threshold == 10 # default
assert nic_obs.med_threshold == 5 # default
assert nic_obs.low_threshold == 0 # default
assert nic_obs.high_nmne_threshold == 10 # default
assert nic_obs.med_nmne_threshold == 5 # default
assert nic_obs.low_nmne_threshold == 0 # default
def test_config_nic_categories(simulation):
@@ -131,9 +131,9 @@ def test_config_nic_categories(simulation):
include_nmne=True,
)
assert nic_obs.high_threshold == 9
assert nic_obs.med_threshold == 6
assert nic_obs.low_threshold == 3
assert nic_obs.high_nmne_threshold == 9
assert nic_obs.med_nmne_threshold == 6
assert nic_obs.low_nmne_threshold == 3
with pytest.raises(Exception):
# should throw an error

View File

@@ -84,9 +84,9 @@ def test_application_executions_categories(simulation):
thresholds={"app_executions": {"low": 3, "medium": 6, "high": 9}},
)
assert app_obs.high_threshold == 9
assert app_obs.med_threshold == 6
assert app_obs.low_threshold == 3
assert app_obs.high_app_execution_threshold == 9
assert app_obs.med_app_execution_threshold == 6
assert app_obs.low_app_execution_threshold == 3
with pytest.raises(Exception):
# should throw an error