diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index c5da8767..ed2bb7f9 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -14,8 +14,9 @@ from primaite.simulator.network.transmission.transport_layer import Port class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): """Status information about a network interface within the simulation environment.""" + capture_nmne: ClassVar[bool] = NMNEConfig().capture_nmne - "A dataclass defining malicious network events to be captured." + "A Boolean specifying whether malicious network events should be captured." class ConfigSchema(AbstractObservation.ConfigSchema): """Configuration schema for NICObservation.""" diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 9afdbea6..64cdf63b 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -9,8 +9,8 @@ from pydantic import BaseModel, ConfigDict from primaite import DEFAULT_BANDWIDTH, getLogger from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent -from primaite.game.agent.observations.observation_manager import ObservationManager from primaite.game.agent.observations import NICObservation +from primaite.game.agent.observations.observation_manager import ObservationManager from primaite.game.agent.rewards import RewardFunction, SharedReward from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index ef789ba7..ced598f0 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -77,6 +77,14 @@ def test_nic(simulation): nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True) + # The Simulation object created by the fixture also creates the + # NICObservation class with the NICObservation.capture_nmnme class variable + # set to False. Under normal (non-test) circumstances this class variable + # is set from a config file such as data_manipulation.yaml. So although + # capture_nmne is set to True in the NetworkInterface class it's still False + # in the NICObservation class so we set it now. + nic_obs.capture_nmne = True + # Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs nmne_config = { "capture_nmne": True, # Enable the capture of MNEs diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py index debf5b1c..1499df9a 100644 --- a/tests/integration_tests/network/test_capture_nmne.py +++ b/tests/integration_tests/network/test_capture_nmne.py @@ -1,5 +1,11 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from itertools import product + +import yaml + +from primaite.config.load import data_manipulation_config_path from primaite.game.agent.observations.nic_observations import NICObservation +from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server @@ -277,3 +283,19 @@ def test_capture_nmne_observations(uc2_network: Network): assert web_nic_obs["outbound"] == expected_nmne assert db_nic_obs["inbound"] == expected_nmne uc2_network.apply_timestep(timestep=0) + + +def test_nmne_parameter_settings(): + """ + Check that the four permutations of the values of capture_nmne and + include_nmne work as expected. + """ + + with open(data_manipulation_config_path(), "r") as f: + cfg = yaml.safe_load(f) + + DEFENDER = 3 + for capture, include in product([True, False], [True, False]): + cfg["simulation"]["network"]["nmne_config"]["capture_nmne"] = capture + cfg["agents"][DEFENDER]["observation_space"]["options"]["components"][0]["options"]["include_nmne"] = include + PrimaiteGymEnv(env_config=cfg)