Files
PrimAITE/tests/integration_tests/game_layer/observations/test_nic_observations.py

98 lines
3.2 KiB
Python

from pathlib import Path
from typing import Union
import pytest
import yaml
from gymnasium import spaces
from primaite.game.agent.observations.nic_observations import NICObservation
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.nmne import CAPTURE_NMNE
from primaite.simulator.sim_container import Simulation
from tests import TEST_ASSETS_ROOT
BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
"""Returns a PrimaiteGame object which loads the contents of a given yaml path."""
with open(config_path, "r") as f:
cfg = yaml.safe_load(f)
return PrimaiteGame.from_config(cfg)
@pytest.fixture(scope="function")
def simulation(example_network) -> Simulation:
sim = Simulation()
# set simulation network as example network
sim.network = example_network
return sim
def test_nic(simulation):
"""Test the NIC observation."""
pc: Computer = simulation.network.get_node_by_hostname("client_1")
nic: NIC = pc.network_interface[1]
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1])
assert nic_obs.space["nic_status"] == spaces.Discrete(3)
assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4)
assert nic_obs.space["NMNE"]["outbound"] == spaces.Discrete(4)
observation_state = nic_obs.observe(simulation.describe_state())
assert observation_state.get("nic_status") == 1 # enabled
assert observation_state.get("NMNE") is not None
assert observation_state["NMNE"].get("inbound") == 0
assert observation_state["NMNE"].get("outbound") == 0
nic.disable()
observation_state = nic_obs.observe(simulation.describe_state())
assert observation_state.get("nic_status") == 2 # disabled
def test_nic_categories(simulation):
"""Test the NIC observation nmne count categories."""
pc: Computer = simulation.network.get_node_by_hostname("client_1")
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1])
assert nic_obs.high_nmne_threshold == 10 # default
assert nic_obs.med_nmne_threshold == 5 # default
assert nic_obs.low_nmne_threshold == 0 # default
nic_obs = NICObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=3,
med_nmne_threshold=6,
high_nmne_threshold=9,
)
assert nic_obs.high_nmne_threshold == 9
assert nic_obs.med_nmne_threshold == 6
assert nic_obs.low_nmne_threshold == 3
with pytest.raises(Exception):
# should throw an error
NICObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=9,
med_nmne_threshold=6,
high_nmne_threshold=9,
)
with pytest.raises(Exception):
# should throw an error
NICObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=3,
med_nmne_threshold=9,
high_nmne_threshold=9,
)