#2350: configurable NMNE category thresholds
This commit is contained in:
@@ -30,6 +30,11 @@ game:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
thresholds:
|
||||
nmne:
|
||||
high: 10
|
||||
medium: 5
|
||||
low: 0
|
||||
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
|
||||
175
src/primaite/game/agent/observations/nic_observations.py
Normal file
175
src/primaite/game/agent/observations/nic_observations.py
Normal file
@@ -0,0 +1,175 @@
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
from primaite.simulator.network.nmne import CAPTURE_NMNE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class NicObservation(AbstractObservation):
|
||||
"""Observation of a Network Interface Card (NIC) in the network."""
|
||||
|
||||
low_nmne_threshold: int = 0
|
||||
"""The minimum number of malicious network events to be considered low."""
|
||||
med_nmne_threshold: int = 5
|
||||
"""The minimum number of malicious network events to be considered medium."""
|
||||
high_nmne_threshold: int = 10
|
||||
"""The minimum number of malicious network events to be considered high."""
|
||||
|
||||
@property
|
||||
def default_observation(self) -> Dict:
|
||||
"""The default NIC observation dict."""
|
||||
data = {"nic_status": 0}
|
||||
if CAPTURE_NMNE:
|
||||
data.update({"nmne": {"inbound": 0, "outbound": 0}})
|
||||
|
||||
return data
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: Optional[Tuple[str]] = None,
|
||||
low_nmne_threshold: Optional[int] = 0,
|
||||
med_nmne_threshold: Optional[int] = 5,
|
||||
high_nmne_threshold: Optional[int] = 10,
|
||||
) -> None:
|
||||
"""Initialise NIC observation.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical
|
||||
example may look like this:
|
||||
['network','nodes',<node_hostname>,'NICs',<nic_number>]
|
||||
If None, this denotes that the NIC does not exist and the observation will be populated with zeroes.
|
||||
:type where: Optional[Tuple[str]], optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
if low_nmne_threshold or med_nmne_threshold or high_nmne_threshold:
|
||||
self._validate_nmne_categories(
|
||||
low_nmne_threshold=low_nmne_threshold,
|
||||
med_nmne_threshold=med_nmne_threshold,
|
||||
high_nmne_threshold=high_nmne_threshold,
|
||||
)
|
||||
|
||||
def _validate_nmne_categories(
|
||||
self, low_nmne_threshold: int = 0, med_nmne_threshold: int = 5, high_nmne_threshold: int = 10
|
||||
):
|
||||
"""
|
||||
Validates the nmne threshold config.
|
||||
|
||||
If the configuration is valid, the thresholds will be set, otherwise, an exception is raised.
|
||||
|
||||
:param: low_nmne_threshold: The minimum number of malicious network events to be considered low
|
||||
:param: med_nmne_threshold: The minimum number of malicious network events to be considered medium
|
||||
:param: high_nmne_threshold: The minimum number of malicious network events to be considered high
|
||||
"""
|
||||
if high_nmne_threshold <= med_nmne_threshold:
|
||||
raise Exception(
|
||||
f"nmne_categories: high nmne count ({high_nmne_threshold}) must be greater "
|
||||
f"than medium nmne count ({med_nmne_threshold})"
|
||||
)
|
||||
|
||||
if med_nmne_threshold <= low_nmne_threshold:
|
||||
raise Exception(
|
||||
f"nmne_categories: medium nmne count ({med_nmne_threshold}) must be greater "
|
||||
f"than low nmne count ({low_nmne_threshold})"
|
||||
)
|
||||
|
||||
self.high_nmne_threshold = high_nmne_threshold
|
||||
self.med_nmne_threshold = med_nmne_threshold
|
||||
self.low_nmne_threshold = low_nmne_threshold
|
||||
|
||||
def _categorise_mne_count(self, nmne_count: int) -> int:
|
||||
"""
|
||||
Categorise the number of Malicious Network Events (NMNEs) into discrete bins.
|
||||
|
||||
This helps in classifying the severity or volume of MNEs into manageable levels for the agent.
|
||||
|
||||
Bins are defined as follows:
|
||||
- 0: No MNEs detected (0 events).
|
||||
- 1: Low number of MNEs (default 1-5 events).
|
||||
- 2: Moderate number of MNEs (default 6-10 events).
|
||||
- 3: High number of MNEs (default more than 10 events).
|
||||
|
||||
:param nmne_count: Number of MNEs detected.
|
||||
:return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count.
|
||||
"""
|
||||
if nmne_count > self.high_nmne_threshold:
|
||||
return 3
|
||||
elif nmne_count > self.med_nmne_threshold:
|
||||
return 2
|
||||
elif nmne_count > self.low_nmne_threshold:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
nic_state = access_from_nested_dict(state, self.where)
|
||||
|
||||
if nic_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
else:
|
||||
obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2}
|
||||
if CAPTURE_NMNE:
|
||||
obs_dict.update({"nmne": {}})
|
||||
direction_dict = nic_state["nmne"].get("direction", {})
|
||||
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
|
||||
inbound_count = inbound_keywords.get("*", 0)
|
||||
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
|
||||
outbound_count = outbound_keywords.get("*", 0)
|
||||
obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count)
|
||||
obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count)
|
||||
return obs_dict
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"nic_status": spaces.Discrete(3),
|
||||
"nmne": spaces.Dict({"inbound": spaces.Discrete(6), "outbound": spaces.Discrete(6)}),
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
|
||||
"""Create NIC observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this NIC observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
|
||||
node. A typical location for a node ``where`` can be: ['network','nodes',<node_hostname>]
|
||||
:type parent_where: Optional[List[str]]
|
||||
:return: Constructed NIC observation
|
||||
:rtype: NicObservation
|
||||
"""
|
||||
low_nmne_threshold = None
|
||||
med_nmne_threshold = None
|
||||
high_nmne_threshold = None
|
||||
|
||||
if game and game.options and game.options.thresholds and game.options.thresholds.get("nmne"):
|
||||
threshold = game.options.thresholds["nmne"]
|
||||
|
||||
low_nmne_threshold = int(threshold.get("low")) if threshold.get("low") is not None else None
|
||||
med_nmne_threshold = int(threshold.get("medium")) if threshold.get("medium") is not None else None
|
||||
high_nmne_threshold = int(threshold.get("high")) if threshold.get("high") is not None else None
|
||||
|
||||
return cls(
|
||||
where=parent_where + ["NICs", config["nic_num"]],
|
||||
low_nmne_threshold=low_nmne_threshold,
|
||||
med_nmne_threshold=med_nmne_threshold,
|
||||
high_nmne_threshold=high_nmne_threshold,
|
||||
)
|
||||
@@ -4,7 +4,8 @@ from gymnasium import spaces
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.file_system_observations import FolderObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, NicObservation
|
||||
from primaite.game.agent.observations.nic_observations import NicObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.observations.software_observation import ServiceObservation
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from gymnasium import spaces
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
from primaite.simulator.network.nmne import CAPTURE_NMNE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -116,107 +115,6 @@ class LinkObservation(AbstractObservation):
|
||||
return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
|
||||
|
||||
|
||||
class NicObservation(AbstractObservation):
|
||||
"""Observation of a Network Interface Card (NIC) in the network."""
|
||||
|
||||
@property
|
||||
def default_observation(self) -> Dict:
|
||||
"""The default NIC observation dict."""
|
||||
data = {"nic_status": 0}
|
||||
if CAPTURE_NMNE:
|
||||
data.update({"nmne": {"inbound": 0, "outbound": 0}})
|
||||
|
||||
return data
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""Initialise NIC observation.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical
|
||||
example may look like this:
|
||||
['network','nodes',<node_hostname>,'NICs',<nic_number>]
|
||||
If None, this denotes that the NIC does not exist and the observation will be populated with zeroes.
|
||||
:type where: Optional[Tuple[str]], optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
def _categorise_mne_count(self, nmne_count: int) -> int:
|
||||
"""
|
||||
Categorise the number of Malicious Network Events (NMNEs) into discrete bins.
|
||||
|
||||
This helps in classifying the severity or volume of MNEs into manageable levels for the agent.
|
||||
|
||||
Bins are defined as follows:
|
||||
- 0: No MNEs detected (0 events).
|
||||
- 1: Low number of MNEs (1-5 events).
|
||||
- 2: Moderate number of MNEs (6-10 events).
|
||||
- 3: High number of MNEs (more than 10 events).
|
||||
|
||||
:param nmne_count: Number of MNEs detected.
|
||||
:return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count.
|
||||
"""
|
||||
if nmne_count > 10:
|
||||
return 3
|
||||
elif nmne_count > 5:
|
||||
return 2
|
||||
elif nmne_count > 0:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
nic_state = access_from_nested_dict(state, self.where)
|
||||
|
||||
if nic_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
else:
|
||||
obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2}
|
||||
if CAPTURE_NMNE:
|
||||
obs_dict.update({"nmne": {}})
|
||||
direction_dict = nic_state["nmne"].get("direction", {})
|
||||
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
|
||||
inbound_count = inbound_keywords.get("*", 0)
|
||||
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
|
||||
outbound_count = outbound_keywords.get("*", 0)
|
||||
obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count)
|
||||
obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count)
|
||||
return obs_dict
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"nic_status": spaces.Discrete(3),
|
||||
"nmne": spaces.Dict({"inbound": spaces.Discrete(6), "outbound": spaces.Discrete(6)}),
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
|
||||
"""Create NIC observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this NIC observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
|
||||
node. A typical location for a node ``where`` can be: ['network','nodes',<node_hostname>]
|
||||
:type parent_where: Optional[List[str]]
|
||||
:return: Constructed NIC observation
|
||||
:rtype: NicObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["NICs", config["nic_num"]])
|
||||
|
||||
|
||||
class AclObservation(AbstractObservation):
|
||||
"""Observation of an Access Control List (ACL) in the network."""
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""PrimAITE game - Encapsulates the simulation and agents."""
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
@@ -67,8 +67,13 @@ class PrimaiteGameOptions(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
max_episode_length: int = 256
|
||||
"""Maximum number of episodes for the PrimAITE game."""
|
||||
ports: List[str]
|
||||
"""A whitelist of available ports in the simulation."""
|
||||
protocols: List[str]
|
||||
"""A whitelist of available protocols in the simulation."""
|
||||
thresholds: Optional[Dict] = {}
|
||||
"""A dict containing the thresholds used for determining what is acceptable during observations."""
|
||||
|
||||
|
||||
class PrimaiteGame:
|
||||
|
||||
@@ -5,8 +5,9 @@ from typing import Union
|
||||
import yaml
|
||||
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.agent.interface import ProxyAgent, RandomAgent
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
|
||||
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
|
||||
from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
@@ -43,15 +44,15 @@ def test_example_config():
|
||||
|
||||
# green agent 1
|
||||
assert "client_2_green_user" in game.agents
|
||||
assert isinstance(game.agents["client_2_green_user"], RandomAgent)
|
||||
assert isinstance(game.agents["client_2_green_user"], ProbabilisticAgent)
|
||||
|
||||
# green agent 2
|
||||
assert "client_1_green_user" in game.agents
|
||||
assert isinstance(game.agents["client_1_green_user"], RandomAgent)
|
||||
assert isinstance(game.agents["client_1_green_user"], ProbabilisticAgent)
|
||||
|
||||
# red agent
|
||||
assert "client_1_data_manipulation_red_bot" in game.agents
|
||||
assert isinstance(game.agents["client_1_data_manipulation_red_bot"], DataManipulationAgent)
|
||||
assert "data_manipulation_attacker" in game.agents
|
||||
assert isinstance(game.agents["data_manipulation_attacker"], DataManipulationAgent)
|
||||
|
||||
# blue agent
|
||||
assert "defender" in game.agents
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
|
||||
|
||||
|
||||
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
|
||||
"""Returns a PrimaiteGame object which loads the contents of a given yaml path."""
|
||||
with open(config_path, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
return PrimaiteGame.from_config(cfg)
|
||||
|
||||
|
||||
def test_thresholds():
|
||||
"""Test that the game options can be parsed correctly."""
|
||||
game = load_config(data_manipulation_config_path())
|
||||
|
||||
assert game.options.thresholds is not None
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from primaite.game.agent.observations.observations import NicObservation
|
||||
from primaite.game.agent.observations.nic_observations import NicObservation
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
@@ -33,3 +33,43 @@ def test_nic(simulation):
|
||||
nic.disable()
|
||||
observation_state = nic_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("nic_status") == 2 # disabled
|
||||
|
||||
|
||||
def test_nic_categories(simulation):
|
||||
"""Test the NIC observation nmne count categories."""
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1])
|
||||
|
||||
assert nic_obs.high_nmne_threshold == 10 # default
|
||||
assert nic_obs.med_nmne_threshold == 5 # default
|
||||
assert nic_obs.low_nmne_threshold == 0 # default
|
||||
|
||||
nic_obs = NicObservation(
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1],
|
||||
low_nmne_threshold=3,
|
||||
med_nmne_threshold=6,
|
||||
high_nmne_threshold=9,
|
||||
)
|
||||
|
||||
assert nic_obs.high_nmne_threshold == 9
|
||||
assert nic_obs.med_nmne_threshold == 6
|
||||
assert nic_obs.low_nmne_threshold == 3
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
NicObservation(
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1],
|
||||
low_nmne_threshold=9,
|
||||
med_nmne_threshold=6,
|
||||
high_nmne_threshold=9,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
NicObservation(
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1],
|
||||
low_nmne_threshold=3,
|
||||
med_nmne_threshold=9,
|
||||
high_nmne_threshold=9,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from primaite.game.agent.observations.observations import NicObservation
|
||||
from primaite.game.agent.observations.nic_observations import NicObservation
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.nmne import set_nmne_config
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
Reference in New Issue
Block a user