diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index cd0180db..2e7ee735 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -16,7 +16,6 @@ from primaite.game.agent.scripted_agents.probabilistic_agent import Probabilisti from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent from primaite.game.agent.scripted_agents.tap001 import TAP001 from primaite.game.science import graph_has_cycle, topological_sort -from primaite.session.io import store_nmne_config from primaite.simulator import SIM_OUTPUT from primaite.simulator.network.airspace import AirSpaceFrequency from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState @@ -27,6 +26,7 @@ from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter +from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.application import Application @@ -265,7 +265,7 @@ class PrimaiteGame: nodes_cfg = network_config.get("nodes", []) links_cfg = network_config.get("links", []) # Set the NMNE capture config - NetworkInterface.nmne_config = store_nmne_config(network_config.get("nmne_config", {})) + NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {})) for node_cfg in nodes_cfg: n_type = node_cfg["type"] diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index c634e835..78d7cb3c 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -131,48 +131,3 @@ class PrimaiteIO: new = cls(settings=cls.Settings(**config)) return new - - -class NMNEConfig(BaseModel): - """Store all the information to perform NMNE operations.""" - - capture_nmne: bool = False - """Indicates whether Malicious Network Events (MNEs) should be captured.""" - nmne_capture_keywords: List[str] = [] - """List of keywords to identify malicious network events.""" - capture_by_direction: bool = True - """Captures should be organized by traffic direction (inbound/outbound).""" - capture_by_ip_address: bool = False - """Captures should be organized by source or destination IP address.""" - capture_by_protocol: bool = False - """Captures should be organized by network protocol (e.g., TCP, UDP).""" - capture_by_port: bool = False - """Captures should be organized by source or destination port.""" - capture_by_keyword: bool = False - """Captures should be filtered and categorised based on specific keywords.""" - - -def store_nmne_config(nmne_config: Dict) -> NMNEConfig: - """ - Store configuration for capturing Malicious Network Events (MNEs). - - This function updates settings related to NMNE capture, stored in NMNEConfig including whether - to capture NMNEs and the keywords to use for identifying NMNEs. - - The function ensures that the settings are updated only if they are provided in the - `nmne_config` dictionary, and maintains type integrity by relying on pydantic validators. - - :param nmne_config: A dictionary containing the NMNE configuration settings. Possible keys - include: - "capture_nmne" (bool) to indicate whether NMNEs should be captured; - "nmne_capture_keywords" (list of strings) to specify keywords for NMNE identification. - :rvar class with data read from config file. - """ - nmne_capture_keywords: List[str] = [] - # Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect - capture_nmne = nmne_config.get("capture_nmne", False) - - # Update the NMNE capture keywords, appending new keywords if provided - nmne_capture_keywords += nmne_config.get("nmne_capture_keywords", []) - - return NMNEConfig(capture_nmne=capture_nmne, nmne_capture_keywords=nmne_capture_keywords) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index aafdbe5c..50549389 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -14,12 +14,12 @@ from pydantic import BaseModel, Field from primaite import getLogger from primaite.exceptions import NetworkError from primaite.interface.request import RequestResponse -from primaite.session.io import NMNEConfig from primaite.simulator import SIM_OUTPUT from primaite.simulator.core import RequestFormat, RequestManager, RequestPermissionValidator, RequestType, SimComponent from primaite.simulator.domain.account import Account from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.network.transmission.data_link_layer import Frame from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.system.applications.application import Application diff --git a/src/primaite/simulator/network/nmne.py b/src/primaite/simulator/network/nmne.py new file mode 100644 index 00000000..c9cff5de --- /dev/null +++ b/src/primaite/simulator/network/nmne.py @@ -0,0 +1,25 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import List + +from pydantic import BaseModel, ConfigDict + + +class NMNEConfig(BaseModel): + """Store all the information to perform NMNE operations.""" + + model_config = ConfigDict(extra="forbid") + + capture_nmne: bool = False + """Indicates whether Malicious Network Events (MNEs) should be captured.""" + nmne_capture_keywords: List[str] = [] + """List of keywords to identify malicious network events.""" + capture_by_direction: bool = True + """Captures should be organized by traffic direction (inbound/outbound).""" + capture_by_ip_address: bool = False + """Captures should be organized by source or destination IP address.""" + capture_by_protocol: bool = False + """Captures should be organized by network protocol (e.g., TCP, UDP).""" + capture_by_port: bool = False + """Captures should be organized by source or destination port.""" + capture_by_keyword: bool = False + """Captures should be filtered and categorised based on specific keywords.""" diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index 7f86d26d..ef789ba7 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -9,11 +9,11 @@ from gymnasium import spaces from primaite.game.agent.interface import ProxyAgent from primaite.game.agent.observations.nic_observations import NICObservation from primaite.game.game import PrimaiteGame -from primaite.session.io import store_nmne_config from primaite.simulator.network.hardware.base import NetworkInterface from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.web_browser import WebBrowser @@ -87,7 +87,7 @@ def test_nic(simulation): } # Apply the NMNE configuration settings - NetworkInterface.nmne_config = store_nmne_config(nmne_config) + NetworkInterface.nmne_config = NMNEConfig(**nmne_config) assert nic_obs.space["nic_status"] == spaces.Discrete(3) assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4) diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py index b4162e58..debf5b1c 100644 --- a/tests/integration_tests/network/test_capture_nmne.py +++ b/tests/integration_tests/network/test_capture_nmne.py @@ -1,9 +1,9 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from primaite.game.agent.observations.nic_observations import NICObservation -from primaite.session.io import store_nmne_config from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection @@ -35,7 +35,7 @@ def test_capture_nmne(uc2_network: Network): } # Apply the NMNE configuration settings - NIC.nmne_config = store_nmne_config(nmne_config) + NIC.nmne_config = NMNEConfig(**nmne_config) # Assert that initially, there are no captured MNEs on both web and database servers assert web_server_nic.nmne == {} @@ -112,7 +112,7 @@ def test_describe_state_nmne(uc2_network: Network): } # Apply the NMNE configuration settings - NIC.nmne_config = store_nmne_config(nmne_config) + NIC.nmne_config = NMNEConfig(**nmne_config) # Assert that initially, there are no captured MNEs on both web and database servers web_server_nic_state = web_server_nic.describe_state() @@ -221,7 +221,7 @@ def test_capture_nmne_observations(uc2_network: Network): } # Apply the NMNE configuration settings - NIC.nmne_config = store_nmne_config(nmne_config) + NIC.nmne_config = NMNEConfig(**nmne_config) # Define observations for the NICs of the database and web servers db_server_nic_obs = NICObservation(where=["network", "nodes", "database_server", "NICs", 1], include_nmne=True)