diff --git a/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml b/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml index 81848b2d..dfd200f3 100644 --- a/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml +++ b/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml @@ -129,6 +129,10 @@ agents: simulation: network: + nmne_config: + capture_nmne: true + nmne_capture_keywords: + - DELETE nodes: - hostname: client type: computer diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 68abf9f2..1e4e8825 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -18,7 +18,7 @@ from primaite.game.agent.scripted_agents.tap001 import TAP001 from primaite.game.science import graph_has_cycle, topological_sort from primaite.simulator import SIM_OUTPUT from primaite.simulator.network.airspace import AirSpaceFrequency -from primaite.simulator.network.hardware.base import NodeOperatingState, UserManager +from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState, UserManager from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Printer, Server @@ -26,7 +26,7 @@ from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter -from primaite.simulator.network.nmne import set_nmne_config +from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.application import Application @@ -264,6 +264,8 @@ class PrimaiteGame: nodes_cfg = network_config.get("nodes", []) links_cfg = network_config.get("links", []) + # Set the NMNE capture config + NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {})) for node_cfg in nodes_cfg: n_type = node_cfg["type"] @@ -539,10 +541,7 @@ class PrimaiteGame: # Validate that if any agents are sharing rewards, they aren't forming an infinite loop. game.setup_reward_sharing() - # Set the NMNE capture config - set_nmne_config(network_config.get("nmne_config", {})) game.update_agents(game.get_sim_state()) - return game def setup_reward_sharing(self): diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 1d320824..8b9d2688 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -11,7 +11,6 @@ from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union from prettytable import MARKDOWN, PrettyTable from pydantic import BaseModel, Field, validate_call -import primaite.simulator.network.nmne from primaite import getLogger from primaite.exceptions import NetworkError from primaite.interface.request import RequestResponse @@ -20,15 +19,7 @@ from primaite.simulator.core import RequestFormat, RequestManager, RequestPermis from primaite.simulator.domain.account import Account from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState -from primaite.simulator.network.nmne import ( - CAPTURE_BY_DIRECTION, - CAPTURE_BY_IP_ADDRESS, - CAPTURE_BY_KEYWORD, - CAPTURE_BY_PORT, - CAPTURE_BY_PROTOCOL, - CAPTURE_NMNE, - NMNE_CAPTURE_KEYWORDS, -) +from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.network.transmission.data_link_layer import Frame from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port @@ -109,8 +100,11 @@ class NetworkInterface(SimComponent, ABC): pcap: Optional[PacketCapture] = None "A PacketCapture instance for capturing and analysing packets passing through this interface." + nmne_config: ClassVar[NMNEConfig] = NMNEConfig() + "A dataclass defining malicious network events to be captured." + nmne: Dict = Field(default_factory=lambda: {}) - "A dict containing details of the number of malicious network events captured." + "A dict containing details of the number of malicious events captured." traffic: Dict = Field(default_factory=lambda: {}) "A dict containing details of the inbound and outbound traffic by port and protocol." @@ -168,8 +162,8 @@ class NetworkInterface(SimComponent, ABC): "enabled": self.enabled, } ) - if CAPTURE_NMNE: - state.update({"nmne": {k: v for k, v in self.nmne.items()}}) + if self.nmne_config and self.nmne_config.capture_nmne: + state.update({"nmne": self.nmne}) state.update({"traffic": convert_dict_enum_keys_to_enum_values(self.traffic)}) return state @@ -202,7 +196,7 @@ class NetworkInterface(SimComponent, ABC): :param inbound: Boolean indicating if the frame direction is inbound. Defaults to True. """ # Exit function if NMNE capturing is disabled - if not CAPTURE_NMNE: + if not (self.nmne_config and self.nmne_config.capture_nmne): return # Initialise basic frame data variables @@ -223,27 +217,27 @@ class NetworkInterface(SimComponent, ABC): frame_str = str(frame.payload) # Proceed only if any NMNE keyword is present in the frame payload - if any(keyword in frame_str for keyword in NMNE_CAPTURE_KEYWORDS): + if any(keyword in frame_str for keyword in self.nmne_config.nmne_capture_keywords): # Start with the root of the NMNE capture structure current_level = self.nmne # Update NMNE structure based on enabled settings - if CAPTURE_BY_DIRECTION: + if self.nmne_config.capture_by_direction: # Set or get the dictionary for the current direction current_level = current_level.setdefault("direction", {}) current_level = current_level.setdefault(direction, {}) - if CAPTURE_BY_IP_ADDRESS: + if self.nmne_config.capture_by_ip_address: # Set or get the dictionary for the current IP address current_level = current_level.setdefault("ip_address", {}) current_level = current_level.setdefault(ip_address, {}) - if CAPTURE_BY_PROTOCOL: + if self.nmne_config.capture_by_protocol: # Set or get the dictionary for the current protocol current_level = current_level.setdefault("protocol", {}) current_level = current_level.setdefault(protocol, {}) - if CAPTURE_BY_PORT: + if self.nmne_config.capture_by_port: # Set or get the dictionary for the current port current_level = current_level.setdefault("port", {}) current_level = current_level.setdefault(port, {}) @@ -252,8 +246,8 @@ class NetworkInterface(SimComponent, ABC): keyword_level = current_level.setdefault("keywords", {}) # Increment the count for detected keywords in the payload - if CAPTURE_BY_KEYWORD: - for keyword in NMNE_CAPTURE_KEYWORDS: + if self.nmne_config.capture_by_keyword: + for keyword in self.nmne_config.nmne_capture_keywords: if keyword in frame_str: # Update the count for each keyword found keyword_level[keyword] = keyword_level.get(keyword, 0) + 1 @@ -1848,7 +1842,7 @@ class Node(SimComponent): ip_address, network_interface.speed, "Enabled" if network_interface.enabled else "Disabled", - network_interface.nmne if primaite.simulator.network.nmne.CAPTURE_NMNE else "Disabled", + network_interface.nmne if network_interface.nmne_config.capture_nmne else "Disabled", ] ) print(table) diff --git a/src/primaite/simulator/network/nmne.py b/src/primaite/simulator/network/nmne.py index 5c0c657b..c9cff5de 100644 --- a/src/primaite/simulator/network/nmne.py +++ b/src/primaite/simulator/network/nmne.py @@ -1,48 +1,25 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from typing import Dict, Final, List +from typing import List -CAPTURE_NMNE: bool = True -"""Indicates whether Malicious Network Events (MNEs) should be captured. Default is True.""" - -NMNE_CAPTURE_KEYWORDS: List[str] = [] -"""List of keywords to identify malicious network events.""" - -# TODO: Remove final and make configurable after example layout when the NICObservation creates nmne structure dynamically -CAPTURE_BY_DIRECTION: Final[bool] = True -"""Flag to determine if captures should be organized by traffic direction (inbound/outbound).""" -CAPTURE_BY_IP_ADDRESS: Final[bool] = False -"""Flag to determine if captures should be organized by source or destination IP address.""" -CAPTURE_BY_PROTOCOL: Final[bool] = False -"""Flag to determine if captures should be organized by network protocol (e.g., TCP, UDP).""" -CAPTURE_BY_PORT: Final[bool] = False -"""Flag to determine if captures should be organized by source or destination port.""" -CAPTURE_BY_KEYWORD: Final[bool] = False -"""Flag to determine if captures should be filtered and categorised based on specific keywords.""" +from pydantic import BaseModel, ConfigDict -def set_nmne_config(nmne_config: Dict): - """ - Sets the configuration for capturing Malicious Network Events (MNEs) based on a provided dictionary. +class NMNEConfig(BaseModel): + """Store all the information to perform NMNE operations.""" - This function updates global settings related to NMNE capture, including whether to capture NMNEs and what - keywords to use for identifying NMNEs. + model_config = ConfigDict(extra="forbid") - The function ensures that the settings are updated only if they are provided in the `nmne_config` dictionary, - and maintains type integrity by checking the types of the provided values. - - :param nmne_config: A dictionary containing the NMNE configuration settings. Possible keys include: - "capture_nmne" (bool) to indicate whether NMNEs should be captured, "nmne_capture_keywords" (list of strings) - to specify keywords for NMNE identification. - """ - global NMNE_CAPTURE_KEYWORDS - global CAPTURE_NMNE - - # Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect - CAPTURE_NMNE = nmne_config.get("capture_nmne", False) - if not isinstance(CAPTURE_NMNE, bool): - CAPTURE_NMNE = True # Revert to default True if the provided value is not a boolean - - # Update the NMNE capture keywords, appending new keywords if provided - NMNE_CAPTURE_KEYWORDS += nmne_config.get("nmne_capture_keywords", []) - if not isinstance(NMNE_CAPTURE_KEYWORDS, list): - NMNE_CAPTURE_KEYWORDS = [] # Reset to empty list if the provided value is not a list + capture_nmne: bool = False + """Indicates whether Malicious Network Events (MNEs) should be captured.""" + nmne_capture_keywords: List[str] = [] + """List of keywords to identify malicious network events.""" + capture_by_direction: bool = True + """Captures should be organized by traffic direction (inbound/outbound).""" + capture_by_ip_address: bool = False + """Captures should be organized by source or destination IP address.""" + capture_by_protocol: bool = False + """Captures should be organized by network protocol (e.g., TCP, UDP).""" + capture_by_port: bool = False + """Captures should be organized by source or destination port.""" + capture_by_keyword: bool = False + """Captures should be filtered and categorised based on specific keywords.""" diff --git a/src/primaite/simulator/network/protocols/icmp.py b/src/primaite/simulator/network/protocols/icmp.py index 743e2375..9f0626f0 100644 --- a/src/primaite/simulator/network/protocols/icmp.py +++ b/src/primaite/simulator/network/protocols/icmp.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Union from pydantic import BaseModel, field_validator, validate_call -from pydantic_core.core_schema import FieldValidationInfo +from pydantic_core.core_schema import ValidationInfo from primaite import getLogger @@ -96,7 +96,7 @@ class ICMPPacket(BaseModel): @field_validator("icmp_code") # noqa @classmethod - def _icmp_type_must_have_icmp_code(cls, v: int, info: FieldValidationInfo) -> int: + def _icmp_type_must_have_icmp_code(cls, v: int, info: ValidationInfo) -> int: """Validates the icmp_type and icmp_code.""" icmp_type = info.data["icmp_type"] if get_icmp_type_code_description(icmp_type, v): diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index 8cbd3ae9..c83cadc8 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -99,7 +99,7 @@ agents: num_files: 1 num_nics: 2 include_num_access: false - include_nmne: true + include_nmne: false routers: - hostname: router_1 num_ports: 0 diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml index 69187fa3..fed0f52d 100644 --- a/tests/assets/configs/basic_switched_network.yaml +++ b/tests/assets/configs/basic_switched_network.yaml @@ -92,7 +92,7 @@ agents: - NONE tcp: - DNS - include_nmne: true + include_nmne: false routers: - hostname: router_1 num_ports: 0 diff --git a/tests/assets/configs/eval_only_primaite_session.yaml b/tests/assets/configs/eval_only_primaite_session.yaml index de861dcc..3d60eb6e 100644 --- a/tests/assets/configs/eval_only_primaite_session.yaml +++ b/tests/assets/configs/eval_only_primaite_session.yaml @@ -111,7 +111,7 @@ agents: num_files: 1 num_nics: 2 include_num_access: false - include_nmne: true + include_nmne: false routers: - hostname: router_1 num_ports: 0 diff --git a/tests/assets/configs/firewall_actions_network.yaml b/tests/assets/configs/firewall_actions_network.yaml index fd5b1bf8..2292616d 100644 --- a/tests/assets/configs/firewall_actions_network.yaml +++ b/tests/assets/configs/firewall_actions_network.yaml @@ -68,7 +68,7 @@ agents: num_files: 1 num_nics: 2 include_num_access: false - include_nmne: true + include_nmne: false routers: - hostname: router_1 num_ports: 0 diff --git a/tests/assets/configs/fix_duration_one_item.yaml b/tests/assets/configs/fix_duration_one_item.yaml index 59bc15f9..bd0fb61f 100644 --- a/tests/assets/configs/fix_duration_one_item.yaml +++ b/tests/assets/configs/fix_duration_one_item.yaml @@ -89,7 +89,7 @@ agents: - NONE tcp: - DNS - include_nmne: true + include_nmne: false routers: - hostname: router_1 num_ports: 0 diff --git a/tests/assets/configs/scenario_with_placeholders/scenario.yaml b/tests/assets/configs/scenario_with_placeholders/scenario.yaml index 81848b2d..ef930a1a 100644 --- a/tests/assets/configs/scenario_with_placeholders/scenario.yaml +++ b/tests/assets/configs/scenario_with_placeholders/scenario.yaml @@ -44,7 +44,7 @@ agents: num_files: 1 num_nics: 1 include_num_access: false - include_nmne: true + include_nmne: false - type: LINKS label: LINKS diff --git a/tests/assets/configs/software_fix_duration.yaml b/tests/assets/configs/software_fix_duration.yaml index 1acb05a9..1a28258b 100644 --- a/tests/assets/configs/software_fix_duration.yaml +++ b/tests/assets/configs/software_fix_duration.yaml @@ -89,7 +89,7 @@ agents: - NONE tcp: - DNS - include_nmne: true + include_nmne: false routers: - hostname: router_1 num_ports: 0 diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index eb8103e8..27cfa240 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -120,7 +120,7 @@ agents: num_files: 1 num_nics: 2 include_num_access: false - include_nmne: true + include_nmne: false routers: - hostname: router_1 num_ports: 0 diff --git a/tests/conftest.py b/tests/conftest.py index ca704461..2d605c94 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,7 +30,7 @@ from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.web_server.web_server import WebServer from tests import TEST_ASSETS_ROOT -rayinit(local_mode=True) +rayinit() ACTION_SPACE_NODE_VALUES = 1 ACTION_SPACE_NODE_ACTION_VALUES = 1 diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index 88dd2bd5..ef789ba7 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -9,9 +9,11 @@ from gymnasium import spaces from primaite.game.agent.interface import ProxyAgent from primaite.game.agent.observations.nic_observations import NICObservation from primaite.game.game import PrimaiteGame +from primaite.simulator.network.hardware.base import NetworkInterface from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.web_browser import WebBrowser @@ -75,6 +77,18 @@ def test_nic(simulation): nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True) + # Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs + nmne_config = { + "capture_nmne": True, # Enable the capture of MNEs + "nmne_capture_keywords": [ + "DELETE", + "ENCRYPT", + ], # Specify "DELETE/ENCRYPT" SQL command as a keyword for MNE detection + } + + # Apply the NMNE configuration settings + NetworkInterface.nmne_config = NMNEConfig(**nmne_config) + assert nic_obs.space["nic_status"] == spaces.Discrete(3) assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4) assert nic_obs.space["NMNE"]["outbound"] == spaces.Discrete(4) @@ -144,7 +158,7 @@ def test_nic_monitored_traffic(simulation): pc2: Computer = simulation.network.get_node_by_hostname("client_2") nic_obs = NICObservation( - where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True, monitored_traffic=monitored_traffic + where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=False, monitored_traffic=monitored_traffic ) simulation.pre_timestep(0) # apply timestep to whole sim diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py index a8f1f245..debf5b1c 100644 --- a/tests/integration_tests/network/test_capture_nmne.py +++ b/tests/integration_tests/network/test_capture_nmne.py @@ -1,12 +1,14 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from primaite.game.agent.observations.nic_observations import NICObservation +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server -from primaite.simulator.network.nmne import set_nmne_config +from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection -def test_capture_nmne(uc2_network): +def test_capture_nmne(uc2_network: Network): """ Conducts a test to verify that Malicious Network Events (MNEs) are correctly captured. @@ -33,7 +35,7 @@ def test_capture_nmne(uc2_network): } # Apply the NMNE configuration settings - set_nmne_config(nmne_config) + NIC.nmne_config = NMNEConfig(**nmne_config) # Assert that initially, there are no captured MNEs on both web and database servers assert web_server_nic.nmne == {} @@ -82,7 +84,7 @@ def test_capture_nmne(uc2_network): assert db_server_nic.nmne == {"direction": {"inbound": {"keywords": {"*": 3}}}} -def test_describe_state_nmne(uc2_network): +def test_describe_state_nmne(uc2_network: Network): """ Conducts a test to verify that Malicious Network Events (MNEs) are correctly represented in the nic state. @@ -110,7 +112,7 @@ def test_describe_state_nmne(uc2_network): } # Apply the NMNE configuration settings - set_nmne_config(nmne_config) + NIC.nmne_config = NMNEConfig(**nmne_config) # Assert that initially, there are no captured MNEs on both web and database servers web_server_nic_state = web_server_nic.describe_state() @@ -190,7 +192,7 @@ def test_describe_state_nmne(uc2_network): assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 4}}}} -def test_capture_nmne_observations(uc2_network): +def test_capture_nmne_observations(uc2_network: Network): """ Tests the NICObservation class's functionality within a simulated network environment. @@ -219,7 +221,7 @@ def test_capture_nmne_observations(uc2_network): } # Apply the NMNE configuration settings - set_nmne_config(nmne_config) + NIC.nmne_config = NMNEConfig(**nmne_config) # Define observations for the NICs of the database and web servers db_server_nic_obs = NICObservation(where=["network", "nodes", "database_server", "NICs", 1], include_nmne=True)