Files
PrimAITE/tests/integration_tests/extensions/nodes/giga_switch.py
2025-02-14 11:38:15 +00:00

126 lines
5.2 KiB
Python

# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from typing import Dict, Literal
from prettytable import MARKDOWN, PrettyTable
from primaite import _LOGGER
from primaite.exceptions import NetworkError
from primaite.simulator.network.hardware.base import Link
from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode
from primaite.simulator.network.hardware.nodes.network.switch import SwitchPort
from primaite.simulator.network.transmission.data_link_layer import Frame
class GigaSwitch(NetworkNode, discriminator="gigaswitch"):
"""
A class representing a Layer 2 network switch.
:ivar num_ports: The number of ports on the switch. Default is 24.
"""
class ConfigSchema(NetworkNode.ConfigSchema):
type: Literal["gigaswitch"] = "gigaswitch"
num_ports: int = 24
"The number of ports on the switch."
network_interfaces: Dict[str, SwitchPort] = {}
"The SwitchPorts on the Switch."
network_interface: Dict[int, SwitchPort] = {}
"The SwitchPorts on the Switch by port id."
mac_address_table: Dict[str, SwitchPort] = {}
"A MAC address table mapping destination MAC addresses to corresponding SwitchPorts."
def __init__(self, **kwargs):
print("--- Extended Component: GigaSwitch ---")
super().__init__(**kwargs)
for i in range(1, self.num_ports + 1):
self.connect_nic(SwitchPort())
def _install_system_software(self):
pass
def show(self, markdown: bool = False):
"""
Prints a table of the SwitchPorts on the Switch.
:param markdown: If True, outputs the table in markdown format. Default is False.
"""
table = PrettyTable(["Port", "MAC Address", "Speed", "Status"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.config.hostname} Switch Ports"
for port_num, port in self.network_interface.items():
table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"])
print(table)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
:return: Current state of this object and child objects.
"""
state = super().describe_state()
state["ports"] = {port_num: port.describe_state() for port_num, port in self.network_interface.items()}
state["num_ports"] = self.num_ports # redundant?
state["mac_address_table"] = {mac: port.port_num for mac, port in self.mac_address_table.items()}
return state
def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort):
"""
Private method to add an entry to the MAC address table.
:param mac_address: MAC address to be added.
:param switch_port: Corresponding SwitchPort object.
"""
mac_table_port = self.mac_address_table.get(mac_address)
if not mac_table_port:
self.mac_address_table[mac_address] = switch_port
self.sys_log.info(f"Added MAC table entry: Port {switch_port.port_num} -> {mac_address}")
else:
if mac_table_port != switch_port:
self.mac_address_table.pop(mac_address)
self.sys_log.info(f"Removed MAC table entry: Port {mac_table_port.port_num} -> {mac_address}")
self._add_mac_table_entry(mac_address, switch_port)
def receive_frame(self, frame: Frame, from_network_interface: SwitchPort):
"""
Forward a frame to the appropriate port based on the destination MAC address.
:param frame: The Frame being received.
:param from_network_interface: The SwitchPort that received the frame.
"""
src_mac = frame.ethernet.src_mac_addr
dst_mac = frame.ethernet.dst_mac_addr
self._add_mac_table_entry(src_mac, from_network_interface)
outgoing_port = self.mac_address_table.get(dst_mac)
if outgoing_port and dst_mac.lower() != "ff:ff:ff:ff:ff:ff":
outgoing_port.send_frame(frame)
else:
# If the destination MAC is not in the table, flood to all ports except incoming
for port in self.network_interface.values():
if port.enabled and port != from_network_interface:
port.send_frame(frame)
def disconnect_link_from_port(self, link: Link, port_number: int):
"""
Disconnect a given link from the specified port number on the switch.
:param link: The Link object to be disconnected.
:param port_number: The port number on the switch from where the link should be disconnected.
:raise NetworkError: When an invalid port number is provided or the link does not match the connection.
"""
port = self.network_interface.get(port_number)
if port is None:
msg = f"Invalid port number {port_number} on the switch"
_LOGGER.error(msg)
raise NetworkError(msg)
if port._connected_link != link:
msg = f"The link does not match the connection at port number {port_number}"
_LOGGER.error(msg)
raise NetworkError(msg)
port.disconnect_link()