Merged PR 548: #2445: added the ability to pass the game options thresholds to observations...
## Summary Added ability to pass the game options thresholds into observation classes This will allow for NICObservation, FileObservation (and FolderObservation) and ApplicationObservation to be able to get the thresholds for the training session. i.e. Allows for the thresholds for NMNE, file access and application executions to be configurable for training. ## Test process https://dev.azure.com/ma-dev-uk/PrimAITE/_git/PrimAITE/pullrequest/548?_a=files&path=/tests/integration_tests/configuration_file_parsing/test_game_options_config.py ## Checklist - [X] PR is linked to a **work item** - [X] **acceptance criteria** of linked ticket are met - [X] performed **self-review** of the code - [X] written **tests** for any new functionality added with this PR - [ ] updated the **documentation** if this PR changes or adds functionality - [ ] written/updated **design docs** if this PR implements new functionality - [ ] updated the **change log** - [X] ran **pre-commit** checks for code style - [ ] attended to any **TO-DOs** left in the code #2445: added the ability to pass the game options thresholds to observations so that relevant observation items can retrieve the thresholds from config Related work items: #2445
This commit is contained in:
@@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Log observation space data by episode and step.
|
||||
- Added `show_history` method to Agents, allowing you to view actions taken by an agent per step. By default, `DONOTHING` actions are omitted.
|
||||
- New ``NODE_SEND_LOCAL_COMMAND`` action implemented which grants agents the ability to execute commands locally. (Previously limited to remote only)
|
||||
- Added ability to set the observation threshold for NMNE, file access and application executions
|
||||
|
||||
### Changed
|
||||
- ACL's are no longer applied to layer-2 traffic.
|
||||
|
||||
@@ -26,7 +26,13 @@ class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
file_system_requires_scan: Optional[bool] = None
|
||||
"""If True, the file must be scanned to update the health state. Tf False, the true state is always shown."""
|
||||
|
||||
def __init__(self, where: WhereType, include_num_access: bool, file_system_requires_scan: bool) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
include_num_access: bool,
|
||||
file_system_requires_scan: bool,
|
||||
thresholds: Optional[Dict] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a file observation instance.
|
||||
|
||||
@@ -48,10 +54,36 @@ class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
if self.include_num_access:
|
||||
self.default_observation["num_access"] = 0
|
||||
|
||||
# TODO: allow these to be configured in yaml
|
||||
self.high_threshold = 10
|
||||
self.med_threshold = 5
|
||||
self.low_threshold = 0
|
||||
if thresholds.get("file_access") is None:
|
||||
self.low_file_access_threshold = 0
|
||||
self.med_file_access_threshold = 5
|
||||
self.high_file_access_threshold = 10
|
||||
else:
|
||||
self._set_file_access_threshold(
|
||||
thresholds=[
|
||||
thresholds.get("file_access")["low"],
|
||||
thresholds.get("file_access")["medium"],
|
||||
thresholds.get("file_access")["high"],
|
||||
]
|
||||
)
|
||||
|
||||
def _set_file_access_threshold(self, thresholds: List[int]):
|
||||
"""
|
||||
Method that validates and then sets the file access threshold.
|
||||
|
||||
:param: thresholds: The file access threshold to validate and set.
|
||||
"""
|
||||
if self._validate_thresholds(
|
||||
thresholds=[
|
||||
thresholds[0],
|
||||
thresholds[1],
|
||||
thresholds[2],
|
||||
],
|
||||
threshold_identifier="file_access",
|
||||
):
|
||||
self.low_file_access_threshold = thresholds[0]
|
||||
self.med_file_access_threshold = thresholds[1]
|
||||
self.high_file_access_threshold = thresholds[2]
|
||||
|
||||
def _categorise_num_access(self, num_access: int) -> int:
|
||||
"""
|
||||
@@ -60,11 +92,11 @@ class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
:param num_access: Number of file accesses.
|
||||
:return: Bin number corresponding to the number of accesses.
|
||||
"""
|
||||
if num_access > self.high_threshold:
|
||||
if num_access > self.high_file_access_threshold:
|
||||
return 3
|
||||
elif num_access > self.med_threshold:
|
||||
elif num_access > self.med_file_access_threshold:
|
||||
return 2
|
||||
elif num_access > self.low_threshold:
|
||||
elif num_access > self.low_file_access_threshold:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
@@ -122,6 +154,7 @@ class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
where=parent_where + ["files", config.file_name],
|
||||
include_num_access=config.include_num_access,
|
||||
file_system_requires_scan=config.file_system_requires_scan,
|
||||
thresholds=config.thresholds,
|
||||
)
|
||||
|
||||
|
||||
@@ -149,6 +182,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
num_files: int,
|
||||
include_num_access: bool,
|
||||
file_system_requires_scan: bool,
|
||||
thresholds: Optional[Dict] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a folder observation instance.
|
||||
@@ -177,6 +211,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
where=None,
|
||||
include_num_access=include_num_access,
|
||||
file_system_requires_scan=self.file_system_requires_scan,
|
||||
thresholds=thresholds,
|
||||
)
|
||||
)
|
||||
while len(self.files) > num_files:
|
||||
@@ -248,6 +283,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
for file_config in config.files:
|
||||
file_config.include_num_access = config.include_num_access
|
||||
file_config.file_system_requires_scan = config.file_system_requires_scan
|
||||
file_config.thresholds = config.thresholds
|
||||
|
||||
files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files]
|
||||
return cls(
|
||||
@@ -256,4 +292,5 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
num_files=config.num_files,
|
||||
include_num_access=config.include_num_access,
|
||||
file_system_requires_scan=config.file_system_requires_scan,
|
||||
thresholds=config.thresholds,
|
||||
)
|
||||
|
||||
@@ -169,7 +169,13 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
|
||||
self.nics: List[NICObservation] = network_interfaces
|
||||
while len(self.nics) < num_nics:
|
||||
self.nics.append(NICObservation(where=None, include_nmne=include_nmne, monitored_traffic=monitored_traffic))
|
||||
self.nics.append(
|
||||
NICObservation(
|
||||
where=None,
|
||||
include_nmne=include_nmne,
|
||||
monitored_traffic=monitored_traffic,
|
||||
)
|
||||
)
|
||||
while len(self.nics) > num_nics:
|
||||
truncated_nic = self.nics.pop()
|
||||
msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}"
|
||||
@@ -279,12 +285,15 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
folder_config.include_num_access = config.include_num_access
|
||||
folder_config.num_files = config.num_files
|
||||
folder_config.file_system_requires_scan = config.file_system_requires_scan
|
||||
folder_config.thresholds = config.thresholds
|
||||
for nic_config in config.network_interfaces:
|
||||
nic_config.include_nmne = config.include_nmne
|
||||
nic_config.thresholds = config.thresholds
|
||||
for service_config in config.services:
|
||||
service_config.services_requires_scan = config.services_requires_scan
|
||||
for application_config in config.applications:
|
||||
application_config.applications_requires_scan = config.applications_requires_scan
|
||||
application_config.thresholds = config.thresholds
|
||||
|
||||
services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services]
|
||||
applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications]
|
||||
@@ -295,7 +304,10 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
count = 1
|
||||
while len(nics) < config.num_nics:
|
||||
nic_config = NICObservation.ConfigSchema(
|
||||
nic_num=count, include_nmne=config.include_nmne, monitored_traffic=config.monitored_traffic
|
||||
nic_num=count,
|
||||
include_nmne=config.include_nmne,
|
||||
monitored_traffic=config.monitored_traffic,
|
||||
thresholds=config.thresholds,
|
||||
)
|
||||
nics.append(NICObservation.from_config(config=nic_config, parent_where=where))
|
||||
count += 1
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar, Dict, Optional
|
||||
from typing import ClassVar, Dict, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -28,7 +28,13 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
monitored_traffic: Optional[Dict] = None
|
||||
"""A dict containing which traffic types are to be included in the observation."""
|
||||
|
||||
def __init__(self, where: WhereType, include_nmne: bool, monitored_traffic: Optional[Dict] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
include_nmne: bool,
|
||||
monitored_traffic: Optional[Dict] = None,
|
||||
thresholds: Optional[Dict] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a network interface observation instance.
|
||||
|
||||
@@ -48,10 +54,18 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
self.nmne_inbound_last_step: int = 0
|
||||
self.nmne_outbound_last_step: int = 0
|
||||
|
||||
# TODO: allow these to be configured in yaml
|
||||
self.high_nmne_threshold = 10
|
||||
self.med_nmne_threshold = 5
|
||||
self.low_nmne_threshold = 0
|
||||
if thresholds.get("nmne") is None:
|
||||
self.low_nmne_threshold = 0
|
||||
self.med_nmne_threshold = 5
|
||||
self.high_nmne_threshold = 10
|
||||
else:
|
||||
self._set_nmne_threshold(
|
||||
thresholds=[
|
||||
thresholds.get("nmne")["low"],
|
||||
thresholds.get("nmne")["medium"],
|
||||
thresholds.get("nmne")["high"],
|
||||
]
|
||||
)
|
||||
|
||||
self.monitored_traffic = monitored_traffic
|
||||
if self.monitored_traffic:
|
||||
@@ -108,6 +122,20 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
bandwidth_utilisation = traffic_value / nic_max_bandwidth
|
||||
return int(bandwidth_utilisation * 9) + 1
|
||||
|
||||
def _set_nmne_threshold(self, thresholds: List[int]):
|
||||
"""
|
||||
Method that validates and then sets the NMNE threshold.
|
||||
|
||||
:param: thresholds: The NMNE threshold to validate and set.
|
||||
"""
|
||||
if self._validate_thresholds(
|
||||
thresholds=thresholds,
|
||||
threshold_identifier="nmne",
|
||||
):
|
||||
self.low_nmne_threshold = thresholds[0]
|
||||
self.med_nmne_threshold = thresholds[1]
|
||||
self.high_nmne_threshold = thresholds[2]
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
@@ -228,6 +256,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
where=parent_where + ["NICs", config.nic_num],
|
||||
include_nmne=config.include_nmne,
|
||||
monitored_traffic=config.monitored_traffic,
|
||||
thresholds=config.thresholds,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -205,6 +205,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
host_config.applications_requires_scan = config.applications_requires_scan
|
||||
if host_config.include_users is None:
|
||||
host_config.include_users = config.include_users
|
||||
if host_config.thresholds is None:
|
||||
host_config.thresholds = config.thresholds
|
||||
|
||||
for router_config in config.routers:
|
||||
if router_config.num_ports is None:
|
||||
@@ -221,6 +223,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
router_config.num_rules = config.num_rules
|
||||
if router_config.include_users is None:
|
||||
router_config.include_users = config.include_users
|
||||
if router_config.thresholds is None:
|
||||
router_config.thresholds = config.thresholds
|
||||
|
||||
for firewall_config in config.firewalls:
|
||||
if firewall_config.ip_list is None:
|
||||
@@ -235,6 +239,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
firewall_config.num_rules = config.num_rules
|
||||
if firewall_config.include_users is None:
|
||||
firewall_config.include_users = config.include_users
|
||||
if firewall_config.thresholds is None:
|
||||
firewall_config.thresholds = config.thresholds
|
||||
|
||||
hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts]
|
||||
routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers]
|
||||
|
||||
@@ -113,7 +113,9 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"):
|
||||
instances = dict()
|
||||
for component in config.components:
|
||||
obs_class = AbstractObservation._registry[component.type]
|
||||
obs_instance = obs_class.from_config(config=obs_class.ConfigSchema(**component.options))
|
||||
obs_instance = obs_class.from_config(
|
||||
config=obs_class.ConfigSchema(**component.options, thresholds=config.thresholds)
|
||||
)
|
||||
instances[component.label] = obs_instance
|
||||
return cls(components=instances)
|
||||
|
||||
@@ -176,7 +178,7 @@ class ObservationManager:
|
||||
return self.obs.space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Optional[Dict]) -> "ObservationManager":
|
||||
def from_config(cls, config: Optional[Dict], thresholds: Optional[Dict] = {}) -> "ObservationManager":
|
||||
"""
|
||||
Create observation space from a config.
|
||||
|
||||
@@ -187,11 +189,15 @@ class ObservationManager:
|
||||
AbstractObservation
|
||||
options: this must adhere to the chosen observation type's ConfigSchema nested class.
|
||||
:type config: Dict
|
||||
:param thresholds: Dictionary containing the observation thresholds.
|
||||
:type thresholds: Optional[Dict]
|
||||
"""
|
||||
if config is None:
|
||||
return cls(NullObservation())
|
||||
obs_type = config["type"]
|
||||
obs_class = AbstractObservation._registry[obs_type]
|
||||
observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"]))
|
||||
observation = obs_class.from_config(
|
||||
config=obs_class.ConfigSchema(**config["options"], thresholds=thresholds),
|
||||
)
|
||||
obs_manager = cls(observation)
|
||||
return obs_manager
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
"""Manages the observation space for the agent."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Iterable, Optional, Type, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Type, Union
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -19,6 +19,9 @@ class AbstractObservation(ABC):
|
||||
class ConfigSchema(ABC, BaseModel):
|
||||
"""Config schema for observations."""
|
||||
|
||||
thresholds: Optional[Dict] = None
|
||||
"""A dict containing the observation thresholds."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
_registry: Dict[str, Type["AbstractObservation"]] = {}
|
||||
@@ -67,3 +70,34 @@ class AbstractObservation(ABC):
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> "AbstractObservation":
|
||||
"""Create this observation space component form a serialised format."""
|
||||
return cls()
|
||||
|
||||
def _validate_thresholds(self, thresholds: List[int] = None, threshold_identifier: Optional[str] = "") -> bool:
|
||||
"""
|
||||
Method that checks if the thresholds are non overlapping and in the correct (ascending) order.
|
||||
|
||||
Pass in the thresholds from low to high e.g.
|
||||
thresholds=[low_threshold, med_threshold, ..._threshold, high_threshold]
|
||||
|
||||
Throws an error if the threshold is not valid
|
||||
|
||||
:param: thresholds: List of thresholds in ascending order.
|
||||
:type: List[int]
|
||||
:param: threshold_identifier: The name of the threshold option.
|
||||
:type: Optional[str]
|
||||
|
||||
:returns: bool
|
||||
"""
|
||||
if thresholds is None or len(thresholds) < 2:
|
||||
raise Exception(f"{threshold_identifier} thresholds are invalid {thresholds}")
|
||||
for idx in range(1, len(thresholds)):
|
||||
if not isinstance(thresholds[idx], int):
|
||||
raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.")
|
||||
if not isinstance(thresholds[idx - 1], int):
|
||||
raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.")
|
||||
|
||||
if thresholds[idx] <= thresholds[idx - 1]:
|
||||
raise Exception(
|
||||
f"{threshold_identifier} threshold ({thresholds[idx - 1]}) "
|
||||
f"is greater than or equal to ({thresholds[idx]}.)"
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -95,7 +95,7 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
|
||||
If True, applications must be scanned to update the health state. If False, true state is always shown.
|
||||
"""
|
||||
|
||||
def __init__(self, where: WhereType, applications_requires_scan: bool) -> None:
|
||||
def __init__(self, where: WhereType, applications_requires_scan: bool, thresholds: Optional[Dict] = {}) -> None:
|
||||
"""
|
||||
Initialise an application observation instance.
|
||||
|
||||
@@ -108,23 +108,49 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
|
||||
self.applications_requires_scan = applications_requires_scan
|
||||
self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0}
|
||||
|
||||
# TODO: allow these to be configured in yaml
|
||||
self.high_threshold = 10
|
||||
self.med_threshold = 5
|
||||
self.low_threshold = 0
|
||||
if thresholds.get("app_executions") is None:
|
||||
self.low_app_execution_threshold = 0
|
||||
self.med_app_execution_threshold = 5
|
||||
self.high_app_execution_threshold = 10
|
||||
else:
|
||||
self._set_application_execution_thresholds(
|
||||
thresholds=[
|
||||
thresholds.get("app_executions")["low"],
|
||||
thresholds.get("app_executions")["medium"],
|
||||
thresholds.get("app_executions")["high"],
|
||||
]
|
||||
)
|
||||
|
||||
def _set_application_execution_thresholds(self, thresholds: List[int]):
|
||||
"""
|
||||
Method that validates and then sets the application execution threshold.
|
||||
|
||||
:param: thresholds: The application execution threshold to validate and set.
|
||||
"""
|
||||
if self._validate_thresholds(
|
||||
thresholds=[
|
||||
thresholds[0],
|
||||
thresholds[1],
|
||||
thresholds[2],
|
||||
],
|
||||
threshold_identifier="app_executions",
|
||||
):
|
||||
self.low_app_execution_threshold = thresholds[0]
|
||||
self.med_app_execution_threshold = thresholds[1]
|
||||
self.high_app_execution_threshold = thresholds[2]
|
||||
|
||||
def _categorise_num_executions(self, num_executions: int) -> int:
|
||||
"""
|
||||
Represent number of file accesses as a categorical variable.
|
||||
Represent number of application executions as a categorical variable.
|
||||
|
||||
:param num_access: Number of file accesses.
|
||||
:param num_access: Number of application executions.
|
||||
:return: Bin number corresponding to the number of accesses.
|
||||
"""
|
||||
if num_executions > self.high_threshold:
|
||||
if num_executions > self.high_app_execution_threshold:
|
||||
return 3
|
||||
elif num_executions > self.med_threshold:
|
||||
elif num_executions > self.med_app_execution_threshold:
|
||||
return 2
|
||||
elif num_executions > self.low_threshold:
|
||||
elif num_executions > self.low_app_execution_threshold:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
@@ -180,4 +206,5 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
|
||||
return cls(
|
||||
where=parent_where + ["applications", config.application_name],
|
||||
applications_requires_scan=config.applications_requires_scan,
|
||||
thresholds=config.thresholds,
|
||||
)
|
||||
|
||||
@@ -535,7 +535,7 @@ class PrimaiteGame:
|
||||
reward_function_cfg = agent_cfg["reward_function"]
|
||||
|
||||
# CREATE OBSERVATION SPACE
|
||||
obs_space = ObservationManager.from_config(observation_space_cfg)
|
||||
obs_space = ObservationManager.from_config(config=observation_space_cfg, thresholds=game.options.thresholds)
|
||||
|
||||
# CREATE ACTION SPACE
|
||||
action_space = ActionManager.from_config(game, action_space_cfg)
|
||||
|
||||
@@ -25,7 +25,19 @@ game:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
|
||||
thresholds:
|
||||
nmne:
|
||||
high: 100
|
||||
medium: 25
|
||||
low: 5
|
||||
file_access:
|
||||
high: 10
|
||||
medium: 5
|
||||
low: 2
|
||||
app_executions:
|
||||
high: 5
|
||||
medium: 3
|
||||
low: 2
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
@@ -79,10 +91,16 @@ agents:
|
||||
options:
|
||||
hosts:
|
||||
- hostname: client_1
|
||||
applications:
|
||||
- application_name: WebBrowser
|
||||
folders:
|
||||
- folder_name: root
|
||||
files:
|
||||
- file_name: "test.txt"
|
||||
- hostname: client_2
|
||||
- hostname: client_3
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_applications: 1
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
@@ -219,6 +237,9 @@ simulation:
|
||||
options:
|
||||
ntp_server_ip: 192.168.1.10
|
||||
- type: NTPServer
|
||||
file_system:
|
||||
- root:
|
||||
- "test.txt"
|
||||
- hostname: client_2
|
||||
type: computer
|
||||
ip_address: 192.168.10.22
|
||||
|
||||
@@ -8,7 +8,7 @@ from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
|
||||
BASIC_SWITCHED_NETWORK_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
|
||||
|
||||
|
||||
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
|
||||
@@ -24,3 +24,42 @@ def test_thresholds():
|
||||
game = load_config(data_manipulation_config_path())
|
||||
|
||||
assert game.options.thresholds is not None
|
||||
|
||||
|
||||
def test_nmne_threshold():
|
||||
"""Test that the NMNE thresholds are properly loaded in by observation."""
|
||||
game = load_config(BASIC_SWITCHED_NETWORK_CONFIG)
|
||||
|
||||
assert game.options.thresholds["nmne"] is not None
|
||||
|
||||
# get NIC observation
|
||||
nic_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].nics[0]
|
||||
assert nic_obs.low_nmne_threshold == 5
|
||||
assert nic_obs.med_nmne_threshold == 25
|
||||
assert nic_obs.high_nmne_threshold == 100
|
||||
|
||||
|
||||
def test_file_access_threshold():
|
||||
"""Test that the NMNE thresholds are properly loaded in by observation."""
|
||||
game = load_config(BASIC_SWITCHED_NETWORK_CONFIG)
|
||||
|
||||
assert game.options.thresholds["file_access"] is not None
|
||||
|
||||
# get file observation
|
||||
file_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].folders[0].files[0]
|
||||
assert file_obs.low_file_access_threshold == 2
|
||||
assert file_obs.med_file_access_threshold == 5
|
||||
assert file_obs.high_file_access_threshold == 10
|
||||
|
||||
|
||||
def test_app_executions_threshold():
|
||||
"""Test that the NMNE thresholds are properly loaded in by observation."""
|
||||
game = load_config(BASIC_SWITCHED_NETWORK_CONFIG)
|
||||
|
||||
assert game.options.thresholds["app_executions"] is not None
|
||||
|
||||
# get application observation
|
||||
app_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].applications[0]
|
||||
assert app_obs.low_app_execution_threshold == 2
|
||||
assert app_obs.med_app_execution_threshold == 3
|
||||
assert app_obs.high_app_execution_threshold == 5
|
||||
|
||||
@@ -44,6 +44,38 @@ def test_file_observation(simulation):
|
||||
assert observation_state.get("health_status") == 3 # corrupted
|
||||
|
||||
|
||||
def test_config_file_access_categories(simulation):
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
file_obs = FileObservation(
|
||||
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
|
||||
include_num_access=False,
|
||||
file_system_requires_scan=True,
|
||||
thresholds={"file_access": {"low": 3, "medium": 6, "high": 9}},
|
||||
)
|
||||
|
||||
assert file_obs.high_file_access_threshold == 9
|
||||
assert file_obs.med_file_access_threshold == 6
|
||||
assert file_obs.low_file_access_threshold == 3
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
FileObservation(
|
||||
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
|
||||
include_num_access=False,
|
||||
file_system_requires_scan=True,
|
||||
thresholds={"file_access": {"low": 9, "medium": 6, "high": 9}},
|
||||
)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
FileObservation(
|
||||
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
|
||||
include_num_access=False,
|
||||
file_system_requires_scan=True,
|
||||
thresholds={"file_access": {"low": 3, "medium": 9, "high": 9}},
|
||||
)
|
||||
|
||||
|
||||
def test_folder_observation(simulation):
|
||||
"""Test the folder observation."""
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
@@ -123,14 +123,11 @@ def test_nic_categories(simulation):
|
||||
assert nic_obs.low_nmne_threshold == 0 # default
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Feature not implemented yet")
|
||||
def test_config_nic_categories(simulation):
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
nic_obs = NICObservation(
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1],
|
||||
low_nmne_threshold=3,
|
||||
med_nmne_threshold=6,
|
||||
high_nmne_threshold=9,
|
||||
thresholds={"nmne": {"low": 3, "medium": 6, "high": 9}},
|
||||
include_nmne=True,
|
||||
)
|
||||
|
||||
@@ -142,9 +139,7 @@ def test_config_nic_categories(simulation):
|
||||
# should throw an error
|
||||
NICObservation(
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1],
|
||||
low_nmne_threshold=9,
|
||||
med_nmne_threshold=6,
|
||||
high_nmne_threshold=9,
|
||||
thresholds={"nmne": {"low": 9, "medium": 6, "high": 9}},
|
||||
include_nmne=True,
|
||||
)
|
||||
|
||||
@@ -152,9 +147,7 @@ def test_config_nic_categories(simulation):
|
||||
# should throw an error
|
||||
NICObservation(
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1],
|
||||
low_nmne_threshold=3,
|
||||
med_nmne_threshold=9,
|
||||
high_nmne_threshold=9,
|
||||
thresholds={"nmne": {"low": 3, "medium": 9, "high": 9}},
|
||||
include_nmne=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -73,3 +73,33 @@ def test_application_observation(simulation):
|
||||
assert observation_state.get("health_status") == 1
|
||||
assert observation_state.get("operating_status") == 1 # running
|
||||
assert observation_state.get("num_executions") == 1
|
||||
|
||||
|
||||
def test_application_executions_categories(simulation):
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
app_obs = ApplicationObservation(
|
||||
where=["network", "nodes", pc.hostname, "applications", "WebBrowser"],
|
||||
applications_requires_scan=False,
|
||||
thresholds={"app_executions": {"low": 3, "medium": 6, "high": 9}},
|
||||
)
|
||||
|
||||
assert app_obs.high_app_execution_threshold == 9
|
||||
assert app_obs.med_app_execution_threshold == 6
|
||||
assert app_obs.low_app_execution_threshold == 3
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
ApplicationObservation(
|
||||
where=["network", "nodes", pc.hostname, "applications", "WebBrowser"],
|
||||
applications_requires_scan=False,
|
||||
thresholds={"app_executions": {"low": 9, "medium": 6, "high": 9}},
|
||||
)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
ApplicationObservation(
|
||||
where=["network", "nodes", pc.hostname, "applications", "WebBrowser"],
|
||||
applications_requires_scan=False,
|
||||
thresholds={"app_executions": {"low": 3, "medium": 9, "high": 9}},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user