Merged PR 279: #2238 - Implement NMNE detection and capture

## Summary

This pull request introduces new features for capturing Malicious Network Events (NMNE) within the `NetworkInterface` class and extends the functionality to the `NicObservation` class .Additionally, it updates the simulation configuration to allow customizable NMNE capturing settings.

### Changes
- `NetworkInterface` Enhancements: Added NMNE capturing capabilities to detect and log specified malicious activities.
- `NicObservation` Integration: Updated to support detailed monitoring and analysis based on NMNE capturing results.
- Simulation Configuration: Introduced nmne_config options allowing users to enable NMNE capturing and define specific keywords, enhancing the adaptability of network security measures.

### New Configuration Options
Added to simulation.yml:
``` yaml
simulation:
  network:
    nmne_config:
      capture_nmne: true
      nmne_capture_keywords:
        - DELETE
```

Tests

Documentation
Updated README and related documentation to guide users on how to utilize the new NMNE capturing features and configure them in their simulations.

## Test process
- **NMNE Capture Testing**: Implemented tests using the UC2 network setup, where DELETE SQL queries are initiated by the database client residing on the web server and targeted towards the database service on the database server. Post-query, the network interface cards (NICs) on both servers are examined to verify accurate counting and logging of NMNE (Malicious Network Events) as expected per configuration.

- **NicObservation Testing**: Introduced additional tests to ensure proper integration of the `NicObservation `class, focusing on its ability to accurately observe and report NMNE occurrences.

## Checklist
- [X] PR is linked to a **work item**
- [X] **acceptance criteria** of linked ticket are met
- [X] performed **self-review** of the code
- [X] written **tests** for any new functionality added with this PR
- [X] updated the **documentation** if this PR changes or adds functionality
- [X] written/updated **design docs** if this PR implements new functionality
- [X] updated the **change log**
- [X] ran **pre-commit** checks for code style
- [X] attended to any **TO-DOs** left in the code

#2238

Related work items: #2238
This commit is contained in:
Christopher McCarthy
2024-02-28 14:57:46 +00:00
11 changed files with 351 additions and 18 deletions

View File

@@ -82,7 +82,8 @@ SessionManager.
- `AirSpace` class to simulate wireless communications, managing wireless interfaces and facilitating the transmission of frames within specified frequencies.
- `AirSpaceFrequency` enum for defining standard wireless frequencies, including 2.4 GHz and 5 GHz bands, to support realistic wireless network simulations.
- `WirelessRouter` class, extending the `Router` class, to incorporate wireless networking capabilities alongside traditional wired connections. This class allows the configuration of wireless access points with specific IP settings and operating frequencies.
- NMNE capturing capabilities to `NetworkInterface` class for detecting and logging Malicious Network Events.
- New `nmne_config` settings in the simulation configuration to enable NMNE capturing and specify keywords such as "DELETE".
### Changed
- Integrated the RouteTable into the Routers frame processing.
@@ -94,7 +95,8 @@ SessionManager.
- Refactored all tests to utilise new `Node` subclasses (`Computer`, `Server`, `Router`, `Switch`) instead of creating generic `Node` instances and manually adding network interfaces. This change aligns test setups more closely with the intended use cases and hierarchies within the network simulation framework.
- Updated all tests to employ the `Network()` class for managing nodes and their connections, ensuring a consistent and structured approach to setting up network topologies in testing scenarios.
- **ACLRule Wildcard Masking**: Updated the `ACLRule` class to support IP ranges using wildcard masking. This enhancement allows for more flexible and granular control over traffic filtering, enabling the specification of broader or more specific IP address ranges in ACL rules.
- Updated `NetworkInterface` documentation to reflect the new NMNE capturing features and how to use them.
- Integration of NMNE capturing functionality within the `NicObservation` class.
### Removed
- Removed legacy simulation modules: `acl`, `common`, `environment`, `links`, `nodes`, `pol`

View File

@@ -65,9 +65,14 @@ Network Interface Classes
**NetworkInterface (Base Layer)**
Abstract base class defining core interface properties like MAC address, speed, MTU.
Requires subclasses implement key methods like send/receive frames, enable/disable interface.
Establishes universal network interface capabilities.
- Abstract base class defining core interface properties like MAC address, speed, MTU.
- Requires subclasses implement key methods like send/receive frames, enable/disable interface.
- Establishes universal network interface capabilities.
- Malicious Network Events Monitoring:
* Enhances network interfaces with the capability to monitor and capture Malicious Network Events (MNEs) based on predefined criteria such as specific keywords or traffic patterns.
* Integrates Number of Malicious Network Events (NMNE) detection functionalities, leveraging configurable settings like ``capture_nmne``, `nmne_capture_keywords``, and observation mechanisms such as ``NicObservation`` to classify and record network anomalies.
* Offers an additional layer of security and data analysis, crucial for identifying and mitigating malicious activities within the network infrastructure. Provides vital information for network security analysis and reinforcement learning algorithms.
**WiredNetworkInterface (Connection Type Layer)**

View File

@@ -583,6 +583,10 @@ agents:
simulation:
network:
nmne_config:
capture_nmne: true
nmne_capture_keywords:
- DELETE
nodes:
- ref: router_1

View File

@@ -963,6 +963,10 @@ agents:
simulation:
network:
nmne_config:
capture_nmne: true
nmne_capture_keywords:
- DELETE
nodes:
- ref: router_1

View File

@@ -8,6 +8,7 @@ from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
from primaite.simulator.network.nmne import CAPTURE_NMNE
_LOGGER = getLogger(__name__)
@@ -346,7 +347,14 @@ class FolderObservation(AbstractObservation):
class NicObservation(AbstractObservation):
"""Observation of a Network Interface Card (NIC) in the network."""
default_observation: spaces.Space = {"nic_status": 0}
@property
def default_observation(self) -> Dict:
"""The default NIC observation dict."""
data = {"nic_status": 0}
if CAPTURE_NMNE:
data.update({"nmne": {"inbound": 0, "outbound": 0}})
return data
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise NIC observation.
@@ -360,6 +368,29 @@ class NicObservation(AbstractObservation):
super().__init__()
self.where: Optional[Tuple[str]] = where
def _categorise_mne_count(self, nmne_count: int) -> int:
"""
Categorise the number of Malicious Network Events (NMNEs) into discrete bins.
This helps in classifying the severity or volume of MNEs into manageable levels for the agent.
Bins are defined as follows:
- 0: No MNEs detected (0 events).
- 1: Low number of MNEs (1-5 events).
- 2: Moderate number of MNEs (6-10 events).
- 3: High number of MNEs (more than 10 events).
:param nmne_count: Number of MNEs detected.
:return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count.
"""
if nmne_count > 10:
return 3
elif nmne_count > 5:
return 2
elif nmne_count > 0:
return 1
return 0
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
@@ -371,15 +402,31 @@ class NicObservation(AbstractObservation):
if self.where is None:
return self.default_observation
nic_state = access_from_nested_dict(state, self.where)
if nic_state is NOT_PRESENT_IN_STATE:
return self.default_observation
else:
return {"nic_status": 1 if nic_state["enabled"] else 2}
obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2}
if CAPTURE_NMNE:
obs_dict.update({"nmne": {}})
direction_dict = nic_state["nmne"].get("direction", {})
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
inbound_count = inbound_keywords.get("*", 0)
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
outbound_count = outbound_keywords.get("*", 0)
obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count)
obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count)
return obs_dict
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Dict({"nic_status": spaces.Discrete(3)})
return spaces.Dict(
{
"nic_status": spaces.Discrete(3),
"nmne": spaces.Dict({"inbound": spaces.Discrete(6), "outbound": spaces.Discrete(6)}),
}
)
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":

View File

@@ -17,6 +17,7 @@ from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.network.router import Router
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.network.nmne import set_nmne_config
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.database_client import DatabaseClient
@@ -426,4 +427,7 @@ class PrimaiteGame:
game.simulation.set_original_state()
# Set the NMNE capture config
set_nmne_config(cfg["simulation"]["network"].get("nmne_config", {}))
return game

View File

@@ -130,6 +130,9 @@
" - NETWORK_INTERFACES\n",
" - <nic_id 1-2>\n",
" - nic_status\n",
" - nmne\n",
" - inbound\n",
" - outbound\n",
" - operating_status\n",
"- LINKS\n",
" - <link_id 1-10>\n",
@@ -220,6 +223,14 @@
"|1|ENABLED|\n",
"|2|DISABLED|\n",
"\n",
"NMNE (number of malicious network events) means, for inbound or outbound traffic, means:\n",
"|value|NMNEs|\n",
"|--|--|\n",
"|0|None|\n",
"|1|1 - 5|\n",
"|2|6 - 10|\n",
"|3|More than 10|\n",
"\n",
"Link load has the following meaning:\n",
"|load|percent utilisation|\n",
"|--|--|\n",

View File

@@ -17,6 +17,15 @@ from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.domain.account import Account
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.nmne import (
CAPTURE_BY_DIRECTION,
CAPTURE_BY_IP_ADDRESS,
CAPTURE_BY_KEYWORD,
CAPTURE_BY_PORT,
CAPTURE_BY_PROTOCOL,
CAPTURE_NMNE,
NMNE_CAPTURE_KEYWORDS,
)
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.core.packet_capture import PacketCapture
@@ -88,6 +97,9 @@ class NetworkInterface(SimComponent, ABC):
pcap: Optional[PacketCapture] = None
"A PacketCapture instance for capturing and analysing packets passing through this interface."
nmne: Dict = Field(default_factory=lambda: {})
"A dict containing details of the number of malicious network events captured."
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
@@ -111,11 +123,14 @@ class NetworkInterface(SimComponent, ABC):
"enabled": self.enabled,
}
)
if CAPTURE_NMNE:
state.update({"nmne": self.nmne})
return state
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
super().reset_component_for_episode(episode)
self.nmne = {}
if episode and self.pcap:
self.pcap.current_episode = episode
self.pcap.setup_logger()
@@ -131,6 +146,82 @@ class NetworkInterface(SimComponent, ABC):
"""Disable the interface."""
pass
def _capture_nmne(self, frame: Frame, inbound: bool = True) -> None:
"""
Processes and captures network frame data based on predefined global NMNE settings.
This method updates the NMNE structure with counts of malicious network events based on the frame content and
direction. The structure is dynamically adjusted according to the enabled capture settings.
.. note::
While there is a lot of logic in this code that defines a multi-level hierarchical NMNE structure,
most of it is unused for now as a result of all `CAPTURE_BY_<>` variables in
``primaite.simulator.network.nmne`` being hardcoded and set as final. Once they're 'released' and made
configurable, this function will be updated to properly explain the dynamic data structure.
:param frame: The network frame to process, containing IP, TCP/UDP, and payload information.
:param inbound: Boolean indicating if the frame direction is inbound. Defaults to True.
"""
# Exit function if NMNE capturing is disabled
if not CAPTURE_NMNE:
return
# Initialise basic frame data variables
direction = "inbound" if inbound else "outbound" # Direction of the traffic
ip_address = str(frame.ip.src_ip_address if inbound else frame.ip.dst_ip_address) # Source or destination IP
protocol = frame.ip.protocol.name # Network protocol used in the frame
# Initialise port variable; will be determined based on protocol type
port = None
# Determine the source or destination port based on the protocol (TCP/UDP)
if frame.tcp:
port = frame.tcp.src_port.value if inbound else frame.tcp.dst_port.value
elif frame.udp:
port = frame.udp.src_port.value if inbound else frame.udp.dst_port.value
# Convert frame payload to string for keyword checking
frame_str = str(frame.payload)
# Proceed only if any NMNE keyword is present in the frame payload
if any(keyword in frame_str for keyword in NMNE_CAPTURE_KEYWORDS):
# Start with the root of the NMNE capture structure
current_level = self.nmne
# Update NMNE structure based on enabled settings
if CAPTURE_BY_DIRECTION:
# Set or get the dictionary for the current direction
current_level = current_level.setdefault("direction", {})
current_level = current_level.setdefault(direction, {})
if CAPTURE_BY_IP_ADDRESS:
# Set or get the dictionary for the current IP address
current_level = current_level.setdefault("ip_address", {})
current_level = current_level.setdefault(ip_address, {})
if CAPTURE_BY_PROTOCOL:
# Set or get the dictionary for the current protocol
current_level = current_level.setdefault("protocol", {})
current_level = current_level.setdefault(protocol, {})
if CAPTURE_BY_PORT:
# Set or get the dictionary for the current port
current_level = current_level.setdefault("port", {})
current_level = current_level.setdefault(port, {})
# Ensure 'KEYWORD' level is present in the structure
keyword_level = current_level.setdefault("keywords", {})
# Increment the count for detected keywords in the payload
if CAPTURE_BY_KEYWORD:
for keyword in NMNE_CAPTURE_KEYWORDS:
if keyword in frame_str:
# Update the count for each keyword found
keyword_level[keyword] = keyword_level.get(keyword, 0) + 1
else:
# Increment a generic counter if keyword capturing is not enabled
keyword_level["*"] = keyword_level.get("*", 0) + 1
@abstractmethod
def send_frame(self, frame: Frame) -> bool:
"""
@@ -139,7 +230,7 @@ class NetworkInterface(SimComponent, ABC):
:param frame: The network frame to be sent.
:return: A boolean indicating whether the frame was successfully sent.
"""
pass
self._capture_nmne(frame, inbound=False)
@abstractmethod
def receive_frame(self, frame: Frame) -> bool:
@@ -149,7 +240,7 @@ class NetworkInterface(SimComponent, ABC):
:param frame: The network frame being received.
:return: A boolean indicating whether the frame was successfully received.
"""
pass
self._capture_nmne(frame, inbound=True)
def __str__(self) -> str:
"""
@@ -263,6 +354,7 @@ class WiredNetworkInterface(NetworkInterface, ABC):
:param frame: The network frame to be sent.
:return: True if the frame is sent, False if the Network Interface is disabled or not connected to a link.
"""
super().send_frame(frame)
if self.enabled:
frame.set_sent_timestamp()
self.pcap.capture_outbound(frame)
@@ -279,7 +371,7 @@ class WiredNetworkInterface(NetworkInterface, ABC):
:param frame: The network frame being received.
:return: A boolean indicating whether the frame was successfully received.
"""
pass
return super().receive_frame(frame)
class Layer3Interface(BaseModel, ABC):
@@ -409,7 +501,7 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC):
except AttributeError:
pass
# @abstractmethod
@abstractmethod
def receive_frame(self, frame: Frame) -> bool:
"""
Receives a network frame on the network interface.
@@ -417,7 +509,7 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC):
:param frame: The network frame being received.
:return: A boolean indicating whether the frame was successfully received.
"""
pass
return super().receive_frame(frame)
class Link(SimComponent):

View File

@@ -205,11 +205,7 @@ class NIC(IPWiredNetworkInterface):
state = super().describe_state()
# Update the state with NIC-specific information
state.update(
{
"wake_on_lan": self.wake_on_lan,
}
)
state.update({"wake_on_lan": self.wake_on_lan})
return state
@@ -248,6 +244,7 @@ class NIC(IPWiredNetworkInterface):
accept_frame = True
if accept_frame:
super().receive_frame(frame)
self._connected_node.receive_frame(frame=frame, from_network_interface=self)
return True
return False

View File

@@ -0,0 +1,47 @@
from typing import Dict, Final, List
CAPTURE_NMNE: bool = True
"""Indicates whether Malicious Network Events (MNEs) should be captured. Default is True."""
NMNE_CAPTURE_KEYWORDS: List[str] = []
"""List of keywords to identify malicious network events."""
# TODO: Remove final and make configurable after example layout when the NicObservation creates nmne structure dynamically
CAPTURE_BY_DIRECTION: Final[bool] = True
"""Flag to determine if captures should be organized by traffic direction (inbound/outbound)."""
CAPTURE_BY_IP_ADDRESS: Final[bool] = False
"""Flag to determine if captures should be organized by source or destination IP address."""
CAPTURE_BY_PROTOCOL: Final[bool] = False
"""Flag to determine if captures should be organized by network protocol (e.g., TCP, UDP)."""
CAPTURE_BY_PORT: Final[bool] = False
"""Flag to determine if captures should be organized by source or destination port."""
CAPTURE_BY_KEYWORD: Final[bool] = False
"""Flag to determine if captures should be filtered and categorised based on specific keywords."""
def set_nmne_config(nmne_config: Dict):
"""
Sets the configuration for capturing Malicious Network Events (MNEs) based on a provided dictionary.
This function updates global settings related to NMNE capture, including whether to capture NMNEs and what
keywords to use for identifying NMNEs.
The function ensures that the settings are updated only if they are provided in the `nmne_config` dictionary,
and maintains type integrity by checking the types of the provided values.
:param nmne_config: A dictionary containing the NMNE configuration settings. Possible keys include:
"capture_nmne" (bool) to indicate whether NMNEs should be captured, "nmne_capture_keywords" (list of strings)
to specify keywords for NMNE identification.
"""
global NMNE_CAPTURE_KEYWORDS
global CAPTURE_NMNE
# Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect
CAPTURE_NMNE = nmne_config.get("capture_nmne", False)
if not isinstance(CAPTURE_NMNE, bool):
CAPTURE_NMNE = True # Revert to default True if the provided value is not a boolean
# Update the NMNE capture keywords, appending new keywords if provided
NMNE_CAPTURE_KEYWORDS += nmne_config.get("nmne_capture_keywords", [])
if not isinstance(NMNE_CAPTURE_KEYWORDS, list):
NMNE_CAPTURE_KEYWORDS = [] # Reset to empty list if the provided value is not a list

View File

@@ -0,0 +1,120 @@
from primaite.game.agent.observations import NicObservation
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.nmne import set_nmne_config
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.database_client import DatabaseClient
def test_capture_nmne(uc2_network):
"""
Conducts a test to verify that Malicious Network Events (MNEs) are correctly captured.
This test involves a web server querying a database server and checks if the MNEs are captured
based on predefined keywords in the network configuration. Specifically, it checks the capture
of the "DELETE" SQL command as a malicious network event.
"""
web_server: Server = uc2_network.get_node_by_hostname("web_server") # noqa
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] # noqa
db_client.connect()
db_server: Server = uc2_network.get_node_by_hostname("database_server") # noqa
web_server_nic = web_server.network_interface[1]
db_server_nic = db_server.network_interface[1]
# Set the NMNE configuration to capture DELETE queries as MNEs
nmne_config = {
"capture_nmne": True, # Enable the capture of MNEs
"nmne_capture_keywords": ["DELETE"], # Specify "DELETE" SQL command as a keyword for MNE detection
}
# Apply the NMNE configuration settings
set_nmne_config(nmne_config)
# Assert that initially, there are no captured MNEs on both web and database servers
assert web_server_nic.describe_state()["nmne"] == {}
assert db_server_nic.describe_state()["nmne"] == {}
# Perform a "SELECT" query
db_client.query("SELECT")
# Check that it does not trigger an MNE capture.
assert web_server_nic.describe_state()["nmne"] == {}
assert db_server_nic.describe_state()["nmne"] == {}
# Perform a "DELETE" query
db_client.query("DELETE")
# Check that the web server's outbound interface and the database server's inbound interface register the MNE
assert web_server_nic.describe_state()["nmne"] == {"direction": {"outbound": {"keywords": {"*": 1}}}}
assert db_server_nic.describe_state()["nmne"] == {"direction": {"inbound": {"keywords": {"*": 1}}}}
# Perform another "SELECT" query
db_client.query("SELECT")
# Check that no additional MNEs are captured
assert web_server_nic.describe_state()["nmne"] == {"direction": {"outbound": {"keywords": {"*": 1}}}}
assert db_server_nic.describe_state()["nmne"] == {"direction": {"inbound": {"keywords": {"*": 1}}}}
# Perform another "DELETE" query
db_client.query("DELETE")
# Check that the web server and database server interfaces register an additional MNE
assert web_server_nic.describe_state()["nmne"] == {"direction": {"outbound": {"keywords": {"*": 2}}}}
assert db_server_nic.describe_state()["nmne"] == {"direction": {"inbound": {"keywords": {"*": 2}}}}
def test_capture_nmne_observations(uc2_network):
"""
Tests the NicObservation class's functionality within a simulated network environment.
This test ensures the observation space, as defined by instances of NicObservation, accurately reflects the
number of MNEs detected based on network activities over multiple iterations.
The test employs a series of "DELETE" SQL operations, considered as MNEs, to validate the dynamic update
and accuracy of the observation space related to network interface conditions. It confirms that the
observed NIC states match expected MNE activity levels.
"""
# Initialise a new Simulation instance and assign the test network to it.
sim = Simulation()
sim.network = uc2_network
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
db_client.connect()
# Set the NMNE configuration to capture DELETE queries as MNEs
nmne_config = {
"capture_nmne": True, # Enable the capture of MNEs
"nmne_capture_keywords": ["DELETE"], # Specify "DELETE" SQL command as a keyword for MNE detection
}
# Apply the NMNE configuration settings
set_nmne_config(nmne_config)
# Define observations for the NICs of the database and web servers
db_server_nic_obs = NicObservation(where=["network", "nodes", "database_server", "NICs", 1])
web_server_nic_obs = NicObservation(where=["network", "nodes", "web_server", "NICs", 1])
# Iterate through a set of test cases to simulate multiple DELETE queries
for i in range(1, 20):
# Perform a "DELETE" query each iteration
db_client.query("DELETE")
# Observe the current state of NMNEs from the NICs of both the database and web servers
db_nic_obs = db_server_nic_obs.observe(sim.describe_state())["nmne"]
web_nic_obs = web_server_nic_obs.observe(sim.describe_state())["nmne"]
# Define expected NMNE values based on the iteration count
if i > 10:
expected_nmne = 3 # High level of detected MNEs after 10 iterations
elif i > 5:
expected_nmne = 2 # Moderate level after more than 5 iterations
elif i > 0:
expected_nmne = 1 # Low level detected after just starting
else:
expected_nmne = 0 # No MNEs detected
# Assert that the observed NMNEs match the expected values for both NICs
assert web_nic_obs["outbound"] == expected_nmne
assert db_nic_obs["inbound"] == expected_nmne