110 lines
3.8 KiB
Python
110 lines
3.8 KiB
Python
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
|
import pytest
|
|
from gymnasium import spaces
|
|
|
|
from primaite.game.agent.observations.link_observation import LinkObservation
|
|
from primaite.simulator.network.container import Network
|
|
from primaite.simulator.network.hardware.base import Link, Node
|
|
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
|
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
|
from primaite.simulator.network.hardware.nodes.host.server import Server
|
|
from primaite.simulator.network.hardware.nodes.network.switch import Switch
|
|
from primaite.simulator.sim_container import Simulation
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def simulation() -> Simulation:
|
|
sim = Simulation()
|
|
|
|
network = Network()
|
|
|
|
# Create Computer
|
|
computer = Computer(
|
|
hostname="computer",
|
|
ip_address="192.168.1.2",
|
|
subnet_mask="255.255.255.0",
|
|
default_gateway="192.168.1.1",
|
|
start_up_duration=0,
|
|
)
|
|
computer.power_on()
|
|
|
|
# Create Server
|
|
server = Server(
|
|
hostname="server",
|
|
ip_address="192.168.1.3",
|
|
subnet_mask="255.255.255.0",
|
|
default_gateway="192.168.1.1",
|
|
start_up_duration=0,
|
|
)
|
|
server.power_on()
|
|
|
|
# Connect Computer and Server
|
|
network.connect(computer.network_interface[1], server.network_interface[1])
|
|
|
|
# Should be linked
|
|
assert next(iter(network.links.values())).is_up
|
|
|
|
assert computer.ping(server.network_interface.get(1).ip_address)
|
|
|
|
# set simulation network as example network
|
|
sim.network = network
|
|
|
|
return sim
|
|
|
|
|
|
def test_link_observation():
|
|
"""Check the shape and contents of the link observation."""
|
|
net = Network()
|
|
sim = Simulation(network=net)
|
|
switch: Switch = Switch.from_config(
|
|
config={"type": "switch", "hostname": "switch", "num_ports": 5, "operating_state": "ON"}
|
|
)
|
|
computer_1: Computer = Computer.from_config(
|
|
config={
|
|
"type": "computer",
|
|
"hostname": "computer_1",
|
|
"ip_address": "10.0.0.1",
|
|
"subnet_mask": "255.255.255.0",
|
|
"start_up_duration": 0,
|
|
}
|
|
)
|
|
computer_2: Computer = Computer.from_config(
|
|
config={
|
|
"type": "computer",
|
|
"hostname": "computer_2",
|
|
"ip_address": "10.0.0.2",
|
|
"subnet_mask": "255.255.255.0",
|
|
"start_up_duration": 0,
|
|
}
|
|
)
|
|
computer_1.power_on()
|
|
computer_2.power_on()
|
|
link_1 = net.connect(switch.network_interface[1], computer_1.network_interface[1])
|
|
link_2 = net.connect(switch.network_interface[2], computer_2.network_interface[1])
|
|
assert link_1 is not None
|
|
assert link_2 is not None
|
|
|
|
link_1_observation = LinkObservation(where=["network", "links", "switch:eth-1<->computer_1:eth-1"])
|
|
link_2_observation = LinkObservation(where=["network", "links", "switch:eth-2<->computer_2:eth-1"])
|
|
|
|
state = sim.describe_state()
|
|
link_1_obs = link_1_observation.observe(state)
|
|
link_2_obs = link_2_observation.observe(state)
|
|
assert "PROTOCOLS" in link_1_obs
|
|
assert "PROTOCOLS" in link_2_obs
|
|
assert "ALL" in link_1_obs["PROTOCOLS"]
|
|
assert "ALL" in link_2_obs["PROTOCOLS"]
|
|
assert link_1_observation.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11)
|
|
assert link_2_observation.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11)
|
|
assert link_1_obs["PROTOCOLS"]["ALL"] == 0
|
|
assert link_2_obs["PROTOCOLS"]["ALL"] == 0
|
|
|
|
# Test that the link observation is updated when a packet is sent
|
|
computer_1.ping("10.0.0.2")
|
|
computer_2.ping("10.0.0.1")
|
|
state = sim.describe_state()
|
|
link_1_obs = link_1_observation.observe(state)
|
|
link_2_obs = link_2_observation.observe(state)
|
|
assert link_1_obs["PROTOCOLS"]["ALL"] > 0
|
|
assert link_2_obs["PROTOCOLS"]["ALL"] > 0
|