#2238 - Implement NMNE detection and logging in NetworkInterface.
- Enhance NicObservation for detailed NMNE event monitoring. - Add nmne_config options to simulation settings for customizable NMNE capturing. - Update documentation and tests for new NMNE features and simulation config.
This commit is contained in:
@@ -583,6 +583,10 @@ agents:
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nmne_config:
|
||||
capture_nmne: true
|
||||
nmne_capture_keywords:
|
||||
- DELETE
|
||||
nodes:
|
||||
|
||||
- ref: router_1
|
||||
|
||||
@@ -963,6 +963,10 @@ agents:
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nmne_config:
|
||||
capture_nmne: true
|
||||
nmne_capture_keywords:
|
||||
- DELETE
|
||||
nodes:
|
||||
|
||||
- ref: router_1
|
||||
|
||||
@@ -8,6 +8,7 @@ from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
from primaite.simulator.network.nmne import CAPTURE_NMNE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -346,7 +347,14 @@ class FolderObservation(AbstractObservation):
|
||||
class NicObservation(AbstractObservation):
|
||||
"""Observation of a Network Interface Card (NIC) in the network."""
|
||||
|
||||
default_observation: spaces.Space = {"nic_status": 0}
|
||||
@property
|
||||
def default_observation(self) -> Dict:
|
||||
"""The default NIC observation dict."""
|
||||
data = {"nic_status": 0}
|
||||
|
||||
if CAPTURE_NMNE:
|
||||
data.update({"nmne": {"inbound": 0, "outbound": 0}})
|
||||
return data
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""Initialise NIC observation.
|
||||
@@ -360,6 +368,29 @@ class NicObservation(AbstractObservation):
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
def _categorise_mne_count(self, nmne_count: int) -> int:
|
||||
"""
|
||||
Categorise the number of Malicious Network Events (NMNEs) into discrete bins.
|
||||
|
||||
This helps in classifying the severity or volume of MNEs into manageable levels for the agent.
|
||||
|
||||
Bins are defined as follows:
|
||||
- 0: No MNEs detected (0 events).
|
||||
- 1: Low number of MNEs (1-5 events).
|
||||
- 2: Moderate number of MNEs (6-10 events).
|
||||
- 3: High number of MNEs (more than 10 events).
|
||||
|
||||
:param nmne_count: Number of MNEs detected.
|
||||
:return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count.
|
||||
"""
|
||||
if nmne_count > 10:
|
||||
return 3
|
||||
elif nmne_count > 5:
|
||||
return 2
|
||||
elif nmne_count > 0:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
@@ -371,15 +402,30 @@ class NicObservation(AbstractObservation):
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
nic_state = access_from_nested_dict(state, self.where)
|
||||
|
||||
if nic_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
else:
|
||||
return {"nic_status": 1 if nic_state["enabled"] else 2}
|
||||
obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2, "nmne": {}}
|
||||
if CAPTURE_NMNE:
|
||||
direction_dict = nic_state["nmne"].get("direction", {})
|
||||
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
|
||||
inbound_count = inbound_keywords.get("*", 0)
|
||||
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
|
||||
outbound_count = outbound_keywords.get("*", 0)
|
||||
obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count)
|
||||
obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count)
|
||||
return obs_dict
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Dict({"nic_status": spaces.Discrete(3)})
|
||||
return spaces.Dict(
|
||||
{
|
||||
"nic_status": spaces.Discrete(3),
|
||||
"nmne": spaces.Dict({"inbound": spaces.Discrete(6), "outbound": spaces.Discrete(6)}),
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
|
||||
|
||||
@@ -17,6 +17,7 @@ from primaite.simulator.network.hardware.nodes.host.host_node import NIC
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.network.router import Router
|
||||
from primaite.simulator.network.hardware.nodes.network.switch import Switch
|
||||
from primaite.simulator.network.nmne import set_nmne_config
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
@@ -426,4 +427,7 @@ class PrimaiteGame:
|
||||
|
||||
game.simulation.set_original_state()
|
||||
|
||||
# Set the NMNE capture config
|
||||
set_nmne_config(cfg["simulation"]["network"].get("nmne_config", {}))
|
||||
|
||||
return game
|
||||
|
||||
@@ -17,6 +17,15 @@ from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.domain.account import Account
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.nmne import (
|
||||
CAPTURE_BY_DIRECTION,
|
||||
CAPTURE_BY_IP_ADDRESS,
|
||||
CAPTURE_BY_KEYWORD,
|
||||
CAPTURE_BY_PORT,
|
||||
CAPTURE_BY_PROTOCOL,
|
||||
CAPTURE_NMNE,
|
||||
NMNE_CAPTURE_KEYWORDS,
|
||||
)
|
||||
from primaite.simulator.network.transmission.data_link_layer import Frame
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.core.packet_capture import PacketCapture
|
||||
@@ -88,6 +97,8 @@ class NetworkInterface(SimComponent, ABC):
|
||||
pcap: Optional[PacketCapture] = None
|
||||
"A PacketCapture instance for capturing and analysing packets passing through this interface."
|
||||
|
||||
nmne: Dict = Field(default_factory=lambda: {})
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
@@ -111,27 +122,99 @@ class NetworkInterface(SimComponent, ABC):
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
)
|
||||
state.update({"nmne": self.nmne})
|
||||
return state
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
super().reset_component_for_episode(episode)
|
||||
self.nmne = {}
|
||||
if episode and self.pcap:
|
||||
self.pcap.current_episode = episode
|
||||
self.pcap.setup_logger()
|
||||
self.enable()
|
||||
|
||||
@abstractmethod
|
||||
# @abstractmethod
|
||||
def enable(self):
|
||||
"""Enable the interface."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
# @abstractmethod
|
||||
def disable(self):
|
||||
"""Disable the interface."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _capture_nmne(self, frame: Frame, inbound: bool = True):
|
||||
"""
|
||||
Processes and captures network frame data based on predefined global NMNE settings.
|
||||
|
||||
This method updates the NMNE structure with counts of malicious network events based on the frame content and
|
||||
direction. The structure is dynamically adjusted according to the enabled capture settings.
|
||||
|
||||
:param frame: The network frame to process, containing IP, TCP/UDP, and payload information.
|
||||
:param inbound: Boolean indicating if the frame direction is inbound. Defaults to True.
|
||||
"""
|
||||
# Exit function if NMNE capturing is disabled
|
||||
if not CAPTURE_NMNE:
|
||||
return
|
||||
|
||||
# Initialise basic frame data variables
|
||||
direction = "inbound" if inbound else "outbound" # Direction of the traffic
|
||||
ip_address = str(frame.ip.src_ip_address if inbound else frame.ip.dst_ip_address) # Source or destination IP
|
||||
protocol = frame.ip.protocol.name # Network protocol used in the frame
|
||||
|
||||
# Initialise port variable; will be determined based on protocol type
|
||||
port = None
|
||||
|
||||
# Determine the source or destination port based on the protocol (TCP/UDP)
|
||||
if frame.tcp:
|
||||
port = frame.tcp.src_port.value if inbound else frame.tcp.dst_port.value
|
||||
elif frame.udp:
|
||||
port = frame.udp.src_port.value if inbound else frame.udp.dst_port.value
|
||||
|
||||
# Convert frame payload to string for keyword checking
|
||||
frame_str = str(frame.payload)
|
||||
|
||||
# Proceed only if any NMNE keyword is present in the frame payload
|
||||
if any(keyword in frame_str for keyword in NMNE_CAPTURE_KEYWORDS):
|
||||
# Start with the root of the NMNE capture structure
|
||||
current_level = self.nmne
|
||||
|
||||
# Update NMNE structure based on enabled settings
|
||||
if CAPTURE_BY_DIRECTION:
|
||||
# Set or get the dictionary for the current direction
|
||||
current_level = current_level.setdefault("direction", {})
|
||||
current_level = current_level.setdefault(direction, {})
|
||||
|
||||
if CAPTURE_BY_IP_ADDRESS:
|
||||
# Set or get the dictionary for the current IP address
|
||||
current_level = current_level.setdefault("ip_address", {})
|
||||
current_level = current_level.setdefault(ip_address, {})
|
||||
|
||||
if CAPTURE_BY_PROTOCOL:
|
||||
# Set or get the dictionary for the current protocol
|
||||
current_level = current_level.setdefault("protocol", {})
|
||||
current_level = current_level.setdefault(protocol, {})
|
||||
|
||||
if CAPTURE_BY_PORT:
|
||||
# Set or get the dictionary for the current port
|
||||
current_level = current_level.setdefault("port", {})
|
||||
current_level = current_level.setdefault(port, {})
|
||||
|
||||
# Ensure 'KEYWORD' level is present in the structure
|
||||
keyword_level = current_level.setdefault("keywords", {})
|
||||
|
||||
# Increment the count for detected keywords in the payload
|
||||
if CAPTURE_BY_KEYWORD:
|
||||
for keyword in NMNE_CAPTURE_KEYWORDS:
|
||||
if keyword in frame_str:
|
||||
# Update the count for each keyword found
|
||||
keyword_level[keyword] = keyword_level.get(keyword, 0) + 1
|
||||
else:
|
||||
# Increment a generic counter if keyword capturing is not enabled
|
||||
keyword_level["*"] = keyword_level.get("*", 0) + 1
|
||||
|
||||
# @abstractmethod
|
||||
def send_frame(self, frame: Frame) -> bool:
|
||||
"""
|
||||
Attempts to send a network frame through the interface.
|
||||
@@ -139,9 +222,9 @@ class NetworkInterface(SimComponent, ABC):
|
||||
:param frame: The network frame to be sent.
|
||||
:return: A boolean indicating whether the frame was successfully sent.
|
||||
"""
|
||||
pass
|
||||
self._capture_nmne(frame, inbound=False)
|
||||
|
||||
@abstractmethod
|
||||
# @abstractmethod
|
||||
def receive_frame(self, frame: Frame) -> bool:
|
||||
"""
|
||||
Receives a network frame on the interface.
|
||||
@@ -149,7 +232,7 @@ class NetworkInterface(SimComponent, ABC):
|
||||
:param frame: The network frame being received.
|
||||
:return: A boolean indicating whether the frame was successfully received.
|
||||
"""
|
||||
pass
|
||||
self._capture_nmne(frame, inbound=True)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
@@ -263,6 +346,7 @@ class WiredNetworkInterface(NetworkInterface, ABC):
|
||||
:param frame: The network frame to be sent.
|
||||
:return: True if the frame is sent, False if the Network Interface is disabled or not connected to a link.
|
||||
"""
|
||||
super().send_frame(frame)
|
||||
if self.enabled:
|
||||
frame.set_sent_timestamp()
|
||||
self.pcap.capture_outbound(frame)
|
||||
@@ -279,7 +363,7 @@ class WiredNetworkInterface(NetworkInterface, ABC):
|
||||
:param frame: The network frame being received.
|
||||
:return: A boolean indicating whether the frame was successfully received.
|
||||
"""
|
||||
pass
|
||||
return super().receive_frame(frame)
|
||||
|
||||
|
||||
class Layer3Interface(BaseModel, ABC):
|
||||
@@ -409,7 +493,7 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC):
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# @abstractmethod
|
||||
@abstractmethod
|
||||
def receive_frame(self, frame: Frame) -> bool:
|
||||
"""
|
||||
Receives a network frame on the network interface.
|
||||
@@ -417,7 +501,7 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC):
|
||||
:param frame: The network frame being received.
|
||||
:return: A boolean indicating whether the frame was successfully received.
|
||||
"""
|
||||
pass
|
||||
return super().receive_frame(frame)
|
||||
|
||||
|
||||
class Link(SimComponent):
|
||||
|
||||
@@ -248,6 +248,7 @@ class NIC(IPWiredNetworkInterface):
|
||||
accept_frame = True
|
||||
|
||||
if accept_frame:
|
||||
super().receive_frame(frame)
|
||||
self._connected_node.receive_frame(frame=frame, from_network_interface=self)
|
||||
return True
|
||||
return False
|
||||
|
||||
46
src/primaite/simulator/network/nmne.py
Normal file
46
src/primaite/simulator/network/nmne.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from typing import Dict, Final, List
|
||||
|
||||
CAPTURE_NMNE: bool = True
|
||||
"""Indicates whether Malicious Network Events (MNEs) should be captured. Default is True."""
|
||||
|
||||
NMNE_CAPTURE_KEYWORDS: List[str] = []
|
||||
"""List of keywords to identify malicious network events."""
|
||||
|
||||
CAPTURE_BY_DIRECTION: Final[bool] = True
|
||||
"""Flag to determine if captures should be organized by traffic direction (inbound/outbound)."""
|
||||
CAPTURE_BY_IP_ADDRESS: Final[bool] = False
|
||||
"""Flag to determine if captures should be organized by source or destination IP address."""
|
||||
CAPTURE_BY_PROTOCOL: Final[bool] = False
|
||||
"""Flag to determine if captures should be organized by network protocol (e.g., TCP, UDP)."""
|
||||
CAPTURE_BY_PORT: Final[bool] = False
|
||||
"""Flag to determine if captures should be organized by source or destination port."""
|
||||
CAPTURE_BY_KEYWORD: Final[bool] = False
|
||||
"""Flag to determine if captures should be filtered and categorised based on specific keywords."""
|
||||
|
||||
|
||||
def set_nmne_config(nmne_config: Dict):
|
||||
"""
|
||||
Sets the configuration for capturing Malicious Network Events (MNEs) based on a provided dictionary.
|
||||
|
||||
This function updates global settings related to NMNE capture, including whether to capture NMNEs and what
|
||||
keywords to use for identifying NMNEs.
|
||||
|
||||
The function ensures that the settings are updated only if they are provided in the `nmne_config` dictionary,
|
||||
and maintains type integrity by checking the types of the provided values.
|
||||
|
||||
:param nmne_config: A dictionary containing the NMNE configuration settings. Possible keys include:
|
||||
"capture_nmne" (bool) to indicate whether NMNEs should be captured, "nmne_capture_keywords" (list of strings)
|
||||
to specify keywords for NMNE identification.
|
||||
"""
|
||||
global NMNE_CAPTURE_KEYWORDS
|
||||
global CAPTURE_NMNE
|
||||
|
||||
# Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect
|
||||
CAPTURE_NMNE = nmne_config.get("capture_nmne", False)
|
||||
if not isinstance(CAPTURE_NMNE, bool):
|
||||
CAPTURE_NMNE = True # Revert to default True if the provided value is not a boolean
|
||||
|
||||
# Update the NMNE capture keywords, appending new keywords if provided
|
||||
NMNE_CAPTURE_KEYWORDS += nmne_config.get("nmne_capture_keywords", [])
|
||||
if not isinstance(NMNE_CAPTURE_KEYWORDS, list):
|
||||
NMNE_CAPTURE_KEYWORDS = [] # Reset to empty list if the provided value is not a list
|
||||
Reference in New Issue
Block a user