diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 002ee4da..c5da8767 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -1,18 +1,21 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from __future__ import annotations -from typing import Dict, Optional +from typing import ClassVar, Dict, Optional from gymnasium import spaces from gymnasium.core import ObsType from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.network.transmission.transport_layer import Port class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): """Status information about a network interface within the simulation environment.""" + capture_nmne: ClassVar[bool] = NMNEConfig().capture_nmne + "A dataclass defining malicious network events to be captured." class ConfigSchema(AbstractObservation.ConfigSchema): """Configuration schema for NICObservation.""" @@ -164,7 +167,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): for port in self.monitored_traffic[protocol]: obs["TRAFFIC"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0} - if self.include_nmne: + if self.capture_nmne and self.include_nmne: obs.update({"NMNE": {}}) direction_dict = nic_state["nmne"].get("direction", {}) inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 045b2467..9c0f49af 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -10,6 +10,7 @@ from primaite import DEFAULT_BANDWIDTH, getLogger from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent from primaite.game.agent.observations.observation_manager import ObservationManager +from primaite.game.agent.observations import NICObservation from primaite.game.agent.rewards import RewardFunction, SharedReward from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent @@ -275,6 +276,7 @@ class PrimaiteGame: links_cfg = network_config.get("links", []) # Set the NMNE capture config NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {})) + NICObservation.capture_nmne = NMNEConfig(**network_config.get("nmne_config", {})).capture_nmne for node_cfg in nodes_cfg: n_type = node_cfg["type"]