74 lines
2.3 KiB
Python
74 lines
2.3 KiB
Python
import pytest
|
|
from gymnasium import spaces
|
|
|
|
from primaite.game.agent.observations.observations import LinkObservation
|
|
from primaite.simulator.network.container import Network
|
|
from primaite.simulator.network.hardware.base import Link, Node
|
|
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
|
from primaite.simulator.network.hardware.nodes.host.server import Server
|
|
from primaite.simulator.sim_container import Simulation
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def simulation() -> Simulation:
|
|
sim = Simulation()
|
|
|
|
network = Network()
|
|
|
|
# Create Computer
|
|
computer = Computer(
|
|
hostname="computer",
|
|
ip_address="192.168.1.2",
|
|
subnet_mask="255.255.255.0",
|
|
default_gateway="192.168.1.1",
|
|
start_up_duration=0,
|
|
)
|
|
computer.power_on()
|
|
|
|
# Create Server
|
|
server = Server(
|
|
hostname="server",
|
|
ip_address="192.168.1.3",
|
|
subnet_mask="255.255.255.0",
|
|
default_gateway="192.168.1.1",
|
|
start_up_duration=0,
|
|
)
|
|
server.power_on()
|
|
|
|
# Connect Computer and Server
|
|
network.connect(computer.network_interface[1], server.network_interface[1])
|
|
|
|
# Should be linked
|
|
assert next(iter(network.links.values())).is_up
|
|
|
|
assert computer.ping(server.network_interface.get(1).ip_address)
|
|
|
|
# set simulation network as example network
|
|
sim.network = network
|
|
|
|
return sim
|
|
|
|
|
|
def test_link_observation(simulation):
|
|
"""Test the link observation."""
|
|
# get a link
|
|
link: Link = next(iter(simulation.network.links.values()))
|
|
|
|
computer: Computer = simulation.network.get_node_by_hostname("computer")
|
|
server: Server = simulation.network.get_node_by_hostname("server")
|
|
|
|
simulation.apply_timestep(0) # some pings when network was made - reset with apply timestep
|
|
|
|
link_obs = LinkObservation(where=["network", "links", link.uuid])
|
|
|
|
assert link_obs.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11) # test that the spaces are 0-10 including 0 and 10
|
|
|
|
observation_state = link_obs.observe(simulation.describe_state())
|
|
assert observation_state.get("PROTOCOLS") is not None
|
|
assert observation_state["PROTOCOLS"]["ALL"] == 0
|
|
|
|
computer.ping(server.network_interface.get(1).ip_address)
|
|
|
|
observation_state = link_obs.observe(simulation.describe_state())
|
|
assert observation_state["PROTOCOLS"]["ALL"] == 1
|