#2676: Store NMNE config data in class variable.

This commit is contained in:
Nick Todd
2024-07-04 15:41:13 +01:00
parent dbc1d73c34
commit 47df2aa569
2 changed files with 20 additions and 35 deletions

View File

@@ -6,12 +6,11 @@ import secrets
from abc import ABC, abstractmethod
from ipaddress import IPv4Address, IPv4Network
from pathlib import Path
from typing import Any, Dict, Optional, Type, TypeVar, Union
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import BaseModel, Field
import primaite.simulator.network.nmne
from primaite import getLogger
from primaite.exceptions import NetworkError
from primaite.interface.request import RequestResponse
@@ -20,15 +19,7 @@ from primaite.simulator.core import RequestFormat, RequestManager, RequestPermis
from primaite.simulator.domain.account import Account
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.nmne import (
CAPTURE_BY_DIRECTION,
CAPTURE_BY_IP_ADDRESS,
CAPTURE_BY_KEYWORD,
CAPTURE_BY_PORT,
CAPTURE_BY_PROTOCOL,
CAPTURE_NMNE,
NMNE_CAPTURE_KEYWORDS,
)
from primaite.simulator.network.nmne import NmneData
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.system.applications.application import Application
@@ -108,8 +99,8 @@ class NetworkInterface(SimComponent, ABC):
pcap: Optional[PacketCapture] = None
"A PacketCapture instance for capturing and analysing packets passing through this interface."
nmne: Dict = Field(default_factory=lambda: {})
"A dict containing details of the number of malicious network events captured."
nmne_config: ClassVar[NmneData] = None
"A dataclass defining malicious network events to be captured."
traffic: Dict = Field(default_factory=lambda: {})
"A dict containing details of the inbound and outbound traffic by port and protocol."
@@ -117,7 +108,6 @@ class NetworkInterface(SimComponent, ABC):
def setup_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
super().setup_for_episode(episode=episode)
self.nmne = {}
self.traffic = {}
if episode and self.pcap and SIM_OUTPUT.save_pcap_logs:
self.pcap.current_episode = episode
@@ -152,8 +142,8 @@ class NetworkInterface(SimComponent, ABC):
"enabled": self.enabled,
}
)
if CAPTURE_NMNE:
state.update({"nmne": {k: v for k, v in self.nmne.items()}})
if self.nmne_config and self.nmne_config.capture_nmne:
state.update({"nmne": {self.nmne_config.__dict__}})
state.update({"traffic": convert_dict_enum_keys_to_enum_values(self.traffic)})
return state
@@ -186,7 +176,7 @@ class NetworkInterface(SimComponent, ABC):
:param inbound: Boolean indicating if the frame direction is inbound. Defaults to True.
"""
# Exit function if NMNE capturing is disabled
if not CAPTURE_NMNE:
if not (self.nmne_config and self.nmne_config.capture_nmne):
return
# Initialise basic frame data variables
@@ -207,27 +197,27 @@ class NetworkInterface(SimComponent, ABC):
frame_str = str(frame.payload)
# Proceed only if any NMNE keyword is present in the frame payload
if any(keyword in frame_str for keyword in NMNE_CAPTURE_KEYWORDS):
if any(keyword in frame_str for keyword in self.nmne_config.nmne_capture_keywords):
# Start with the root of the NMNE capture structure
current_level = self.nmne
current_level = self.nmne_config
# Update NMNE structure based on enabled settings
if CAPTURE_BY_DIRECTION:
if self.nmne_config.capture_by_direction:
# Set or get the dictionary for the current direction
current_level = current_level.setdefault("direction", {})
current_level = current_level.setdefault(direction, {})
if CAPTURE_BY_IP_ADDRESS:
if self.nmne_config.capture_by_ip_address:
# Set or get the dictionary for the current IP address
current_level = current_level.setdefault("ip_address", {})
current_level = current_level.setdefault(ip_address, {})
if CAPTURE_BY_PROTOCOL:
if self.nmne_config.capture_by_protocol:
# Set or get the dictionary for the current protocol
current_level = current_level.setdefault("protocol", {})
current_level = current_level.setdefault(protocol, {})
if CAPTURE_BY_PORT:
if self.nmne_config.capture_by_port:
# Set or get the dictionary for the current port
current_level = current_level.setdefault("port", {})
current_level = current_level.setdefault(port, {})
@@ -236,8 +226,8 @@ class NetworkInterface(SimComponent, ABC):
keyword_level = current_level.setdefault("keywords", {})
# Increment the count for detected keywords in the payload
if CAPTURE_BY_KEYWORD:
for keyword in NMNE_CAPTURE_KEYWORDS:
if self.nmne_config.capture_by_keyword:
for keyword in self.nmne_config.nmne_capture_keywords:
if keyword in frame_str:
# Update the count for each keyword found
keyword_level[keyword] = keyword_level.get(keyword, 0) + 1
@@ -1067,7 +1057,7 @@ class Node(SimComponent):
ip_address,
network_interface.speed,
"Enabled" if network_interface.enabled else "Disabled",
network_interface.nmne if primaite.simulator.network.nmne.CAPTURE_NMNE else "Disabled",
network_interface.nmne if self.nmne_config.capture_nmne else "Disabled",
]
)
print(table)