Files
PrimAITE/tests/integration_tests/game_layer/observations/test_nic_observations.py

227 lines
8.4 KiB
Python

# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from pathlib import Path
from typing import Union
import pytest
import yaml
from gymnasium import spaces
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.observations.nic_observations import NICObservation
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.base import NetworkInterface
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.web_server.web_server import WebServer
from tests import TEST_ASSETS_ROOT
BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
"""Returns a PrimaiteGame object which loads the contents of a given yaml path."""
with open(config_path, "r") as f:
cfg = yaml.safe_load(f)
return PrimaiteGame.from_config(cfg)
@pytest.fixture(scope="function")
def simulation(example_network) -> Simulation:
sim = Simulation()
# set simulation network as example network
sim.network = example_network
computer: Computer = example_network.get_node_by_hostname("client_1")
server: Server = example_network.get_node_by_hostname("server_1")
web_browser: WebBrowser = computer.software_manager.software.get("web-browser")
web_browser.run()
# Install DNS Client service on computer
computer.software_manager.install(DNSClient)
dns_client: DNSClient = computer.software_manager.software.get("dns-client")
# set dns server
dns_client.dns_server = server.network_interface[1].ip_address
# Install Web Server service on server
server.software_manager.install(WebServer)
web_server_service: WebServer = server.software_manager.software.get("web-server")
web_server_service.start()
# Install DNS Server service on server
server.software_manager.install(DNSServer)
dns_server: DNSServer = server.software_manager.software.get("dns-server")
# register arcd.com to DNS
dns_server.dns_register(
domain_name="arcd.com",
domain_ip_address=server.network_interfaces[next(iter(server.network_interfaces))].ip_address,
)
return sim
def test_nic(simulation):
"""Test the NIC observation."""
pc: Computer = simulation.network.get_node_by_hostname("client_1")
nic: NIC = pc.network_interface[1]
nic_obs = NICObservation(where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=True)
# Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs
nmne_config = {
"capture_nmne": True, # Enable the capture of MNEs
"nmne_capture_keywords": [
"DELETE",
"ENCRYPT",
], # Specify "DELETE/ENCRYPT" SQL command as a keyword for MNE detection
}
# Apply the NMNE configuration settings
NetworkInterface.nmne_config = NMNEConfig(**nmne_config)
assert nic_obs.space["nic_status"] == spaces.Discrete(3)
assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4)
assert nic_obs.space["NMNE"]["outbound"] == spaces.Discrete(4)
observation_state = nic_obs.observe(simulation.describe_state())
assert observation_state.get("nic_status") == 1 # enabled
assert observation_state.get("NMNE") is not None
assert observation_state["NMNE"].get("inbound") == 0
assert observation_state["NMNE"].get("outbound") == 0
nic.disable()
observation_state = nic_obs.observe(simulation.describe_state())
assert observation_state.get("nic_status") == 2 # disabled
def test_nic_categories(simulation):
"""Test the NIC observation nmne count categories."""
pc: Computer = simulation.network.get_node_by_hostname("client_1")
nic_obs = NICObservation(where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=True)
assert nic_obs.high_nmne_threshold == 10 # default
assert nic_obs.med_nmne_threshold == 5 # default
assert nic_obs.low_nmne_threshold == 0 # default
@pytest.mark.skip(reason="Feature not implemented yet")
def test_config_nic_categories(simulation):
pc: Computer = simulation.network.get_node_by_hostname("client_1")
nic_obs = NICObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=3,
med_nmne_threshold=6,
high_nmne_threshold=9,
include_nmne=True,
)
assert nic_obs.high_nmne_threshold == 9
assert nic_obs.med_nmne_threshold == 6
assert nic_obs.low_nmne_threshold == 3
with pytest.raises(Exception):
# should throw an error
NICObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=9,
med_nmne_threshold=6,
high_nmne_threshold=9,
include_nmne=True,
)
with pytest.raises(Exception):
# should throw an error
NICObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=3,
med_nmne_threshold=9,
high_nmne_threshold=9,
include_nmne=True,
)
def test_nic_monitored_traffic(simulation):
monitored_traffic = {
"icmp": ["NONE"],
"tcp": [
53,
],
}
pc: Computer = simulation.network.get_node_by_hostname("client_1")
pc2: Computer = simulation.network.get_node_by_hostname("client_2")
nic_obs = NICObservation(
where=["network", "nodes", pc.config.hostname, "NICs", 1],
include_nmne=False,
monitored_traffic=monitored_traffic,
)
simulation.pre_timestep(0) # apply timestep to whole sim
simulation.apply_timestep(0) # apply timestep to whole sim
traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC")
assert traffic_obs["icmp"]["inbound"] == 0
assert traffic_obs["icmp"]["outbound"] == 0
# send a ping
assert pc.ping(target_ip_address=pc2.network_interface[1].ip_address)
traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC")
assert traffic_obs["icmp"]["inbound"] == 1
assert traffic_obs["icmp"]["outbound"] == 1
simulation.pre_timestep(1) # apply timestep to whole sim
simulation.apply_timestep(1) # apply timestep to whole sim
traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC")
assert traffic_obs["icmp"]["inbound"] == 0
assert traffic_obs["icmp"]["outbound"] == 0
assert traffic_obs["tcp"][53]["inbound"] == 0
assert traffic_obs["tcp"][53]["outbound"] == 0
# send a database query
browser: WebBrowser = pc.software_manager.software.get("web-browser")
browser.config.target_url = f"http://arcd.com/"
browser.get_webpage()
traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC")
assert traffic_obs["icmp"]["inbound"] == 0
assert traffic_obs["icmp"]["outbound"] == 0
assert traffic_obs["tcp"][53]["inbound"] == 1
assert traffic_obs["tcp"][53]["outbound"] == 1 # getting a webpage sent a dns request out
simulation.pre_timestep(2) # apply timestep to whole sim
simulation.apply_timestep(2) # apply timestep to whole sim
traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC")
assert traffic_obs["icmp"]["inbound"] == 0
assert traffic_obs["icmp"]["outbound"] == 0
assert traffic_obs["tcp"][53]["inbound"] == 0
assert traffic_obs["tcp"][53]["outbound"] == 0
def test_nic_monitored_traffic_config():
"""Test that the config loads the monitored traffic config correctly."""
game: PrimaiteGame = load_config(BASIC_CONFIG)
# should have icmp and DNS
defender_agent: ProxyAgent = game.agents.get("defender")
cur_obs = defender_agent.observation_manager.current_observation
assert cur_obs["NODES"]["HOST0"]["NICS"][1]["TRAFFIC"] == {
"icmp": {"inbound": 0, "outbound": 0},
"tcp": {53: {"inbound": 0, "outbound": 0}},
}