diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 002ee4da..ed2bb7f9 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -1,19 +1,23 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from __future__ import annotations -from typing import Dict, Optional +from typing import ClassVar, Dict, Optional from gymnasium import spaces from gymnasium.core import ObsType from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.network.transmission.transport_layer import Port class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): """Status information about a network interface within the simulation environment.""" + capture_nmne: ClassVar[bool] = NMNEConfig().capture_nmne + "A Boolean specifying whether malicious network events should be captured." + class ConfigSchema(AbstractObservation.ConfigSchema): """Configuration schema for NICObservation.""" @@ -164,7 +168,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): for port in self.monitored_traffic[protocol]: obs["TRAFFIC"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0} - if self.include_nmne: + if self.capture_nmne and self.include_nmne: obs.update({"NMNE": {}}) direction_dict = nic_state["nmne"].get("direction", {}) inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 0e7b8c23..e683772b 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, ConfigDict from primaite import DEFAULT_BANDWIDTH, getLogger from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent +from primaite.game.agent.observations import NICObservation from primaite.game.agent.observations.observation_manager import ObservationManager from primaite.game.agent.rewards import RewardFunction, SharedReward from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent @@ -279,6 +280,7 @@ class PrimaiteGame: links_cfg = network_config.get("links", []) # Set the NMNE capture config NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {})) + NICObservation.capture_nmne = NMNEConfig(**network_config.get("nmne_config", {})).capture_nmne for node_cfg in nodes_cfg: n_type = node_cfg["type"] diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index ef789ba7..ced598f0 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -77,6 +77,14 @@ def test_nic(simulation): nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True) + # The Simulation object created by the fixture also creates the + # NICObservation class with the NICObservation.capture_nmnme class variable + # set to False. Under normal (non-test) circumstances this class variable + # is set from a config file such as data_manipulation.yaml. So although + # capture_nmne is set to True in the NetworkInterface class it's still False + # in the NICObservation class so we set it now. + nic_obs.capture_nmne = True + # Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs nmne_config = { "capture_nmne": True, # Enable the capture of MNEs diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py index debf5b1c..1499df9a 100644 --- a/tests/integration_tests/network/test_capture_nmne.py +++ b/tests/integration_tests/network/test_capture_nmne.py @@ -1,5 +1,11 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from itertools import product + +import yaml + +from primaite.config.load import data_manipulation_config_path from primaite.game.agent.observations.nic_observations import NICObservation +from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server @@ -277,3 +283,19 @@ def test_capture_nmne_observations(uc2_network: Network): assert web_nic_obs["outbound"] == expected_nmne assert db_nic_obs["inbound"] == expected_nmne uc2_network.apply_timestep(timestep=0) + + +def test_nmne_parameter_settings(): + """ + Check that the four permutations of the values of capture_nmne and + include_nmne work as expected. + """ + + with open(data_manipulation_config_path(), "r") as f: + cfg = yaml.safe_load(f) + + DEFENDER = 3 + for capture, include in product([True, False], [True, False]): + cfg["simulation"]["network"]["nmne_config"]["capture_nmne"] = capture + cfg["agents"][DEFENDER]["observation_space"]["options"]["components"][0]["options"]["include_nmne"] = include + PrimaiteGymEnv(env_config=cfg)