#2829: Add check for capture_nmne

This commit is contained in:
Nick Todd
2024-09-09 09:12:20 +01:00
parent 08f742b3ec
commit 5ab42ead27
2 changed files with 7 additions and 2 deletions

View File

@@ -1,18 +1,21 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from __future__ import annotations
from typing import Dict, Optional
from typing import ClassVar, Dict, Optional
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.network.transmission.transport_layer import Port
class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
"""Status information about a network interface within the simulation environment."""
capture_nmne: ClassVar[bool] = NMNEConfig().capture_nmne
"A dataclass defining malicious network events to be captured."
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for NICObservation."""
@@ -164,7 +167,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
for port in self.monitored_traffic[protocol]:
obs["TRAFFIC"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0}
if self.include_nmne:
if self.capture_nmne and self.include_nmne:
obs.update({"NMNE": {}})
direction_dict = nic_state["nmne"].get("direction", {})
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})

View File

@@ -10,6 +10,7 @@ from primaite import DEFAULT_BANDWIDTH, getLogger
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.observations import NICObservation
from primaite.game.agent.rewards import RewardFunction, SharedReward
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
@@ -275,6 +276,7 @@ class PrimaiteGame:
links_cfg = network_config.get("links", [])
# Set the NMNE capture config
NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {}))
NICObservation.capture_nmne = NMNEConfig(**network_config.get("nmne_config", {})).capture_nmne
for node_cfg in nodes_cfg:
n_type = node_cfg["type"]