#2445: added the ability to pass the game options thresholds to observations so that relevant observation items can retrieve the thresholds from config
This commit is contained in:
@@ -26,7 +26,13 @@ class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
file_system_requires_scan: Optional[bool] = None
|
||||
"""If True, the file must be scanned to update the health state. Tf False, the true state is always shown."""
|
||||
|
||||
def __init__(self, where: WhereType, include_num_access: bool, file_system_requires_scan: bool) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
include_num_access: bool,
|
||||
file_system_requires_scan: bool,
|
||||
thresholds: Optional[Dict] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a file observation instance.
|
||||
|
||||
@@ -48,10 +54,22 @@ class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
if self.include_num_access:
|
||||
self.default_observation["num_access"] = 0
|
||||
|
||||
# TODO: allow these to be configured in yaml
|
||||
self.high_threshold = 10
|
||||
self.med_threshold = 5
|
||||
self.low_threshold = 0
|
||||
if thresholds.get("file_access") is None:
|
||||
self.low_threshold = 0
|
||||
self.med_threshold = 5
|
||||
self.high_threshold = 10
|
||||
else:
|
||||
if self._validate_thresholds(
|
||||
thresholds=[
|
||||
thresholds.get("file_access")["low"],
|
||||
thresholds.get("file_access")["medium"],
|
||||
thresholds.get("file_access")["high"],
|
||||
],
|
||||
threshold_identifier="file_access",
|
||||
):
|
||||
self.low_threshold = thresholds.get("file_access")["low"]
|
||||
self.med_threshold = thresholds.get("file_access")["medium"]
|
||||
self.high_threshold = thresholds.get("file_access")["high"]
|
||||
|
||||
def _categorise_num_access(self, num_access: int) -> int:
|
||||
"""
|
||||
@@ -122,6 +140,7 @@ class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
where=parent_where + ["files", config.file_name],
|
||||
include_num_access=config.include_num_access,
|
||||
file_system_requires_scan=config.file_system_requires_scan,
|
||||
thresholds=config.thresholds,
|
||||
)
|
||||
|
||||
|
||||
@@ -149,6 +168,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
num_files: int,
|
||||
include_num_access: bool,
|
||||
file_system_requires_scan: bool,
|
||||
thresholds: Optional[Dict] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a folder observation instance.
|
||||
@@ -170,6 +190,23 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
|
||||
self.file_system_requires_scan: bool = file_system_requires_scan
|
||||
|
||||
if thresholds.get("file_access") is None:
|
||||
self.low_threshold = 0
|
||||
self.med_threshold = 5
|
||||
self.high_threshold = 10
|
||||
else:
|
||||
if self._validate_thresholds(
|
||||
thresholds=[
|
||||
thresholds.get("file_access")["low"],
|
||||
thresholds.get("file_access")["medium"],
|
||||
thresholds.get("file_access")["high"],
|
||||
],
|
||||
threshold_identifier="file_access",
|
||||
):
|
||||
self.low_threshold = thresholds.get("file_access")["low"]
|
||||
self.med_threshold = thresholds.get("file_access")["medium"]
|
||||
self.high_threshold = thresholds.get("file_access")["high"]
|
||||
|
||||
self.files: List[FileObservation] = files
|
||||
while len(self.files) < num_files:
|
||||
self.files.append(
|
||||
@@ -177,6 +214,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
where=None,
|
||||
include_num_access=include_num_access,
|
||||
file_system_requires_scan=self.file_system_requires_scan,
|
||||
thresholds=thresholds,
|
||||
)
|
||||
)
|
||||
while len(self.files) > num_files:
|
||||
@@ -248,6 +286,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
for file_config in config.files:
|
||||
file_config.include_num_access = config.include_num_access
|
||||
file_config.file_system_requires_scan = config.file_system_requires_scan
|
||||
file_config.thresholds = config.thresholds
|
||||
|
||||
files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files]
|
||||
return cls(
|
||||
@@ -256,4 +295,5 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
num_files=config.num_files,
|
||||
include_num_access=config.include_num_access,
|
||||
file_system_requires_scan=config.file_system_requires_scan,
|
||||
thresholds=config.thresholds,
|
||||
)
|
||||
|
||||
@@ -151,7 +151,13 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
|
||||
self.nics: List[NICObservation] = network_interfaces
|
||||
while len(self.nics) < num_nics:
|
||||
self.nics.append(NICObservation(where=None, include_nmne=include_nmne, monitored_traffic=monitored_traffic))
|
||||
self.nics.append(
|
||||
NICObservation(
|
||||
where=None,
|
||||
include_nmne=include_nmne,
|
||||
monitored_traffic=monitored_traffic,
|
||||
)
|
||||
)
|
||||
while len(self.nics) > num_nics:
|
||||
truncated_nic = self.nics.pop()
|
||||
msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}"
|
||||
@@ -257,12 +263,16 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
where = parent_where + [config.hostname]
|
||||
|
||||
# Pass down shared/common config items
|
||||
for app_config in config.applications:
|
||||
app_config.thresholds = config.thresholds
|
||||
for folder_config in config.folders:
|
||||
folder_config.include_num_access = config.include_num_access
|
||||
folder_config.num_files = config.num_files
|
||||
folder_config.file_system_requires_scan = config.file_system_requires_scan
|
||||
folder_config.thresholds = config.thresholds
|
||||
for nic_config in config.network_interfaces:
|
||||
nic_config.include_nmne = config.include_nmne
|
||||
nic_config.thresholds = config.thresholds
|
||||
|
||||
services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services]
|
||||
applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications]
|
||||
@@ -273,7 +283,10 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
count = 1
|
||||
while len(nics) < config.num_nics:
|
||||
nic_config = NICObservation.ConfigSchema(
|
||||
nic_num=count, include_nmne=config.include_nmne, monitored_traffic=config.monitored_traffic
|
||||
nic_num=count,
|
||||
include_nmne=config.include_nmne,
|
||||
monitored_traffic=config.monitored_traffic,
|
||||
thresholds=config.thresholds,
|
||||
)
|
||||
nics.append(NICObservation.from_config(config=nic_config, parent_where=where))
|
||||
count += 1
|
||||
|
||||
@@ -24,7 +24,13 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
monitored_traffic: Optional[Dict] = None
|
||||
"""A dict containing which traffic types are to be included in the observation."""
|
||||
|
||||
def __init__(self, where: WhereType, include_nmne: bool, monitored_traffic: Optional[Dict] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
include_nmne: bool,
|
||||
monitored_traffic: Optional[Dict] = None,
|
||||
thresholds: Optional[Dict] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a network interface observation instance.
|
||||
|
||||
@@ -44,10 +50,22 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
self.nmne_inbound_last_step: int = 0
|
||||
self.nmne_outbound_last_step: int = 0
|
||||
|
||||
# TODO: allow these to be configured in yaml
|
||||
self.high_nmne_threshold = 10
|
||||
self.med_nmne_threshold = 5
|
||||
self.low_nmne_threshold = 0
|
||||
if thresholds.get("nmne") is None:
|
||||
self.low_threshold = 0
|
||||
self.med_threshold = 5
|
||||
self.high_threshold = 10
|
||||
else:
|
||||
if self._validate_thresholds(
|
||||
thresholds=[
|
||||
thresholds.get("nmne")["low"],
|
||||
thresholds.get("nmne")["medium"],
|
||||
thresholds.get("nmne")["high"],
|
||||
],
|
||||
threshold_identifier="nmne",
|
||||
):
|
||||
self.low_threshold = thresholds.get("nmne")["low"]
|
||||
self.med_threshold = thresholds.get("nmne")["medium"]
|
||||
self.high_threshold = thresholds.get("nmne")["high"]
|
||||
|
||||
self.monitored_traffic = monitored_traffic
|
||||
if self.monitored_traffic:
|
||||
@@ -86,11 +104,11 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
:param nmne_count: Number of MNEs detected.
|
||||
:return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count.
|
||||
"""
|
||||
if nmne_count > self.high_nmne_threshold:
|
||||
if nmne_count > self.high_threshold:
|
||||
return 3
|
||||
elif nmne_count > self.med_nmne_threshold:
|
||||
elif nmne_count > self.med_threshold:
|
||||
return 2
|
||||
elif nmne_count > self.low_nmne_threshold:
|
||||
elif nmne_count > self.low_threshold:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
@@ -224,6 +242,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
where=parent_where + ["NICs", config.nic_num],
|
||||
include_nmne=config.include_nmne,
|
||||
monitored_traffic=config.monitored_traffic,
|
||||
thresholds=config.thresholds,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -195,6 +195,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
host_config.file_system_requires_scan = config.file_system_requires_scan
|
||||
if host_config.include_users is None:
|
||||
host_config.include_users = config.include_users
|
||||
if host_config.thresholds is None:
|
||||
host_config.thresholds = config.thresholds
|
||||
|
||||
for router_config in config.routers:
|
||||
if router_config.num_ports is None:
|
||||
@@ -211,6 +213,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
router_config.num_rules = config.num_rules
|
||||
if router_config.include_users is None:
|
||||
router_config.include_users = config.include_users
|
||||
if router_config.thresholds is None:
|
||||
router_config.thresholds = config.thresholds
|
||||
|
||||
for firewall_config in config.firewalls:
|
||||
if firewall_config.ip_list is None:
|
||||
@@ -225,6 +229,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
firewall_config.num_rules = config.num_rules
|
||||
if firewall_config.include_users is None:
|
||||
firewall_config.include_users = config.include_users
|
||||
if firewall_config.thresholds is None:
|
||||
firewall_config.thresholds = config.thresholds
|
||||
|
||||
hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts]
|
||||
routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers]
|
||||
|
||||
@@ -113,7 +113,9 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"):
|
||||
instances = dict()
|
||||
for component in config.components:
|
||||
obs_class = AbstractObservation._registry[component.type]
|
||||
obs_instance = obs_class.from_config(config=obs_class.ConfigSchema(**component.options))
|
||||
obs_instance = obs_class.from_config(
|
||||
config=obs_class.ConfigSchema(**component.options, thresholds=config.thresholds)
|
||||
)
|
||||
instances[component.label] = obs_instance
|
||||
return cls(components=instances)
|
||||
|
||||
@@ -176,7 +178,7 @@ class ObservationManager:
|
||||
return self.obs.space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Optional[Dict]) -> "ObservationManager":
|
||||
def from_config(cls, config: Optional[Dict], thresholds: Optional[Dict] = {}) -> "ObservationManager":
|
||||
"""
|
||||
Create observation space from a config.
|
||||
|
||||
@@ -187,11 +189,15 @@ class ObservationManager:
|
||||
AbstractObservation
|
||||
options: this must adhere to the chosen observation type's ConfigSchema nested class.
|
||||
:type config: Dict
|
||||
:param thresholds: Dictionary containing the observation thresholds.
|
||||
:type thresholds: Optional[Dict]
|
||||
"""
|
||||
if config is None:
|
||||
return cls(NullObservation())
|
||||
obs_type = config["type"]
|
||||
obs_class = AbstractObservation._registry[obs_type]
|
||||
observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"]))
|
||||
observation = obs_class.from_config(
|
||||
config=obs_class.ConfigSchema(**config["options"], thresholds=thresholds),
|
||||
)
|
||||
obs_manager = cls(observation)
|
||||
return obs_manager
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
"""Manages the observation space for the agent."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Iterable, Optional, Type, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Type, Union
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -19,6 +19,9 @@ class AbstractObservation(ABC):
|
||||
class ConfigSchema(ABC, BaseModel):
|
||||
"""Config schema for observations."""
|
||||
|
||||
thresholds: Optional[Dict] = None
|
||||
"""A dict containing the observation thresholds."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
_registry: Dict[str, Type["AbstractObservation"]] = {}
|
||||
@@ -67,3 +70,34 @@ class AbstractObservation(ABC):
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> "AbstractObservation":
|
||||
"""Create this observation space component form a serialised format."""
|
||||
return cls()
|
||||
|
||||
def _validate_thresholds(self, thresholds: List[int] = None, threshold_identifier: Optional[str] = "") -> bool:
|
||||
"""
|
||||
Method that checks if the thresholds are non overlapping and in the correct (ascending) order.
|
||||
|
||||
Pass in the thresholds from low to high e.g.
|
||||
thresholds=[low_threshold, med_threshold, ..._threshold, high_threshold]
|
||||
|
||||
Throws an error if the threshold is not valid
|
||||
|
||||
:param: thresholds: List of thresholds in ascending order.
|
||||
:type: List[int]
|
||||
:param: threshold_identifier: The name of the threshold option.
|
||||
:type: Optional[str]
|
||||
|
||||
:returns: bool
|
||||
"""
|
||||
if thresholds is None or len(thresholds) < 2:
|
||||
raise Exception(f"{threshold_identifier} thresholds are invalid {thresholds}")
|
||||
for idx in range(1, len(thresholds)):
|
||||
if not isinstance(thresholds[idx], int):
|
||||
raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.")
|
||||
if not isinstance(thresholds[idx - 1], int):
|
||||
raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.")
|
||||
|
||||
if thresholds[idx] <= thresholds[idx - 1]:
|
||||
raise Exception(
|
||||
f"{threshold_identifier} threshold ({thresholds[idx]}) "
|
||||
f"is greater than or equal to ({thresholds[idx - 1]}.)"
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -82,7 +82,7 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
|
||||
application_name: str
|
||||
"""Name of the application, used for querying simulation state dictionary"""
|
||||
|
||||
def __init__(self, where: WhereType) -> None:
|
||||
def __init__(self, where: WhereType, thresholds: Optional[Dict] = {}) -> None:
|
||||
"""
|
||||
Initialise an application observation instance.
|
||||
|
||||
@@ -94,16 +94,28 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
|
||||
self.where = where
|
||||
self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0}
|
||||
|
||||
# TODO: allow these to be configured in yaml
|
||||
self.high_threshold = 10
|
||||
self.med_threshold = 5
|
||||
self.low_threshold = 0
|
||||
if thresholds.get("app_executions") is None:
|
||||
self.low_threshold = 0
|
||||
self.med_threshold = 5
|
||||
self.high_threshold = 10
|
||||
else:
|
||||
if self._validate_thresholds(
|
||||
thresholds=[
|
||||
thresholds.get("app_executions")["low"],
|
||||
thresholds.get("app_executions")["medium"],
|
||||
thresholds.get("app_executions")["high"],
|
||||
],
|
||||
threshold_identifier="app_executions",
|
||||
):
|
||||
self.low_threshold = thresholds.get("app_executions")["low"]
|
||||
self.med_threshold = thresholds.get("app_executions")["medium"]
|
||||
self.high_threshold = thresholds.get("app_executions")["high"]
|
||||
|
||||
def _categorise_num_executions(self, num_executions: int) -> int:
|
||||
"""
|
||||
Represent number of file accesses as a categorical variable.
|
||||
Represent number of application executions as a categorical variable.
|
||||
|
||||
:param num_access: Number of file accesses.
|
||||
:param num_access: Number of application executions.
|
||||
:return: Bin number corresponding to the number of accesses.
|
||||
"""
|
||||
if num_executions > self.high_threshold:
|
||||
@@ -161,4 +173,4 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
|
||||
:return: Constructed application observation instance.
|
||||
:rtype: ApplicationObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["applications", config.application_name])
|
||||
return cls(where=parent_where + ["applications", config.application_name], thresholds=config.thresholds)
|
||||
|
||||
@@ -531,7 +531,7 @@ class PrimaiteGame:
|
||||
reward_function_cfg = agent_cfg["reward_function"]
|
||||
|
||||
# CREATE OBSERVATION SPACE
|
||||
obs_space = ObservationManager.from_config(observation_space_cfg)
|
||||
obs_space = ObservationManager.from_config(config=observation_space_cfg, thresholds=game.options.thresholds)
|
||||
|
||||
# CREATE ACTION SPACE
|
||||
action_space = ActionManager.from_config(game, action_space_cfg)
|
||||
|
||||
@@ -25,7 +25,19 @@ game:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
|
||||
thresholds:
|
||||
nmne:
|
||||
high: 100
|
||||
medium: 25
|
||||
low: 5
|
||||
file_access:
|
||||
high: 10
|
||||
medium: 5
|
||||
low: 2
|
||||
app_executions:
|
||||
high: 5
|
||||
medium: 3
|
||||
low: 2
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
@@ -79,10 +91,16 @@ agents:
|
||||
options:
|
||||
hosts:
|
||||
- hostname: client_1
|
||||
applications:
|
||||
- application_name: WebBrowser
|
||||
folders:
|
||||
- folder_name: root
|
||||
files:
|
||||
- file_name: "test.txt"
|
||||
- hostname: client_2
|
||||
- hostname: client_3
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_applications: 1
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
@@ -219,6 +237,9 @@ simulation:
|
||||
options:
|
||||
ntp_server_ip: 192.168.1.10
|
||||
- type: NTPServer
|
||||
file_system:
|
||||
- root:
|
||||
- "test.txt"
|
||||
- hostname: client_2
|
||||
type: computer
|
||||
ip_address: 192.168.10.22
|
||||
|
||||
@@ -8,7 +8,7 @@ from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
|
||||
BASIC_SWITCHED_NETWORK_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
|
||||
|
||||
|
||||
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
|
||||
@@ -24,3 +24,42 @@ def test_thresholds():
|
||||
game = load_config(data_manipulation_config_path())
|
||||
|
||||
assert game.options.thresholds is not None
|
||||
|
||||
|
||||
def test_nmne_threshold():
|
||||
"""Test that the NMNE thresholds are properly loaded in by observation."""
|
||||
game = load_config(BASIC_SWITCHED_NETWORK_CONFIG)
|
||||
|
||||
assert game.options.thresholds["nmne"] is not None
|
||||
|
||||
# get NIC observation
|
||||
nic_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].nics[0]
|
||||
assert nic_obs.low_threshold == 5
|
||||
assert nic_obs.med_threshold == 25
|
||||
assert nic_obs.high_threshold == 100
|
||||
|
||||
|
||||
def test_file_access_threshold():
|
||||
"""Test that the NMNE thresholds are properly loaded in by observation."""
|
||||
game = load_config(BASIC_SWITCHED_NETWORK_CONFIG)
|
||||
|
||||
assert game.options.thresholds["file_access"] is not None
|
||||
|
||||
# get file observation
|
||||
file_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].folders[0].files[0]
|
||||
assert file_obs.low_threshold == 2
|
||||
assert file_obs.med_threshold == 5
|
||||
assert file_obs.high_threshold == 10
|
||||
|
||||
|
||||
def test_app_executions_threshold():
|
||||
"""Test that the NMNE thresholds are properly loaded in by observation."""
|
||||
game = load_config(BASIC_SWITCHED_NETWORK_CONFIG)
|
||||
|
||||
assert game.options.thresholds["app_executions"] is not None
|
||||
|
||||
# get application observation
|
||||
app_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].applications[0]
|
||||
assert app_obs.low_threshold == 2
|
||||
assert app_obs.med_threshold == 3
|
||||
assert app_obs.high_threshold == 5
|
||||
|
||||
@@ -44,6 +44,38 @@ def test_file_observation(simulation):
|
||||
assert observation_state.get("health_status") == 3 # corrupted
|
||||
|
||||
|
||||
def test_config_file_access_categories(simulation):
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
file_obs = FileObservation(
|
||||
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
|
||||
include_num_access=False,
|
||||
file_system_requires_scan=True,
|
||||
thresholds={"file_access": {"low": 3, "medium": 6, "high": 9}},
|
||||
)
|
||||
|
||||
assert file_obs.high_threshold == 9
|
||||
assert file_obs.med_threshold == 6
|
||||
assert file_obs.low_threshold == 3
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
FileObservation(
|
||||
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
|
||||
include_num_access=False,
|
||||
file_system_requires_scan=True,
|
||||
thresholds={"file_access": {"low": 9, "medium": 6, "high": 9}},
|
||||
)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
FileObservation(
|
||||
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
|
||||
include_num_access=False,
|
||||
file_system_requires_scan=True,
|
||||
thresholds={"file_access": {"low": 3, "medium": 9, "high": 9}},
|
||||
)
|
||||
|
||||
|
||||
def test_folder_observation(simulation):
|
||||
"""Test the folder observation."""
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
@@ -110,33 +110,28 @@ def test_nic_categories(simulation):
|
||||
|
||||
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True)
|
||||
|
||||
assert nic_obs.high_nmne_threshold == 10 # default
|
||||
assert nic_obs.med_nmne_threshold == 5 # default
|
||||
assert nic_obs.low_nmne_threshold == 0 # default
|
||||
assert nic_obs.high_threshold == 10 # default
|
||||
assert nic_obs.med_threshold == 5 # default
|
||||
assert nic_obs.low_threshold == 0 # default
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Feature not implemented yet")
|
||||
def test_config_nic_categories(simulation):
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
nic_obs = NICObservation(
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1],
|
||||
low_nmne_threshold=3,
|
||||
med_nmne_threshold=6,
|
||||
high_nmne_threshold=9,
|
||||
thresholds={"nmne": {"low": 3, "medium": 6, "high": 9}},
|
||||
include_nmne=True,
|
||||
)
|
||||
|
||||
assert nic_obs.high_nmne_threshold == 9
|
||||
assert nic_obs.med_nmne_threshold == 6
|
||||
assert nic_obs.low_nmne_threshold == 3
|
||||
assert nic_obs.high_threshold == 9
|
||||
assert nic_obs.med_threshold == 6
|
||||
assert nic_obs.low_threshold == 3
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
NICObservation(
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1],
|
||||
low_nmne_threshold=9,
|
||||
med_nmne_threshold=6,
|
||||
high_nmne_threshold=9,
|
||||
thresholds={"nmne": {"low": 9, "medium": 6, "high": 9}},
|
||||
include_nmne=True,
|
||||
)
|
||||
|
||||
@@ -144,9 +139,7 @@ def test_config_nic_categories(simulation):
|
||||
# should throw an error
|
||||
NICObservation(
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1],
|
||||
low_nmne_threshold=3,
|
||||
med_nmne_threshold=9,
|
||||
high_nmne_threshold=9,
|
||||
thresholds={"nmne": {"low": 3, "medium": 9, "high": 9}},
|
||||
include_nmne=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -69,3 +69,30 @@ def test_application_observation(simulation):
|
||||
assert observation_state.get("health_status") == 1
|
||||
assert observation_state.get("operating_status") == 1 # running
|
||||
assert observation_state.get("num_executions") == 1
|
||||
|
||||
|
||||
def test_application_executions_categories(simulation):
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
app_obs = ApplicationObservation(
|
||||
where=["network", "nodes", pc.hostname, "applications", "WebBrowser"],
|
||||
thresholds={"app_executions": {"low": 3, "medium": 6, "high": 9}},
|
||||
)
|
||||
|
||||
assert app_obs.high_threshold == 9
|
||||
assert app_obs.med_threshold == 6
|
||||
assert app_obs.low_threshold == 3
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
ApplicationObservation(
|
||||
where=["network", "nodes", pc.hostname, "applications", "WebBrowser"],
|
||||
thresholds={"app_executions": {"low": 9, "medium": 6, "high": 9}},
|
||||
)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
ApplicationObservation(
|
||||
where=["network", "nodes", pc.hostname, "applications", "WebBrowser"],
|
||||
thresholds={"app_executions": {"low": 3, "medium": 9, "high": 9}},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user