#2829: Add check for capture_nmne
This commit is contained in:
@@ -1,18 +1,21 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional
|
||||
from typing import ClassVar, Dict, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
from primaite.simulator.network.nmne import NMNEConfig
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
|
||||
|
||||
class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
"""Status information about a network interface within the simulation environment."""
|
||||
capture_nmne: ClassVar[bool] = NMNEConfig().capture_nmne
|
||||
"A dataclass defining malicious network events to be captured."
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for NICObservation."""
|
||||
@@ -164,7 +167,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
for port in self.monitored_traffic[protocol]:
|
||||
obs["TRAFFIC"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0}
|
||||
|
||||
if self.include_nmne:
|
||||
if self.capture_nmne and self.include_nmne:
|
||||
obs.update({"NMNE": {}})
|
||||
direction_dict = nic_state["nmne"].get("direction", {})
|
||||
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
|
||||
|
||||
@@ -10,6 +10,7 @@ from primaite import DEFAULT_BANDWIDTH, getLogger
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.observations import NICObservation
|
||||
from primaite.game.agent.rewards import RewardFunction, SharedReward
|
||||
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
|
||||
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
|
||||
@@ -275,6 +276,7 @@ class PrimaiteGame:
|
||||
links_cfg = network_config.get("links", [])
|
||||
# Set the NMNE capture config
|
||||
NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {}))
|
||||
NICObservation.capture_nmne = NMNEConfig(**network_config.get("nmne_config", {})).capture_nmne
|
||||
|
||||
for node_cfg in nodes_cfg:
|
||||
n_type = node_cfg["type"]
|
||||
|
||||
Reference in New Issue
Block a user