From 47df2aa56940c26047c4e2b6672867e9016c8b1f Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Thu, 4 Jul 2024 15:41:13 +0100 Subject: [PATCH] #2676: Store NMNE config data in class variable. --- src/primaite/game/game.py | 13 ++---- .../simulator/network/hardware/base.py | 42 +++++++------------ 2 files changed, 20 insertions(+), 35 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index cc559b4d..9636bd23 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -15,7 +15,7 @@ from primaite.game.agent.scripted_agents.probabilistic_agent import Probabilisti from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent from primaite.game.agent.scripted_agents.tap001 import TAP001 from primaite.game.science import graph_has_cycle, topological_sort -from primaite.simulator.network.hardware.base import NodeOperatingState +from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Printer, Server @@ -23,7 +23,7 @@ from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter -from primaite.simulator.network.nmne import store_nmne_config, NmneData +from primaite.simulator.network.nmne import NmneData, store_nmne_config from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.database_client import DatabaseClient @@ -239,6 +239,8 @@ class PrimaiteGame: nodes_cfg = network_config.get("nodes", []) links_cfg = network_config.get("links", []) + # Set the NMNE capture config + NetworkInterface.nmne_config = store_nmne_config(network_config.get("nmne_config", {})) for node_cfg in nodes_cfg: n_type = node_cfg["type"] @@ -500,10 +502,6 @@ class PrimaiteGame: game.setup_reward_sharing() game.update_agents(game.get_sim_state()) - - # Set the NMNE capture config - game.nmne_config = store_nmne_config(network_config.get("nmne_config", {})) - return game def setup_reward_sharing(self): @@ -543,6 +541,3 @@ class PrimaiteGame: # sort the agents so the rewards that depend on other rewards are always evaluated later self._reward_calculation_order = topological_sort(graph) - - def get_nmne_config(self) -> NmneData: - return self.nmne_config diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 01745215..6d753731 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -6,12 +6,11 @@ import secrets from abc import ABC, abstractmethod from ipaddress import IPv4Address, IPv4Network from pathlib import Path -from typing import Any, Dict, Optional, Type, TypeVar, Union +from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union from prettytable import MARKDOWN, PrettyTable from pydantic import BaseModel, Field -import primaite.simulator.network.nmne from primaite import getLogger from primaite.exceptions import NetworkError from primaite.interface.request import RequestResponse @@ -20,15 +19,7 @@ from primaite.simulator.core import RequestFormat, RequestManager, RequestPermis from primaite.simulator.domain.account import Account from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState -from primaite.simulator.network.nmne import ( - CAPTURE_BY_DIRECTION, - CAPTURE_BY_IP_ADDRESS, - CAPTURE_BY_KEYWORD, - CAPTURE_BY_PORT, - CAPTURE_BY_PROTOCOL, - CAPTURE_NMNE, - NMNE_CAPTURE_KEYWORDS, -) +from primaite.simulator.network.nmne import NmneData from primaite.simulator.network.transmission.data_link_layer import Frame from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.system.applications.application import Application @@ -108,8 +99,8 @@ class NetworkInterface(SimComponent, ABC): pcap: Optional[PacketCapture] = None "A PacketCapture instance for capturing and analysing packets passing through this interface." - nmne: Dict = Field(default_factory=lambda: {}) - "A dict containing details of the number of malicious network events captured." + nmne_config: ClassVar[NmneData] = None + "A dataclass defining malicious network events to be captured." traffic: Dict = Field(default_factory=lambda: {}) "A dict containing details of the inbound and outbound traffic by port and protocol." @@ -117,7 +108,6 @@ class NetworkInterface(SimComponent, ABC): def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" super().setup_for_episode(episode=episode) - self.nmne = {} self.traffic = {} if episode and self.pcap and SIM_OUTPUT.save_pcap_logs: self.pcap.current_episode = episode @@ -152,8 +142,8 @@ class NetworkInterface(SimComponent, ABC): "enabled": self.enabled, } ) - if CAPTURE_NMNE: - state.update({"nmne": {k: v for k, v in self.nmne.items()}}) + if self.nmne_config and self.nmne_config.capture_nmne: + state.update({"nmne": {self.nmne_config.__dict__}}) state.update({"traffic": convert_dict_enum_keys_to_enum_values(self.traffic)}) return state @@ -186,7 +176,7 @@ class NetworkInterface(SimComponent, ABC): :param inbound: Boolean indicating if the frame direction is inbound. Defaults to True. """ # Exit function if NMNE capturing is disabled - if not CAPTURE_NMNE: + if not (self.nmne_config and self.nmne_config.capture_nmne): return # Initialise basic frame data variables @@ -207,27 +197,27 @@ class NetworkInterface(SimComponent, ABC): frame_str = str(frame.payload) # Proceed only if any NMNE keyword is present in the frame payload - if any(keyword in frame_str for keyword in NMNE_CAPTURE_KEYWORDS): + if any(keyword in frame_str for keyword in self.nmne_config.nmne_capture_keywords): # Start with the root of the NMNE capture structure - current_level = self.nmne + current_level = self.nmne_config # Update NMNE structure based on enabled settings - if CAPTURE_BY_DIRECTION: + if self.nmne_config.capture_by_direction: # Set or get the dictionary for the current direction current_level = current_level.setdefault("direction", {}) current_level = current_level.setdefault(direction, {}) - if CAPTURE_BY_IP_ADDRESS: + if self.nmne_config.capture_by_ip_address: # Set or get the dictionary for the current IP address current_level = current_level.setdefault("ip_address", {}) current_level = current_level.setdefault(ip_address, {}) - if CAPTURE_BY_PROTOCOL: + if self.nmne_config.capture_by_protocol: # Set or get the dictionary for the current protocol current_level = current_level.setdefault("protocol", {}) current_level = current_level.setdefault(protocol, {}) - if CAPTURE_BY_PORT: + if self.nmne_config.capture_by_port: # Set or get the dictionary for the current port current_level = current_level.setdefault("port", {}) current_level = current_level.setdefault(port, {}) @@ -236,8 +226,8 @@ class NetworkInterface(SimComponent, ABC): keyword_level = current_level.setdefault("keywords", {}) # Increment the count for detected keywords in the payload - if CAPTURE_BY_KEYWORD: - for keyword in NMNE_CAPTURE_KEYWORDS: + if self.nmne_config.capture_by_keyword: + for keyword in self.nmne_config.nmne_capture_keywords: if keyword in frame_str: # Update the count for each keyword found keyword_level[keyword] = keyword_level.get(keyword, 0) + 1 @@ -1067,7 +1057,7 @@ class Node(SimComponent): ip_address, network_interface.speed, "Enabled" if network_interface.enabled else "Disabled", - network_interface.nmne if primaite.simulator.network.nmne.CAPTURE_NMNE else "Disabled", + network_interface.nmne if self.nmne_config.capture_nmne else "Disabled", ] ) print(table)