#2238 - Implement NMNE detection and logging in NetworkInterface.

- Enhance NicObservation for detailed NMNE event monitoring.
- Add nmne_config options to simulation settings for customizable NMNE capturing.
- Update documentation and tests for new NMNE features and simulation config.
This commit is contained in:
Chris McCarthy
2024-02-22 22:43:14 +00:00
parent 8f85555709
commit 771a68dccb
10 changed files with 333 additions and 17 deletions

View File

@@ -583,6 +583,10 @@ agents:
simulation:
network:
nmne_config:
capture_nmne: true
nmne_capture_keywords:
- DELETE
nodes:
- ref: router_1

View File

@@ -963,6 +963,10 @@ agents:
simulation:
network:
nmne_config:
capture_nmne: true
nmne_capture_keywords:
- DELETE
nodes:
- ref: router_1

View File

@@ -8,6 +8,7 @@ from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
from primaite.simulator.network.nmne import CAPTURE_NMNE
_LOGGER = getLogger(__name__)
@@ -346,7 +347,14 @@ class FolderObservation(AbstractObservation):
class NicObservation(AbstractObservation):
"""Observation of a Network Interface Card (NIC) in the network."""
default_observation: spaces.Space = {"nic_status": 0}
@property
def default_observation(self) -> Dict:
"""The default NIC observation dict."""
data = {"nic_status": 0}
if CAPTURE_NMNE:
data.update({"nmne": {"inbound": 0, "outbound": 0}})
return data
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise NIC observation.
@@ -360,6 +368,29 @@ class NicObservation(AbstractObservation):
super().__init__()
self.where: Optional[Tuple[str]] = where
def _categorise_mne_count(self, nmne_count: int) -> int:
"""
Categorise the number of Malicious Network Events (NMNEs) into discrete bins.
This helps in classifying the severity or volume of MNEs into manageable levels for the agent.
Bins are defined as follows:
- 0: No MNEs detected (0 events).
- 1: Low number of MNEs (1-5 events).
- 2: Moderate number of MNEs (6-10 events).
- 3: High number of MNEs (more than 10 events).
:param nmne_count: Number of MNEs detected.
:return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count.
"""
if nmne_count > 10:
return 3
elif nmne_count > 5:
return 2
elif nmne_count > 0:
return 1
return 0
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
@@ -371,15 +402,30 @@ class NicObservation(AbstractObservation):
if self.where is None:
return self.default_observation
nic_state = access_from_nested_dict(state, self.where)
if nic_state is NOT_PRESENT_IN_STATE:
return self.default_observation
else:
return {"nic_status": 1 if nic_state["enabled"] else 2}
obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2, "nmne": {}}
if CAPTURE_NMNE:
direction_dict = nic_state["nmne"].get("direction", {})
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
inbound_count = inbound_keywords.get("*", 0)
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
outbound_count = outbound_keywords.get("*", 0)
obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count)
obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count)
return obs_dict
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Dict({"nic_status": spaces.Discrete(3)})
return spaces.Dict(
{
"nic_status": spaces.Discrete(3),
"nmne": spaces.Dict({"inbound": spaces.Discrete(6), "outbound": spaces.Discrete(6)}),
}
)
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":

View File

@@ -17,6 +17,7 @@ from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.network.router import Router
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.network.nmne import set_nmne_config
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.database_client import DatabaseClient
@@ -426,4 +427,7 @@ class PrimaiteGame:
game.simulation.set_original_state()
# Set the NMNE capture config
set_nmne_config(cfg["simulation"]["network"].get("nmne_config", {}))
return game

View File

@@ -17,6 +17,15 @@ from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.domain.account import Account
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.nmne import (
CAPTURE_BY_DIRECTION,
CAPTURE_BY_IP_ADDRESS,
CAPTURE_BY_KEYWORD,
CAPTURE_BY_PORT,
CAPTURE_BY_PROTOCOL,
CAPTURE_NMNE,
NMNE_CAPTURE_KEYWORDS,
)
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.core.packet_capture import PacketCapture
@@ -88,6 +97,8 @@ class NetworkInterface(SimComponent, ABC):
pcap: Optional[PacketCapture] = None
"A PacketCapture instance for capturing and analysing packets passing through this interface."
nmne: Dict = Field(default_factory=lambda: {})
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
@@ -111,27 +122,99 @@ class NetworkInterface(SimComponent, ABC):
"enabled": self.enabled,
}
)
state.update({"nmne": self.nmne})
return state
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
super().reset_component_for_episode(episode)
self.nmne = {}
if episode and self.pcap:
self.pcap.current_episode = episode
self.pcap.setup_logger()
self.enable()
@abstractmethod
# @abstractmethod
def enable(self):
"""Enable the interface."""
pass
@abstractmethod
# @abstractmethod
def disable(self):
"""Disable the interface."""
pass
@abstractmethod
def _capture_nmne(self, frame: Frame, inbound: bool = True):
"""
Processes and captures network frame data based on predefined global NMNE settings.
This method updates the NMNE structure with counts of malicious network events based on the frame content and
direction. The structure is dynamically adjusted according to the enabled capture settings.
:param frame: The network frame to process, containing IP, TCP/UDP, and payload information.
:param inbound: Boolean indicating if the frame direction is inbound. Defaults to True.
"""
# Exit function if NMNE capturing is disabled
if not CAPTURE_NMNE:
return
# Initialise basic frame data variables
direction = "inbound" if inbound else "outbound" # Direction of the traffic
ip_address = str(frame.ip.src_ip_address if inbound else frame.ip.dst_ip_address) # Source or destination IP
protocol = frame.ip.protocol.name # Network protocol used in the frame
# Initialise port variable; will be determined based on protocol type
port = None
# Determine the source or destination port based on the protocol (TCP/UDP)
if frame.tcp:
port = frame.tcp.src_port.value if inbound else frame.tcp.dst_port.value
elif frame.udp:
port = frame.udp.src_port.value if inbound else frame.udp.dst_port.value
# Convert frame payload to string for keyword checking
frame_str = str(frame.payload)
# Proceed only if any NMNE keyword is present in the frame payload
if any(keyword in frame_str for keyword in NMNE_CAPTURE_KEYWORDS):
# Start with the root of the NMNE capture structure
current_level = self.nmne
# Update NMNE structure based on enabled settings
if CAPTURE_BY_DIRECTION:
# Set or get the dictionary for the current direction
current_level = current_level.setdefault("direction", {})
current_level = current_level.setdefault(direction, {})
if CAPTURE_BY_IP_ADDRESS:
# Set or get the dictionary for the current IP address
current_level = current_level.setdefault("ip_address", {})
current_level = current_level.setdefault(ip_address, {})
if CAPTURE_BY_PROTOCOL:
# Set or get the dictionary for the current protocol
current_level = current_level.setdefault("protocol", {})
current_level = current_level.setdefault(protocol, {})
if CAPTURE_BY_PORT:
# Set or get the dictionary for the current port
current_level = current_level.setdefault("port", {})
current_level = current_level.setdefault(port, {})
# Ensure 'KEYWORD' level is present in the structure
keyword_level = current_level.setdefault("keywords", {})
# Increment the count for detected keywords in the payload
if CAPTURE_BY_KEYWORD:
for keyword in NMNE_CAPTURE_KEYWORDS:
if keyword in frame_str:
# Update the count for each keyword found
keyword_level[keyword] = keyword_level.get(keyword, 0) + 1
else:
# Increment a generic counter if keyword capturing is not enabled
keyword_level["*"] = keyword_level.get("*", 0) + 1
# @abstractmethod
def send_frame(self, frame: Frame) -> bool:
"""
Attempts to send a network frame through the interface.
@@ -139,9 +222,9 @@ class NetworkInterface(SimComponent, ABC):
:param frame: The network frame to be sent.
:return: A boolean indicating whether the frame was successfully sent.
"""
pass
self._capture_nmne(frame, inbound=False)
@abstractmethod
# @abstractmethod
def receive_frame(self, frame: Frame) -> bool:
"""
Receives a network frame on the interface.
@@ -149,7 +232,7 @@ class NetworkInterface(SimComponent, ABC):
:param frame: The network frame being received.
:return: A boolean indicating whether the frame was successfully received.
"""
pass
self._capture_nmne(frame, inbound=True)
def __str__(self) -> str:
"""
@@ -263,6 +346,7 @@ class WiredNetworkInterface(NetworkInterface, ABC):
:param frame: The network frame to be sent.
:return: True if the frame is sent, False if the Network Interface is disabled or not connected to a link.
"""
super().send_frame(frame)
if self.enabled:
frame.set_sent_timestamp()
self.pcap.capture_outbound(frame)
@@ -279,7 +363,7 @@ class WiredNetworkInterface(NetworkInterface, ABC):
:param frame: The network frame being received.
:return: A boolean indicating whether the frame was successfully received.
"""
pass
return super().receive_frame(frame)
class Layer3Interface(BaseModel, ABC):
@@ -409,7 +493,7 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC):
except AttributeError:
pass
# @abstractmethod
@abstractmethod
def receive_frame(self, frame: Frame) -> bool:
"""
Receives a network frame on the network interface.
@@ -417,7 +501,7 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC):
:param frame: The network frame being received.
:return: A boolean indicating whether the frame was successfully received.
"""
pass
return super().receive_frame(frame)
class Link(SimComponent):

View File

@@ -248,6 +248,7 @@ class NIC(IPWiredNetworkInterface):
accept_frame = True
if accept_frame:
super().receive_frame(frame)
self._connected_node.receive_frame(frame=frame, from_network_interface=self)
return True
return False

View File

@@ -0,0 +1,46 @@
from typing import Dict, Final, List
CAPTURE_NMNE: bool = True
"""Indicates whether Malicious Network Events (MNEs) should be captured. Default is True."""
NMNE_CAPTURE_KEYWORDS: List[str] = []
"""List of keywords to identify malicious network events."""
CAPTURE_BY_DIRECTION: Final[bool] = True
"""Flag to determine if captures should be organized by traffic direction (inbound/outbound)."""
CAPTURE_BY_IP_ADDRESS: Final[bool] = False
"""Flag to determine if captures should be organized by source or destination IP address."""
CAPTURE_BY_PROTOCOL: Final[bool] = False
"""Flag to determine if captures should be organized by network protocol (e.g., TCP, UDP)."""
CAPTURE_BY_PORT: Final[bool] = False
"""Flag to determine if captures should be organized by source or destination port."""
CAPTURE_BY_KEYWORD: Final[bool] = False
"""Flag to determine if captures should be filtered and categorised based on specific keywords."""
def set_nmne_config(nmne_config: Dict):
"""
Sets the configuration for capturing Malicious Network Events (MNEs) based on a provided dictionary.
This function updates global settings related to NMNE capture, including whether to capture NMNEs and what
keywords to use for identifying NMNEs.
The function ensures that the settings are updated only if they are provided in the `nmne_config` dictionary,
and maintains type integrity by checking the types of the provided values.
:param nmne_config: A dictionary containing the NMNE configuration settings. Possible keys include:
"capture_nmne" (bool) to indicate whether NMNEs should be captured, "nmne_capture_keywords" (list of strings)
to specify keywords for NMNE identification.
"""
global NMNE_CAPTURE_KEYWORDS
global CAPTURE_NMNE
# Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect
CAPTURE_NMNE = nmne_config.get("capture_nmne", False)
if not isinstance(CAPTURE_NMNE, bool):
CAPTURE_NMNE = True # Revert to default True if the provided value is not a boolean
# Update the NMNE capture keywords, appending new keywords if provided
NMNE_CAPTURE_KEYWORDS += nmne_config.get("nmne_capture_keywords", [])
if not isinstance(NMNE_CAPTURE_KEYWORDS, list):
NMNE_CAPTURE_KEYWORDS = [] # Reset to empty list if the provided value is not a list