Files
PrimAITE/tests/integration_tests/network/test_broadcast.py

197 lines
6.8 KiB
Python

# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, List, Tuple
import pytest
from pydantic import Field
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.services.service import Service
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.port import PORT_LOOKUP
class BroadcastTestService(Service, discriminator="broadcast-test-service"):
"""A service for sending broadcast and unicast messages over a network."""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for BroadcastTestService."""
type: str = "broadcast-test-service"
config: "BroadcastTestService.ConfigSchema" = Field(default_factory=lambda: BroadcastTestService.ConfigSchema())
def __init__(self, **kwargs):
# Set default service properties for broadcasting
kwargs["name"] = "BroadcastService"
kwargs["port"] = PORT_LOOKUP["HTTP"]
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
def describe_state(self) -> Dict:
# Implement state description for the service
pass
def unicast(self, ip_address: IPv4Address):
# Send a unicast payload to a specific IP address
super().send(
payload="unicast",
dest_ip_address=ip_address,
dest_port=PORT_LOOKUP["HTTP"],
)
def broadcast(self, ip_network: IPv4Network):
# Send a broadcast payload to an entire IP network
super().send(
payload="broadcast", dest_ip_address=ip_network, dest_port=PORT_LOOKUP["HTTP"], ip_protocol=self.protocol
)
class BroadcastTestClient(Application, discriminator="broadcast-test-client"):
"""A client application to receive broadcast and unicast messages."""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for BroadcastTestClient."""
type: str = "broadcast-test-client"
config: ConfigSchema = Field(default_factory=lambda: BroadcastTestClient.ConfigSchema())
payloads_received: List = []
def __init__(self, **kwargs):
# Set default client properties
kwargs["name"] = "broadcast-test-client"
kwargs["port"] = PORT_LOOKUP["HTTP"]
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
def describe_state(self) -> Dict:
# Implement state description for the application
pass
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
# Append received payloads to the list and print a message
self.payloads_received.append(payload)
print(f"Payload: {payload} received on node {self.sys_log.hostname}")
@pytest.fixture(scope="function")
def broadcast_network() -> Network:
network = Network()
client_1 = Computer(
hostname="client_1",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
client_1.power_on()
client_1.software_manager.install(BroadcastTestClient)
application_1 = client_1.software_manager.software["broadcast-test-client"]
application_1.run()
client_2 = Computer(
hostname="client_2",
ip_address="192.168.1.3",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
client_2.power_on()
client_2.software_manager.install(BroadcastTestClient)
application_2 = client_2.software_manager.software["broadcast-test-client"]
application_2.run()
server_1 = Server(
hostname="server_1",
ip_address="192.168.1.1",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
server_1.power_on()
server_1.software_manager.install(BroadcastTestService)
service: BroadcastTestService = server_1.software_manager.software["BroadcastService"]
service.start()
switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0)
switch_1.power_on()
network.connect(endpoint_a=client_1.network_interface[1], endpoint_b=switch_1.network_interface[1])
network.connect(endpoint_a=client_2.network_interface[1], endpoint_b=switch_1.network_interface[2])
network.connect(endpoint_a=server_1.network_interface[1], endpoint_b=switch_1.network_interface[3])
return network
@pytest.fixture(scope="function")
def broadcast_service_and_clients(
broadcast_network,
) -> Tuple[BroadcastTestService, BroadcastTestClient, BroadcastTestClient]:
client_1: BroadcastTestClient = broadcast_network.get_node_by_hostname("client_1").software_manager.software[
"broadcast-test-client"
]
client_2: BroadcastTestClient = broadcast_network.get_node_by_hostname("client_2").software_manager.software[
"broadcast-test-client"
]
service: BroadcastTestService = broadcast_network.get_node_by_hostname("server_1").software_manager.software[
"broadcast-service"
]
return service, client_1, client_2
def test_broadcast_correct_subnet(broadcast_service_and_clients):
service, client_1, client_2 = broadcast_service_and_clients
assert not client_1.payloads_received
assert not client_2.payloads_received
service.broadcast(IPv4Network("192.168.1.0/24"))
assert client_1.payloads_received == ["broadcast"]
assert client_2.payloads_received == ["broadcast"]
def test_broadcast_incorrect_subnet(broadcast_service_and_clients):
service, client_1, client_2 = broadcast_service_and_clients
assert not client_1.payloads_received
assert not client_2.payloads_received
service.broadcast(IPv4Network("192.168.2.0/24"))
assert not client_1.payloads_received
assert not client_2.payloads_received
def test_unicast_correct_address(broadcast_service_and_clients):
service, client_1, client_2 = broadcast_service_and_clients
assert not client_1.payloads_received
assert not client_2.payloads_received
service.unicast(IPv4Address("192.168.1.2"))
assert client_1.payloads_received == ["unicast"]
assert not client_2.payloads_received
def test_unicast_incorrect_address(broadcast_service_and_clients):
service, client_1, client_2 = broadcast_service_and_clients
assert not client_1.payloads_received
assert not client_2.payloads_received
service.unicast(IPv4Address("192.168.2.2"))
assert not client_1.payloads_received
assert not client_2.payloads_received