#2676: Store NMNE config data in class variable.
This commit is contained in:
@@ -15,7 +15,7 @@ from primaite.game.agent.scripted_agents.probabilistic_agent import Probabilisti
|
||||
from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
|
||||
from primaite.game.agent.scripted_agents.tap001 import TAP001
|
||||
from primaite.game.science import graph_has_cycle, topological_sort
|
||||
from primaite.simulator.network.hardware.base import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Printer, Server
|
||||
@@ -23,7 +23,7 @@ from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
|
||||
from primaite.simulator.network.hardware.nodes.network.router import Router
|
||||
from primaite.simulator.network.hardware.nodes.network.switch import Switch
|
||||
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
|
||||
from primaite.simulator.network.nmne import store_nmne_config, NmneData
|
||||
from primaite.simulator.network.nmne import NmneData, store_nmne_config
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
@@ -239,6 +239,8 @@ class PrimaiteGame:
|
||||
|
||||
nodes_cfg = network_config.get("nodes", [])
|
||||
links_cfg = network_config.get("links", [])
|
||||
# Set the NMNE capture config
|
||||
NetworkInterface.nmne_config = store_nmne_config(network_config.get("nmne_config", {}))
|
||||
|
||||
for node_cfg in nodes_cfg:
|
||||
n_type = node_cfg["type"]
|
||||
@@ -500,10 +502,6 @@ class PrimaiteGame:
|
||||
game.setup_reward_sharing()
|
||||
|
||||
game.update_agents(game.get_sim_state())
|
||||
|
||||
# Set the NMNE capture config
|
||||
game.nmne_config = store_nmne_config(network_config.get("nmne_config", {}))
|
||||
|
||||
return game
|
||||
|
||||
def setup_reward_sharing(self):
|
||||
@@ -543,6 +541,3 @@ class PrimaiteGame:
|
||||
|
||||
# sort the agents so the rewards that depend on other rewards are always evaluated later
|
||||
self._reward_calculation_order = topological_sort(graph)
|
||||
|
||||
def get_nmne_config(self) -> NmneData:
|
||||
return self.nmne_config
|
||||
|
||||
@@ -6,12 +6,11 @@ import secrets
|
||||
from abc import ABC, abstractmethod
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Type, TypeVar, Union
|
||||
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import primaite.simulator.network.nmne
|
||||
from primaite import getLogger
|
||||
from primaite.exceptions import NetworkError
|
||||
from primaite.interface.request import RequestResponse
|
||||
@@ -20,15 +19,7 @@ from primaite.simulator.core import RequestFormat, RequestManager, RequestPermis
|
||||
from primaite.simulator.domain.account import Account
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.nmne import (
|
||||
CAPTURE_BY_DIRECTION,
|
||||
CAPTURE_BY_IP_ADDRESS,
|
||||
CAPTURE_BY_KEYWORD,
|
||||
CAPTURE_BY_PORT,
|
||||
CAPTURE_BY_PROTOCOL,
|
||||
CAPTURE_NMNE,
|
||||
NMNE_CAPTURE_KEYWORDS,
|
||||
)
|
||||
from primaite.simulator.network.nmne import NmneData
|
||||
from primaite.simulator.network.transmission.data_link_layer import Frame
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
@@ -108,8 +99,8 @@ class NetworkInterface(SimComponent, ABC):
|
||||
pcap: Optional[PacketCapture] = None
|
||||
"A PacketCapture instance for capturing and analysing packets passing through this interface."
|
||||
|
||||
nmne: Dict = Field(default_factory=lambda: {})
|
||||
"A dict containing details of the number of malicious network events captured."
|
||||
nmne_config: ClassVar[NmneData] = None
|
||||
"A dataclass defining malicious network events to be captured."
|
||||
|
||||
traffic: Dict = Field(default_factory=lambda: {})
|
||||
"A dict containing details of the inbound and outbound traffic by port and protocol."
|
||||
@@ -117,7 +108,6 @@ class NetworkInterface(SimComponent, ABC):
|
||||
def setup_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
super().setup_for_episode(episode=episode)
|
||||
self.nmne = {}
|
||||
self.traffic = {}
|
||||
if episode and self.pcap and SIM_OUTPUT.save_pcap_logs:
|
||||
self.pcap.current_episode = episode
|
||||
@@ -152,8 +142,8 @@ class NetworkInterface(SimComponent, ABC):
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
)
|
||||
if CAPTURE_NMNE:
|
||||
state.update({"nmne": {k: v for k, v in self.nmne.items()}})
|
||||
if self.nmne_config and self.nmne_config.capture_nmne:
|
||||
state.update({"nmne": {self.nmne_config.__dict__}})
|
||||
state.update({"traffic": convert_dict_enum_keys_to_enum_values(self.traffic)})
|
||||
return state
|
||||
|
||||
@@ -186,7 +176,7 @@ class NetworkInterface(SimComponent, ABC):
|
||||
:param inbound: Boolean indicating if the frame direction is inbound. Defaults to True.
|
||||
"""
|
||||
# Exit function if NMNE capturing is disabled
|
||||
if not CAPTURE_NMNE:
|
||||
if not (self.nmne_config and self.nmne_config.capture_nmne):
|
||||
return
|
||||
|
||||
# Initialise basic frame data variables
|
||||
@@ -207,27 +197,27 @@ class NetworkInterface(SimComponent, ABC):
|
||||
frame_str = str(frame.payload)
|
||||
|
||||
# Proceed only if any NMNE keyword is present in the frame payload
|
||||
if any(keyword in frame_str for keyword in NMNE_CAPTURE_KEYWORDS):
|
||||
if any(keyword in frame_str for keyword in self.nmne_config.nmne_capture_keywords):
|
||||
# Start with the root of the NMNE capture structure
|
||||
current_level = self.nmne
|
||||
current_level = self.nmne_config
|
||||
|
||||
# Update NMNE structure based on enabled settings
|
||||
if CAPTURE_BY_DIRECTION:
|
||||
if self.nmne_config.capture_by_direction:
|
||||
# Set or get the dictionary for the current direction
|
||||
current_level = current_level.setdefault("direction", {})
|
||||
current_level = current_level.setdefault(direction, {})
|
||||
|
||||
if CAPTURE_BY_IP_ADDRESS:
|
||||
if self.nmne_config.capture_by_ip_address:
|
||||
# Set or get the dictionary for the current IP address
|
||||
current_level = current_level.setdefault("ip_address", {})
|
||||
current_level = current_level.setdefault(ip_address, {})
|
||||
|
||||
if CAPTURE_BY_PROTOCOL:
|
||||
if self.nmne_config.capture_by_protocol:
|
||||
# Set or get the dictionary for the current protocol
|
||||
current_level = current_level.setdefault("protocol", {})
|
||||
current_level = current_level.setdefault(protocol, {})
|
||||
|
||||
if CAPTURE_BY_PORT:
|
||||
if self.nmne_config.capture_by_port:
|
||||
# Set or get the dictionary for the current port
|
||||
current_level = current_level.setdefault("port", {})
|
||||
current_level = current_level.setdefault(port, {})
|
||||
@@ -236,8 +226,8 @@ class NetworkInterface(SimComponent, ABC):
|
||||
keyword_level = current_level.setdefault("keywords", {})
|
||||
|
||||
# Increment the count for detected keywords in the payload
|
||||
if CAPTURE_BY_KEYWORD:
|
||||
for keyword in NMNE_CAPTURE_KEYWORDS:
|
||||
if self.nmne_config.capture_by_keyword:
|
||||
for keyword in self.nmne_config.nmne_capture_keywords:
|
||||
if keyword in frame_str:
|
||||
# Update the count for each keyword found
|
||||
keyword_level[keyword] = keyword_level.get(keyword, 0) + 1
|
||||
@@ -1067,7 +1057,7 @@ class Node(SimComponent):
|
||||
ip_address,
|
||||
network_interface.speed,
|
||||
"Enabled" if network_interface.enabled else "Disabled",
|
||||
network_interface.nmne if primaite.simulator.network.nmne.CAPTURE_NMNE else "Disabled",
|
||||
network_interface.nmne if self.nmne_config.capture_nmne else "Disabled",
|
||||
]
|
||||
)
|
||||
print(table)
|
||||
|
||||
Reference in New Issue
Block a user