#2350: configurable NMNE category thresholds

This commit is contained in:
Czar Echavez
2024-03-08 19:32:07 +00:00
parent 61aa242128
commit cc721056d8
9 changed files with 261 additions and 111 deletions

View File

@@ -30,6 +30,11 @@ game:
- ICMP
- TCP
- UDP
thresholds:
nmne:
high: 10
medium: 5
low: 0
agents:
- ref: client_2_green_user

View File

@@ -0,0 +1,175 @@
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
from primaite.simulator.network.nmne import CAPTURE_NMNE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class NicObservation(AbstractObservation):
"""Observation of a Network Interface Card (NIC) in the network."""
low_nmne_threshold: int = 0
"""The minimum number of malicious network events to be considered low."""
med_nmne_threshold: int = 5
"""The minimum number of malicious network events to be considered medium."""
high_nmne_threshold: int = 10
"""The minimum number of malicious network events to be considered high."""
@property
def default_observation(self) -> Dict:
"""The default NIC observation dict."""
data = {"nic_status": 0}
if CAPTURE_NMNE:
data.update({"nmne": {"inbound": 0, "outbound": 0}})
return data
def __init__(
self,
where: Optional[Tuple[str]] = None,
low_nmne_threshold: Optional[int] = 0,
med_nmne_threshold: Optional[int] = 5,
high_nmne_threshold: Optional[int] = 10,
) -> None:
"""Initialise NIC observation.
:param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical
example may look like this:
['network','nodes',<node_hostname>,'NICs',<nic_number>]
If None, this denotes that the NIC does not exist and the observation will be populated with zeroes.
:type where: Optional[Tuple[str]], optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
if low_nmne_threshold or med_nmne_threshold or high_nmne_threshold:
self._validate_nmne_categories(
low_nmne_threshold=low_nmne_threshold,
med_nmne_threshold=med_nmne_threshold,
high_nmne_threshold=high_nmne_threshold,
)
def _validate_nmne_categories(
self, low_nmne_threshold: int = 0, med_nmne_threshold: int = 5, high_nmne_threshold: int = 10
):
"""
Validates the nmne threshold config.
If the configuration is valid, the thresholds will be set, otherwise, an exception is raised.
:param: low_nmne_threshold: The minimum number of malicious network events to be considered low
:param: med_nmne_threshold: The minimum number of malicious network events to be considered medium
:param: high_nmne_threshold: The minimum number of malicious network events to be considered high
"""
if high_nmne_threshold <= med_nmne_threshold:
raise Exception(
f"nmne_categories: high nmne count ({high_nmne_threshold}) must be greater "
f"than medium nmne count ({med_nmne_threshold})"
)
if med_nmne_threshold <= low_nmne_threshold:
raise Exception(
f"nmne_categories: medium nmne count ({med_nmne_threshold}) must be greater "
f"than low nmne count ({low_nmne_threshold})"
)
self.high_nmne_threshold = high_nmne_threshold
self.med_nmne_threshold = med_nmne_threshold
self.low_nmne_threshold = low_nmne_threshold
def _categorise_mne_count(self, nmne_count: int) -> int:
"""
Categorise the number of Malicious Network Events (NMNEs) into discrete bins.
This helps in classifying the severity or volume of MNEs into manageable levels for the agent.
Bins are defined as follows:
- 0: No MNEs detected (0 events).
- 1: Low number of MNEs (default 1-5 events).
- 2: Moderate number of MNEs (default 6-10 events).
- 3: High number of MNEs (default more than 10 events).
:param nmne_count: Number of MNEs detected.
:return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count.
"""
if nmne_count > self.high_nmne_threshold:
return 3
elif nmne_count > self.med_nmne_threshold:
return 2
elif nmne_count > self.low_nmne_threshold:
return 1
return 0
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
nic_state = access_from_nested_dict(state, self.where)
if nic_state is NOT_PRESENT_IN_STATE:
return self.default_observation
else:
obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2}
if CAPTURE_NMNE:
obs_dict.update({"nmne": {}})
direction_dict = nic_state["nmne"].get("direction", {})
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
inbound_count = inbound_keywords.get("*", 0)
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
outbound_count = outbound_keywords.get("*", 0)
obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count)
obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count)
return obs_dict
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Dict(
{
"nic_status": spaces.Discrete(3),
"nmne": spaces.Dict({"inbound": spaces.Discrete(6), "outbound": spaces.Discrete(6)}),
}
)
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
"""Create NIC observation from a config.
:param config: Dictionary containing the configuration for this NIC observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
node. A typical location for a node ``where`` can be: ['network','nodes',<node_hostname>]
:type parent_where: Optional[List[str]]
:return: Constructed NIC observation
:rtype: NicObservation
"""
low_nmne_threshold = None
med_nmne_threshold = None
high_nmne_threshold = None
if game and game.options and game.options.thresholds and game.options.thresholds.get("nmne"):
threshold = game.options.thresholds["nmne"]
low_nmne_threshold = int(threshold.get("low")) if threshold.get("low") is not None else None
med_nmne_threshold = int(threshold.get("medium")) if threshold.get("medium") is not None else None
high_nmne_threshold = int(threshold.get("high")) if threshold.get("high") is not None else None
return cls(
where=parent_where + ["NICs", config["nic_num"]],
low_nmne_threshold=low_nmne_threshold,
med_nmne_threshold=med_nmne_threshold,
high_nmne_threshold=high_nmne_threshold,
)

View File

@@ -4,7 +4,8 @@ from gymnasium import spaces
from primaite import getLogger
from primaite.game.agent.observations.file_system_observations import FolderObservation
from primaite.game.agent.observations.observations import AbstractObservation, NicObservation
from primaite.game.agent.observations.nic_observations import NicObservation
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.observations.software_observation import ServiceObservation
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE

View File

@@ -7,7 +7,6 @@ from gymnasium import spaces
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
from primaite.simulator.network.nmne import CAPTURE_NMNE
_LOGGER = getLogger(__name__)
@@ -116,107 +115,6 @@ class LinkObservation(AbstractObservation):
return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
class NicObservation(AbstractObservation):
"""Observation of a Network Interface Card (NIC) in the network."""
@property
def default_observation(self) -> Dict:
"""The default NIC observation dict."""
data = {"nic_status": 0}
if CAPTURE_NMNE:
data.update({"nmne": {"inbound": 0, "outbound": 0}})
return data
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise NIC observation.
:param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical
example may look like this:
['network','nodes',<node_hostname>,'NICs',<nic_number>]
If None, this denotes that the NIC does not exist and the observation will be populated with zeroes.
:type where: Optional[Tuple[str]], optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
def _categorise_mne_count(self, nmne_count: int) -> int:
"""
Categorise the number of Malicious Network Events (NMNEs) into discrete bins.
This helps in classifying the severity or volume of MNEs into manageable levels for the agent.
Bins are defined as follows:
- 0: No MNEs detected (0 events).
- 1: Low number of MNEs (1-5 events).
- 2: Moderate number of MNEs (6-10 events).
- 3: High number of MNEs (more than 10 events).
:param nmne_count: Number of MNEs detected.
:return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count.
"""
if nmne_count > 10:
return 3
elif nmne_count > 5:
return 2
elif nmne_count > 0:
return 1
return 0
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
nic_state = access_from_nested_dict(state, self.where)
if nic_state is NOT_PRESENT_IN_STATE:
return self.default_observation
else:
obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2}
if CAPTURE_NMNE:
obs_dict.update({"nmne": {}})
direction_dict = nic_state["nmne"].get("direction", {})
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
inbound_count = inbound_keywords.get("*", 0)
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
outbound_count = outbound_keywords.get("*", 0)
obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count)
obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count)
return obs_dict
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Dict(
{
"nic_status": spaces.Discrete(3),
"nmne": spaces.Dict({"inbound": spaces.Discrete(6), "outbound": spaces.Discrete(6)}),
}
)
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
"""Create NIC observation from a config.
:param config: Dictionary containing the configuration for this NIC observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
node. A typical location for a node ``where`` can be: ['network','nodes',<node_hostname>]
:type parent_where: Optional[List[str]]
:return: Constructed NIC observation
:rtype: NicObservation
"""
return cls(where=parent_where + ["NICs", config["nic_num"]])
class AclObservation(AbstractObservation):
"""Observation of an Access Control List (ACL) in the network."""

View File

@@ -1,6 +1,6 @@
"""PrimAITE game - Encapsulates the simulation and agents."""
from ipaddress import IPv4Address
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
from pydantic import BaseModel, ConfigDict
@@ -67,8 +67,13 @@ class PrimaiteGameOptions(BaseModel):
model_config = ConfigDict(extra="forbid")
max_episode_length: int = 256
"""Maximum number of episodes for the PrimAITE game."""
ports: List[str]
"""A whitelist of available ports in the simulation."""
protocols: List[str]
"""A whitelist of available protocols in the simulation."""
thresholds: Optional[Dict] = {}
"""A dict containing the thresholds used for determining what is acceptable during observations."""
class PrimaiteGame:

View File

@@ -5,8 +5,9 @@ from typing import Union
import yaml
from primaite.config.load import data_manipulation_config_path
from primaite.game.agent.interface import ProxyAgent, RandomAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
@@ -43,15 +44,15 @@ def test_example_config():
# green agent 1
assert "client_2_green_user" in game.agents
assert isinstance(game.agents["client_2_green_user"], RandomAgent)
assert isinstance(game.agents["client_2_green_user"], ProbabilisticAgent)
# green agent 2
assert "client_1_green_user" in game.agents
assert isinstance(game.agents["client_1_green_user"], RandomAgent)
assert isinstance(game.agents["client_1_green_user"], ProbabilisticAgent)
# red agent
assert "client_1_data_manipulation_red_bot" in game.agents
assert isinstance(game.agents["client_1_data_manipulation_red_bot"], DataManipulationAgent)
assert "data_manipulation_attacker" in game.agents
assert isinstance(game.agents["data_manipulation_attacker"], DataManipulationAgent)
# blue agent
assert "defender" in game.agents

View File

@@ -0,0 +1,25 @@
from pathlib import Path
from typing import Union
import yaml
from primaite.config.load import data_manipulation_config_path
from primaite.game.game import PrimaiteGame
from tests import TEST_ASSETS_ROOT
BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
"""Returns a PrimaiteGame object which loads the contents of a given yaml path."""
with open(config_path, "r") as f:
cfg = yaml.safe_load(f)
return PrimaiteGame.from_config(cfg)
def test_thresholds():
"""Test that the game options can be parsed correctly."""
game = load_config(data_manipulation_config_path())
assert game.options.thresholds is not None

View File

@@ -1,6 +1,6 @@
import pytest
from primaite.game.agent.observations.observations import NicObservation
from primaite.game.agent.observations.nic_observations import NicObservation
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.sim_container import Simulation
@@ -33,3 +33,43 @@ def test_nic(simulation):
nic.disable()
observation_state = nic_obs.observe(simulation.describe_state())
assert observation_state.get("nic_status") == 2 # disabled
def test_nic_categories(simulation):
"""Test the NIC observation nmne count categories."""
pc: Computer = simulation.network.get_node_by_hostname("client_1")
nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1])
assert nic_obs.high_nmne_threshold == 10 # default
assert nic_obs.med_nmne_threshold == 5 # default
assert nic_obs.low_nmne_threshold == 0 # default
nic_obs = NicObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=3,
med_nmne_threshold=6,
high_nmne_threshold=9,
)
assert nic_obs.high_nmne_threshold == 9
assert nic_obs.med_nmne_threshold == 6
assert nic_obs.low_nmne_threshold == 3
with pytest.raises(Exception):
# should throw an error
NicObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=9,
med_nmne_threshold=6,
high_nmne_threshold=9,
)
with pytest.raises(Exception):
# should throw an error
NicObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=3,
med_nmne_threshold=9,
high_nmne_threshold=9,
)

View File

@@ -1,4 +1,4 @@
from primaite.game.agent.observations.observations import NicObservation
from primaite.game.agent.observations.nic_observations import NicObservation
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.nmne import set_nmne_config
from primaite.simulator.sim_container import Simulation