# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from pathlib import Path from typing import Union import pytest import yaml from gymnasium import spaces from primaite.game.agent.interface import ProxyAgent from primaite.game.agent.observations.nic_observations import NICObservation from primaite.game.game import PrimaiteGame from primaite.simulator.network.hardware.base import NetworkInterface from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.web_server.web_server import WebServer from tests import TEST_ASSETS_ROOT BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" def load_config(config_path: Union[str, Path]) -> PrimaiteGame: """Returns a PrimaiteGame object which loads the contents of a given yaml path.""" with open(config_path, "r") as f: cfg = yaml.safe_load(f) return PrimaiteGame.from_config(cfg) @pytest.fixture(scope="function") def simulation(example_network) -> Simulation: sim = Simulation() # set simulation network as example network sim.network = example_network computer: Computer = example_network.get_node_by_hostname("client_1") server: Server = example_network.get_node_by_hostname("server_1") web_browser: WebBrowser = computer.software_manager.software.get("web-browser") web_browser.run() # Install DNS Client service on computer computer.software_manager.install(DNSClient) dns_client: DNSClient = computer.software_manager.software.get("dns-client") # set dns server dns_client.dns_server = server.network_interface[1].ip_address # Install Web Server service on server server.software_manager.install(WebServer) web_server_service: WebServer = server.software_manager.software.get("web-server") web_server_service.start() # Install DNS Server service on server server.software_manager.install(DNSServer) dns_server: DNSServer = server.software_manager.software.get("dns-server") # register arcd.com to DNS dns_server.dns_register( domain_name="arcd.com", domain_ip_address=server.network_interfaces[next(iter(server.network_interfaces))].ip_address, ) return sim def test_nic(simulation): """Test the NIC observation.""" pc: Computer = simulation.network.get_node_by_hostname("client_1") nic: NIC = pc.network_interface[1] nic_obs = NICObservation(where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=True) # The Simulation object created by the fixture also creates the # NICObservation class with the NICObservation.capture_nmnme class variable # set to False. Under normal (non-test) circumstances this class variable # is set from a config file such as data_manipulation.yaml. So although # capture_nmne is set to True in the NetworkInterface class it's still False # in the NICObservation class so we set it now. nic_obs.capture_nmne = True # The Simulation object created by the fixture also creates the # NICObservation class with the NICObservation.capture_nmnme class variable # set to False. Under normal (non-test) circumstances this class variable # is set from a config file such as data_manipulation.yaml. So although # capture_nmne is set to True in the NetworkInterface class it's still False # in the NICObservation class so we set it now. nic_obs.capture_nmne = True # Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs nmne_config = { "capture_nmne": True, # Enable the capture of MNEs "nmne_capture_keywords": [ "DELETE", "ENCRYPT", ], # Specify "DELETE/ENCRYPT" SQL command as a keyword for MNE detection } # Apply the NMNE configuration settings NetworkInterface.nmne_config = NMNEConfig(**nmne_config) assert nic_obs.space["nic_status"] == spaces.Discrete(3) assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4) assert nic_obs.space["NMNE"]["outbound"] == spaces.Discrete(4) observation_state = nic_obs.observe(simulation.describe_state()) assert observation_state.get("nic_status") == 1 # enabled assert observation_state.get("NMNE") is not None assert observation_state["NMNE"].get("inbound") == 0 assert observation_state["NMNE"].get("outbound") == 0 nic.disable() observation_state = nic_obs.observe(simulation.describe_state()) assert observation_state.get("nic_status") == 2 # disabled def test_nic_categories(simulation): """Test the NIC observation nmne count categories.""" pc: Computer = simulation.network.get_node_by_hostname("client_1") nic_obs = NICObservation(where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=True) assert nic_obs.high_nmne_threshold == 10 # default assert nic_obs.med_nmne_threshold == 5 # default assert nic_obs.low_nmne_threshold == 0 # default def test_config_nic_categories(simulation): pc: Computer = simulation.network.get_node_by_hostname("client_1") nic_obs = NICObservation( where=["network", "nodes", pc.config.hostname, "NICs", 1], thresholds={"nmne": {"low": 3, "medium": 6, "high": 9}}, include_nmne=True, ) assert nic_obs.high_nmne_threshold == 9 assert nic_obs.med_nmne_threshold == 6 assert nic_obs.low_nmne_threshold == 3 with pytest.raises(Exception): # should throw an error NICObservation( where=["network", "nodes", pc.config.hostname, "NICs", 1], thresholds={"nmne": {"low": 9, "medium": 6, "high": 9}}, include_nmne=True, ) with pytest.raises(Exception): # should throw an error NICObservation( where=["network", "nodes", pc.config.hostname, "NICs", 1], thresholds={"nmne": {"low": 3, "medium": 9, "high": 9}}, include_nmne=True, ) def test_nic_monitored_traffic(simulation): monitored_traffic = { "icmp": ["NONE"], "tcp": [ 53, ], } pc: Computer = simulation.network.get_node_by_hostname("client_1") pc2: Computer = simulation.network.get_node_by_hostname("client_2") nic_obs = NICObservation( where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=False, monitored_traffic=monitored_traffic, ) simulation.pre_timestep(0) # apply timestep to whole sim simulation.apply_timestep(0) # apply timestep to whole sim traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC") assert traffic_obs["icmp"]["inbound"] == 0 assert traffic_obs["icmp"]["outbound"] == 0 # send a ping assert pc.ping(target_ip_address=pc2.network_interface[1].ip_address) traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC") assert traffic_obs["icmp"]["inbound"] == 1 assert traffic_obs["icmp"]["outbound"] == 1 simulation.pre_timestep(1) # apply timestep to whole sim simulation.apply_timestep(1) # apply timestep to whole sim traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC") assert traffic_obs["icmp"]["inbound"] == 0 assert traffic_obs["icmp"]["outbound"] == 0 assert traffic_obs["tcp"][53]["inbound"] == 0 assert traffic_obs["tcp"][53]["outbound"] == 0 # send a database query browser: WebBrowser = pc.software_manager.software.get("web-browser") browser.config.target_url = f"http://arcd.com/" browser.get_webpage() traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC") assert traffic_obs["icmp"]["inbound"] == 0 assert traffic_obs["icmp"]["outbound"] == 0 assert traffic_obs["tcp"][53]["inbound"] == 1 assert traffic_obs["tcp"][53]["outbound"] == 1 # getting a webpage sent a dns request out simulation.pre_timestep(2) # apply timestep to whole sim simulation.apply_timestep(2) # apply timestep to whole sim traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC") assert traffic_obs["icmp"]["inbound"] == 0 assert traffic_obs["icmp"]["outbound"] == 0 assert traffic_obs["tcp"][53]["inbound"] == 0 assert traffic_obs["tcp"][53]["outbound"] == 0 def test_nic_monitored_traffic_config(): """Test that the config loads the monitored traffic config correctly.""" game: PrimaiteGame = load_config(BASIC_CONFIG) # should have icmp and DNS defender_agent: ProxyAgent = game.agents.get("defender") cur_obs = defender_agent.observation_manager.current_observation assert cur_obs["NODES"]["HOST0"]["NICS"][1]["TRAFFIC"] == { "icmp": {"inbound": 0, "outbound": 0}, "tcp": {53: {"inbound": 0, "outbound": 0}}, }