#2445: added the ability to pass the game options thresholds to observations so that relevant observation items can retrieve the thresholds from config

This commit is contained in:
Czar Echavez
2024-09-17 12:19:35 +01:00
parent eb24d1270b
commit 4391d7cdd5
13 changed files with 290 additions and 48 deletions

View File

@@ -26,7 +26,13 @@ class FileObservation(AbstractObservation, identifier="FILE"):
file_system_requires_scan: Optional[bool] = None
"""If True, the file must be scanned to update the health state. Tf False, the true state is always shown."""
def __init__(self, where: WhereType, include_num_access: bool, file_system_requires_scan: bool) -> None:
def __init__(
self,
where: WhereType,
include_num_access: bool,
file_system_requires_scan: bool,
thresholds: Optional[Dict] = {},
) -> None:
"""
Initialise a file observation instance.
@@ -48,10 +54,22 @@ class FileObservation(AbstractObservation, identifier="FILE"):
if self.include_num_access:
self.default_observation["num_access"] = 0
# TODO: allow these to be configured in yaml
self.high_threshold = 10
self.med_threshold = 5
self.low_threshold = 0
if thresholds.get("file_access") is None:
self.low_threshold = 0
self.med_threshold = 5
self.high_threshold = 10
else:
if self._validate_thresholds(
thresholds=[
thresholds.get("file_access")["low"],
thresholds.get("file_access")["medium"],
thresholds.get("file_access")["high"],
],
threshold_identifier="file_access",
):
self.low_threshold = thresholds.get("file_access")["low"]
self.med_threshold = thresholds.get("file_access")["medium"]
self.high_threshold = thresholds.get("file_access")["high"]
def _categorise_num_access(self, num_access: int) -> int:
"""
@@ -122,6 +140,7 @@ class FileObservation(AbstractObservation, identifier="FILE"):
where=parent_where + ["files", config.file_name],
include_num_access=config.include_num_access,
file_system_requires_scan=config.file_system_requires_scan,
thresholds=config.thresholds,
)
@@ -149,6 +168,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
num_files: int,
include_num_access: bool,
file_system_requires_scan: bool,
thresholds: Optional[Dict] = {},
) -> None:
"""
Initialise a folder observation instance.
@@ -170,6 +190,23 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
self.file_system_requires_scan: bool = file_system_requires_scan
if thresholds.get("file_access") is None:
self.low_threshold = 0
self.med_threshold = 5
self.high_threshold = 10
else:
if self._validate_thresholds(
thresholds=[
thresholds.get("file_access")["low"],
thresholds.get("file_access")["medium"],
thresholds.get("file_access")["high"],
],
threshold_identifier="file_access",
):
self.low_threshold = thresholds.get("file_access")["low"]
self.med_threshold = thresholds.get("file_access")["medium"]
self.high_threshold = thresholds.get("file_access")["high"]
self.files: List[FileObservation] = files
while len(self.files) < num_files:
self.files.append(
@@ -177,6 +214,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
where=None,
include_num_access=include_num_access,
file_system_requires_scan=self.file_system_requires_scan,
thresholds=thresholds,
)
)
while len(self.files) > num_files:
@@ -248,6 +286,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
for file_config in config.files:
file_config.include_num_access = config.include_num_access
file_config.file_system_requires_scan = config.file_system_requires_scan
file_config.thresholds = config.thresholds
files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files]
return cls(
@@ -256,4 +295,5 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
num_files=config.num_files,
include_num_access=config.include_num_access,
file_system_requires_scan=config.file_system_requires_scan,
thresholds=config.thresholds,
)

View File

@@ -151,7 +151,13 @@ class HostObservation(AbstractObservation, identifier="HOST"):
self.nics: List[NICObservation] = network_interfaces
while len(self.nics) < num_nics:
self.nics.append(NICObservation(where=None, include_nmne=include_nmne, monitored_traffic=monitored_traffic))
self.nics.append(
NICObservation(
where=None,
include_nmne=include_nmne,
monitored_traffic=monitored_traffic,
)
)
while len(self.nics) > num_nics:
truncated_nic = self.nics.pop()
msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}"
@@ -257,12 +263,16 @@ class HostObservation(AbstractObservation, identifier="HOST"):
where = parent_where + [config.hostname]
# Pass down shared/common config items
for app_config in config.applications:
app_config.thresholds = config.thresholds
for folder_config in config.folders:
folder_config.include_num_access = config.include_num_access
folder_config.num_files = config.num_files
folder_config.file_system_requires_scan = config.file_system_requires_scan
folder_config.thresholds = config.thresholds
for nic_config in config.network_interfaces:
nic_config.include_nmne = config.include_nmne
nic_config.thresholds = config.thresholds
services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services]
applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications]
@@ -273,7 +283,10 @@ class HostObservation(AbstractObservation, identifier="HOST"):
count = 1
while len(nics) < config.num_nics:
nic_config = NICObservation.ConfigSchema(
nic_num=count, include_nmne=config.include_nmne, monitored_traffic=config.monitored_traffic
nic_num=count,
include_nmne=config.include_nmne,
monitored_traffic=config.monitored_traffic,
thresholds=config.thresholds,
)
nics.append(NICObservation.from_config(config=nic_config, parent_where=where))
count += 1

View File

@@ -24,7 +24,13 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
monitored_traffic: Optional[Dict] = None
"""A dict containing which traffic types are to be included in the observation."""
def __init__(self, where: WhereType, include_nmne: bool, monitored_traffic: Optional[Dict] = None) -> None:
def __init__(
self,
where: WhereType,
include_nmne: bool,
monitored_traffic: Optional[Dict] = None,
thresholds: Optional[Dict] = {},
) -> None:
"""
Initialise a network interface observation instance.
@@ -44,10 +50,22 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
self.nmne_inbound_last_step: int = 0
self.nmne_outbound_last_step: int = 0
# TODO: allow these to be configured in yaml
self.high_nmne_threshold = 10
self.med_nmne_threshold = 5
self.low_nmne_threshold = 0
if thresholds.get("nmne") is None:
self.low_threshold = 0
self.med_threshold = 5
self.high_threshold = 10
else:
if self._validate_thresholds(
thresholds=[
thresholds.get("nmne")["low"],
thresholds.get("nmne")["medium"],
thresholds.get("nmne")["high"],
],
threshold_identifier="nmne",
):
self.low_threshold = thresholds.get("nmne")["low"]
self.med_threshold = thresholds.get("nmne")["medium"]
self.high_threshold = thresholds.get("nmne")["high"]
self.monitored_traffic = monitored_traffic
if self.monitored_traffic:
@@ -86,11 +104,11 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
:param nmne_count: Number of MNEs detected.
:return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count.
"""
if nmne_count > self.high_nmne_threshold:
if nmne_count > self.high_threshold:
return 3
elif nmne_count > self.med_nmne_threshold:
elif nmne_count > self.med_threshold:
return 2
elif nmne_count > self.low_nmne_threshold:
elif nmne_count > self.low_threshold:
return 1
return 0
@@ -224,6 +242,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
where=parent_where + ["NICs", config.nic_num],
include_nmne=config.include_nmne,
monitored_traffic=config.monitored_traffic,
thresholds=config.thresholds,
)

View File

@@ -195,6 +195,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
host_config.file_system_requires_scan = config.file_system_requires_scan
if host_config.include_users is None:
host_config.include_users = config.include_users
if host_config.thresholds is None:
host_config.thresholds = config.thresholds
for router_config in config.routers:
if router_config.num_ports is None:
@@ -211,6 +213,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
router_config.num_rules = config.num_rules
if router_config.include_users is None:
router_config.include_users = config.include_users
if router_config.thresholds is None:
router_config.thresholds = config.thresholds
for firewall_config in config.firewalls:
if firewall_config.ip_list is None:
@@ -225,6 +229,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
firewall_config.num_rules = config.num_rules
if firewall_config.include_users is None:
firewall_config.include_users = config.include_users
if firewall_config.thresholds is None:
firewall_config.thresholds = config.thresholds
hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts]
routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers]

View File

@@ -113,7 +113,9 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"):
instances = dict()
for component in config.components:
obs_class = AbstractObservation._registry[component.type]
obs_instance = obs_class.from_config(config=obs_class.ConfigSchema(**component.options))
obs_instance = obs_class.from_config(
config=obs_class.ConfigSchema(**component.options, thresholds=config.thresholds)
)
instances[component.label] = obs_instance
return cls(components=instances)
@@ -176,7 +178,7 @@ class ObservationManager:
return self.obs.space
@classmethod
def from_config(cls, config: Optional[Dict]) -> "ObservationManager":
def from_config(cls, config: Optional[Dict], thresholds: Optional[Dict] = {}) -> "ObservationManager":
"""
Create observation space from a config.
@@ -187,11 +189,15 @@ class ObservationManager:
AbstractObservation
options: this must adhere to the chosen observation type's ConfigSchema nested class.
:type config: Dict
:param thresholds: Dictionary containing the observation thresholds.
:type thresholds: Optional[Dict]
"""
if config is None:
return cls(NullObservation())
obs_type = config["type"]
obs_class = AbstractObservation._registry[obs_type]
observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"]))
observation = obs_class.from_config(
config=obs_class.ConfigSchema(**config["options"], thresholds=thresholds),
)
obs_manager = cls(observation)
return obs_manager

View File

@@ -1,7 +1,7 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
"""Manages the observation space for the agent."""
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, Optional, Type, Union
from typing import Any, Dict, Iterable, List, Optional, Type, Union
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -19,6 +19,9 @@ class AbstractObservation(ABC):
class ConfigSchema(ABC, BaseModel):
"""Config schema for observations."""
thresholds: Optional[Dict] = None
"""A dict containing the observation thresholds."""
model_config = ConfigDict(extra="forbid")
_registry: Dict[str, Type["AbstractObservation"]] = {}
@@ -67,3 +70,34 @@ class AbstractObservation(ABC):
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> "AbstractObservation":
"""Create this observation space component form a serialised format."""
return cls()
def _validate_thresholds(self, thresholds: List[int] = None, threshold_identifier: Optional[str] = "") -> bool:
"""
Method that checks if the thresholds are non overlapping and in the correct (ascending) order.
Pass in the thresholds from low to high e.g.
thresholds=[low_threshold, med_threshold, ..._threshold, high_threshold]
Throws an error if the threshold is not valid
:param: thresholds: List of thresholds in ascending order.
:type: List[int]
:param: threshold_identifier: The name of the threshold option.
:type: Optional[str]
:returns: bool
"""
if thresholds is None or len(thresholds) < 2:
raise Exception(f"{threshold_identifier} thresholds are invalid {thresholds}")
for idx in range(1, len(thresholds)):
if not isinstance(thresholds[idx], int):
raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.")
if not isinstance(thresholds[idx - 1], int):
raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.")
if thresholds[idx] <= thresholds[idx - 1]:
raise Exception(
f"{threshold_identifier} threshold ({thresholds[idx]}) "
f"is greater than or equal to ({thresholds[idx - 1]}.)"
)
return True

View File

@@ -1,7 +1,7 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from __future__ import annotations
from typing import Dict
from typing import Dict, Optional
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -82,7 +82,7 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
application_name: str
"""Name of the application, used for querying simulation state dictionary"""
def __init__(self, where: WhereType) -> None:
def __init__(self, where: WhereType, thresholds: Optional[Dict] = {}) -> None:
"""
Initialise an application observation instance.
@@ -94,16 +94,28 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
self.where = where
self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0}
# TODO: allow these to be configured in yaml
self.high_threshold = 10
self.med_threshold = 5
self.low_threshold = 0
if thresholds.get("app_executions") is None:
self.low_threshold = 0
self.med_threshold = 5
self.high_threshold = 10
else:
if self._validate_thresholds(
thresholds=[
thresholds.get("app_executions")["low"],
thresholds.get("app_executions")["medium"],
thresholds.get("app_executions")["high"],
],
threshold_identifier="app_executions",
):
self.low_threshold = thresholds.get("app_executions")["low"]
self.med_threshold = thresholds.get("app_executions")["medium"]
self.high_threshold = thresholds.get("app_executions")["high"]
def _categorise_num_executions(self, num_executions: int) -> int:
"""
Represent number of file accesses as a categorical variable.
Represent number of application executions as a categorical variable.
:param num_access: Number of file accesses.
:param num_access: Number of application executions.
:return: Bin number corresponding to the number of accesses.
"""
if num_executions > self.high_threshold:
@@ -161,4 +173,4 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
:return: Constructed application observation instance.
:rtype: ApplicationObservation
"""
return cls(where=parent_where + ["applications", config.application_name])
return cls(where=parent_where + ["applications", config.application_name], thresholds=config.thresholds)

View File

@@ -531,7 +531,7 @@ class PrimaiteGame:
reward_function_cfg = agent_cfg["reward_function"]
# CREATE OBSERVATION SPACE
obs_space = ObservationManager.from_config(observation_space_cfg)
obs_space = ObservationManager.from_config(config=observation_space_cfg, thresholds=game.options.thresholds)
# CREATE ACTION SPACE
action_space = ActionManager.from_config(game, action_space_cfg)

View File

@@ -25,7 +25,19 @@ game:
- ICMP
- TCP
- UDP
thresholds:
nmne:
high: 100
medium: 25
low: 5
file_access:
high: 10
medium: 5
low: 2
app_executions:
high: 5
medium: 3
low: 2
agents:
- ref: client_2_green_user
team: GREEN
@@ -79,10 +91,16 @@ agents:
options:
hosts:
- hostname: client_1
applications:
- application_name: WebBrowser
folders:
- folder_name: root
files:
- file_name: "test.txt"
- hostname: client_2
- hostname: client_3
num_services: 1
num_applications: 0
num_applications: 1
num_folders: 1
num_files: 1
num_nics: 2
@@ -219,6 +237,9 @@ simulation:
options:
ntp_server_ip: 192.168.1.10
- type: NTPServer
file_system:
- root:
- "test.txt"
- hostname: client_2
type: computer
ip_address: 192.168.10.22

View File

@@ -8,7 +8,7 @@ from primaite.config.load import data_manipulation_config_path
from primaite.game.game import PrimaiteGame
from tests import TEST_ASSETS_ROOT
BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
BASIC_SWITCHED_NETWORK_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
@@ -24,3 +24,42 @@ def test_thresholds():
game = load_config(data_manipulation_config_path())
assert game.options.thresholds is not None
def test_nmne_threshold():
"""Test that the NMNE thresholds are properly loaded in by observation."""
game = load_config(BASIC_SWITCHED_NETWORK_CONFIG)
assert game.options.thresholds["nmne"] is not None
# get NIC observation
nic_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].nics[0]
assert nic_obs.low_threshold == 5
assert nic_obs.med_threshold == 25
assert nic_obs.high_threshold == 100
def test_file_access_threshold():
"""Test that the NMNE thresholds are properly loaded in by observation."""
game = load_config(BASIC_SWITCHED_NETWORK_CONFIG)
assert game.options.thresholds["file_access"] is not None
# get file observation
file_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].folders[0].files[0]
assert file_obs.low_threshold == 2
assert file_obs.med_threshold == 5
assert file_obs.high_threshold == 10
def test_app_executions_threshold():
"""Test that the NMNE thresholds are properly loaded in by observation."""
game = load_config(BASIC_SWITCHED_NETWORK_CONFIG)
assert game.options.thresholds["app_executions"] is not None
# get application observation
app_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].applications[0]
assert app_obs.low_threshold == 2
assert app_obs.med_threshold == 3
assert app_obs.high_threshold == 5

View File

@@ -44,6 +44,38 @@ def test_file_observation(simulation):
assert observation_state.get("health_status") == 3 # corrupted
def test_config_file_access_categories(simulation):
pc: Computer = simulation.network.get_node_by_hostname("client_1")
file_obs = FileObservation(
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
include_num_access=False,
file_system_requires_scan=True,
thresholds={"file_access": {"low": 3, "medium": 6, "high": 9}},
)
assert file_obs.high_threshold == 9
assert file_obs.med_threshold == 6
assert file_obs.low_threshold == 3
with pytest.raises(Exception):
# should throw an error
FileObservation(
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
include_num_access=False,
file_system_requires_scan=True,
thresholds={"file_access": {"low": 9, "medium": 6, "high": 9}},
)
with pytest.raises(Exception):
# should throw an error
FileObservation(
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
include_num_access=False,
file_system_requires_scan=True,
thresholds={"file_access": {"low": 3, "medium": 9, "high": 9}},
)
def test_folder_observation(simulation):
"""Test the folder observation."""
pc: Computer = simulation.network.get_node_by_hostname("client_1")

View File

@@ -110,33 +110,28 @@ def test_nic_categories(simulation):
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True)
assert nic_obs.high_nmne_threshold == 10 # default
assert nic_obs.med_nmne_threshold == 5 # default
assert nic_obs.low_nmne_threshold == 0 # default
assert nic_obs.high_threshold == 10 # default
assert nic_obs.med_threshold == 5 # default
assert nic_obs.low_threshold == 0 # default
@pytest.mark.skip(reason="Feature not implemented yet")
def test_config_nic_categories(simulation):
pc: Computer = simulation.network.get_node_by_hostname("client_1")
nic_obs = NICObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=3,
med_nmne_threshold=6,
high_nmne_threshold=9,
thresholds={"nmne": {"low": 3, "medium": 6, "high": 9}},
include_nmne=True,
)
assert nic_obs.high_nmne_threshold == 9
assert nic_obs.med_nmne_threshold == 6
assert nic_obs.low_nmne_threshold == 3
assert nic_obs.high_threshold == 9
assert nic_obs.med_threshold == 6
assert nic_obs.low_threshold == 3
with pytest.raises(Exception):
# should throw an error
NICObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=9,
med_nmne_threshold=6,
high_nmne_threshold=9,
thresholds={"nmne": {"low": 9, "medium": 6, "high": 9}},
include_nmne=True,
)
@@ -144,9 +139,7 @@ def test_config_nic_categories(simulation):
# should throw an error
NICObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=3,
med_nmne_threshold=9,
high_nmne_threshold=9,
thresholds={"nmne": {"low": 3, "medium": 9, "high": 9}},
include_nmne=True,
)

View File

@@ -69,3 +69,30 @@ def test_application_observation(simulation):
assert observation_state.get("health_status") == 1
assert observation_state.get("operating_status") == 1 # running
assert observation_state.get("num_executions") == 1
def test_application_executions_categories(simulation):
pc: Computer = simulation.network.get_node_by_hostname("client_1")
app_obs = ApplicationObservation(
where=["network", "nodes", pc.hostname, "applications", "WebBrowser"],
thresholds={"app_executions": {"low": 3, "medium": 6, "high": 9}},
)
assert app_obs.high_threshold == 9
assert app_obs.med_threshold == 6
assert app_obs.low_threshold == 3
with pytest.raises(Exception):
# should throw an error
ApplicationObservation(
where=["network", "nodes", pc.hostname, "applications", "WebBrowser"],
thresholds={"app_executions": {"low": 9, "medium": 6, "high": 9}},
)
with pytest.raises(Exception):
# should throw an error
ApplicationObservation(
where=["network", "nodes", pc.hostname, "applications", "WebBrowser"],
thresholds={"app_executions": {"low": 3, "medium": 9, "high": 9}},
)