#2676: Store NMNE config data in class variable.

This commit is contained in:
Nick Todd
2024-07-04 15:41:13 +01:00
parent dbc1d73c34
commit 47df2aa569
2 changed files with 20 additions and 35 deletions

View File

@@ -15,7 +15,7 @@ from primaite.game.agent.scripted_agents.probabilistic_agent import Probabilisti
from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
from primaite.game.agent.scripted_agents.tap001 import TAP001
from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator.network.hardware.base import NodeOperatingState
from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.host.server import Printer, Server
@@ -23,7 +23,7 @@ from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
from primaite.simulator.network.hardware.nodes.network.router import Router
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from primaite.simulator.network.nmne import store_nmne_config, NmneData
from primaite.simulator.network.nmne import NmneData, store_nmne_config
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.database_client import DatabaseClient
@@ -239,6 +239,8 @@ class PrimaiteGame:
nodes_cfg = network_config.get("nodes", [])
links_cfg = network_config.get("links", [])
# Set the NMNE capture config
NetworkInterface.nmne_config = store_nmne_config(network_config.get("nmne_config", {}))
for node_cfg in nodes_cfg:
n_type = node_cfg["type"]
@@ -500,10 +502,6 @@ class PrimaiteGame:
game.setup_reward_sharing()
game.update_agents(game.get_sim_state())
# Set the NMNE capture config
game.nmne_config = store_nmne_config(network_config.get("nmne_config", {}))
return game
def setup_reward_sharing(self):
@@ -543,6 +541,3 @@ class PrimaiteGame:
# sort the agents so the rewards that depend on other rewards are always evaluated later
self._reward_calculation_order = topological_sort(graph)
def get_nmne_config(self) -> NmneData:
return self.nmne_config

View File

@@ -6,12 +6,11 @@ import secrets
from abc import ABC, abstractmethod
from ipaddress import IPv4Address, IPv4Network
from pathlib import Path
from typing import Any, Dict, Optional, Type, TypeVar, Union
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import BaseModel, Field
import primaite.simulator.network.nmne
from primaite import getLogger
from primaite.exceptions import NetworkError
from primaite.interface.request import RequestResponse
@@ -20,15 +19,7 @@ from primaite.simulator.core import RequestFormat, RequestManager, RequestPermis
from primaite.simulator.domain.account import Account
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.nmne import (
CAPTURE_BY_DIRECTION,
CAPTURE_BY_IP_ADDRESS,
CAPTURE_BY_KEYWORD,
CAPTURE_BY_PORT,
CAPTURE_BY_PROTOCOL,
CAPTURE_NMNE,
NMNE_CAPTURE_KEYWORDS,
)
from primaite.simulator.network.nmne import NmneData
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.system.applications.application import Application
@@ -108,8 +99,8 @@ class NetworkInterface(SimComponent, ABC):
pcap: Optional[PacketCapture] = None
"A PacketCapture instance for capturing and analysing packets passing through this interface."
nmne: Dict = Field(default_factory=lambda: {})
"A dict containing details of the number of malicious network events captured."
nmne_config: ClassVar[NmneData] = None
"A dataclass defining malicious network events to be captured."
traffic: Dict = Field(default_factory=lambda: {})
"A dict containing details of the inbound and outbound traffic by port and protocol."
@@ -117,7 +108,6 @@ class NetworkInterface(SimComponent, ABC):
def setup_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
super().setup_for_episode(episode=episode)
self.nmne = {}
self.traffic = {}
if episode and self.pcap and SIM_OUTPUT.save_pcap_logs:
self.pcap.current_episode = episode
@@ -152,8 +142,8 @@ class NetworkInterface(SimComponent, ABC):
"enabled": self.enabled,
}
)
if CAPTURE_NMNE:
state.update({"nmne": {k: v for k, v in self.nmne.items()}})
if self.nmne_config and self.nmne_config.capture_nmne:
state.update({"nmne": {self.nmne_config.__dict__}})
state.update({"traffic": convert_dict_enum_keys_to_enum_values(self.traffic)})
return state
@@ -186,7 +176,7 @@ class NetworkInterface(SimComponent, ABC):
:param inbound: Boolean indicating if the frame direction is inbound. Defaults to True.
"""
# Exit function if NMNE capturing is disabled
if not CAPTURE_NMNE:
if not (self.nmne_config and self.nmne_config.capture_nmne):
return
# Initialise basic frame data variables
@@ -207,27 +197,27 @@ class NetworkInterface(SimComponent, ABC):
frame_str = str(frame.payload)
# Proceed only if any NMNE keyword is present in the frame payload
if any(keyword in frame_str for keyword in NMNE_CAPTURE_KEYWORDS):
if any(keyword in frame_str for keyword in self.nmne_config.nmne_capture_keywords):
# Start with the root of the NMNE capture structure
current_level = self.nmne
current_level = self.nmne_config
# Update NMNE structure based on enabled settings
if CAPTURE_BY_DIRECTION:
if self.nmne_config.capture_by_direction:
# Set or get the dictionary for the current direction
current_level = current_level.setdefault("direction", {})
current_level = current_level.setdefault(direction, {})
if CAPTURE_BY_IP_ADDRESS:
if self.nmne_config.capture_by_ip_address:
# Set or get the dictionary for the current IP address
current_level = current_level.setdefault("ip_address", {})
current_level = current_level.setdefault(ip_address, {})
if CAPTURE_BY_PROTOCOL:
if self.nmne_config.capture_by_protocol:
# Set or get the dictionary for the current protocol
current_level = current_level.setdefault("protocol", {})
current_level = current_level.setdefault(protocol, {})
if CAPTURE_BY_PORT:
if self.nmne_config.capture_by_port:
# Set or get the dictionary for the current port
current_level = current_level.setdefault("port", {})
current_level = current_level.setdefault(port, {})
@@ -236,8 +226,8 @@ class NetworkInterface(SimComponent, ABC):
keyword_level = current_level.setdefault("keywords", {})
# Increment the count for detected keywords in the payload
if CAPTURE_BY_KEYWORD:
for keyword in NMNE_CAPTURE_KEYWORDS:
if self.nmne_config.capture_by_keyword:
for keyword in self.nmne_config.nmne_capture_keywords:
if keyword in frame_str:
# Update the count for each keyword found
keyword_level[keyword] = keyword_level.get(keyword, 0) + 1
@@ -1067,7 +1057,7 @@ class Node(SimComponent):
ip_address,
network_interface.speed,
"Enabled" if network_interface.enabled else "Disabled",
network_interface.nmne if primaite.simulator.network.nmne.CAPTURE_NMNE else "Disabled",
network_interface.nmne if self.nmne_config.capture_nmne else "Disabled",
]
)
print(table)