# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest from gymnasium import spaces from primaite.game.agent.observations.link_observation import LinkObservation from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.base import Link, Node from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.sim_container import Simulation @pytest.fixture(scope="function") def simulation() -> Simulation: sim = Simulation() network = Network() # Create Computer computer = Computer( hostname="computer", ip_address="192.168.1.2", subnet_mask="255.255.255.0", default_gateway="192.168.1.1", start_up_duration=0, ) computer.power_on() # Create Server server = Server( hostname="server", ip_address="192.168.1.3", subnet_mask="255.255.255.0", default_gateway="192.168.1.1", start_up_duration=0, ) server.power_on() # Connect Computer and Server network.connect(computer.network_interface[1], server.network_interface[1]) # Should be linked assert next(iter(network.links.values())).is_up assert computer.ping(server.network_interface.get(1).ip_address) # set simulation network as example network sim.network = network return sim def test_link_observation(): """Check the shape and contents of the link observation.""" net = Network() sim = Simulation(network=net) switch: Switch = Switch.from_config( config={"type": "switch", "hostname": "switch", "num_ports": 5, "operating_state": "ON"} ) computer_1: Computer = Computer.from_config( config={ "type": "computer", "hostname": "computer_1", "ip_address": "10.0.0.1", "subnet_mask": "255.255.255.0", "start_up_duration": 0, } ) computer_2: Computer = Computer.from_config( config={ "type": "computer", "hostname": "computer_2", "ip_address": "10.0.0.2", "subnet_mask": "255.255.255.0", "start_up_duration": 0, } ) computer_1.power_on() computer_2.power_on() link_1 = net.connect(switch.network_interface[1], computer_1.network_interface[1]) link_2 = net.connect(switch.network_interface[2], computer_2.network_interface[1]) assert link_1 is not None assert link_2 is not None link_1_observation = LinkObservation(where=["network", "links", "switch:eth-1<->computer_1:eth-1"]) link_2_observation = LinkObservation(where=["network", "links", "switch:eth-2<->computer_2:eth-1"]) state = sim.describe_state() link_1_obs = link_1_observation.observe(state) link_2_obs = link_2_observation.observe(state) assert "PROTOCOLS" in link_1_obs assert "PROTOCOLS" in link_2_obs assert "ALL" in link_1_obs["PROTOCOLS"] assert "ALL" in link_2_obs["PROTOCOLS"] assert link_1_observation.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11) assert link_2_observation.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11) assert link_1_obs["PROTOCOLS"]["ALL"] == 0 assert link_2_obs["PROTOCOLS"]["ALL"] == 0 # Test that the link observation is updated when a packet is sent computer_1.ping("10.0.0.2") computer_2.ping("10.0.0.1") state = sim.describe_state() link_1_obs = link_1_observation.observe(state) link_2_obs = link_2_observation.observe(state) assert link_1_obs["PROTOCOLS"]["ALL"] > 0 assert link_2_obs["PROTOCOLS"]["ALL"] > 0