diff --git a/docs/source/simulation.rst b/docs/source/simulation.rst index 81476998..b9f921c2 100644 --- a/docs/source/simulation.rst +++ b/docs/source/simulation.rst @@ -16,5 +16,5 @@ Contents :maxdepth: 8 simulation_structure - simulation_components/network/physical_layer + simulation_components/network/base_hardware simulation_components/network/transport_to_data_link_layer diff --git a/docs/source/simulation_components/network/physical_layer.rst b/docs/source/simulation_components/network/base_hardware.rst similarity index 98% rename from docs/source/simulation_components/network/physical_layer.rst rename to docs/source/simulation_components/network/base_hardware.rst index 1e87b72e..c3891a6e 100644 --- a/docs/source/simulation_components/network/physical_layer.rst +++ b/docs/source/simulation_components/network/base_hardware.rst @@ -2,8 +2,8 @@ © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -Physical Layer -============== +Base Hardware +============= The physical layer components are models of a ``NIC`` (Network Interface Card) and a ``Link``. These components allow modelling of layer 1 (physical layer) in the OSI model. diff --git a/docs/source/simulation_components/network/transport_to_data_link_layer.rst b/docs/source/simulation_components/network/transport_to_data_link_layer.rst index 8273339c..9332b57c 100644 --- a/docs/source/simulation_components/network/transport_to_data_link_layer.rst +++ b/docs/source/simulation_components/network/transport_to_data_link_layer.rst @@ -34,7 +34,7 @@ specify the priority of IP packets for Quality of Service handling. **ICMPType:** Enumeration of common ICMP (Internet Control Message Protocol) types. It defines various types of ICMP messages used for network troubleshooting and error reporting. -**ICMPHeader:** Models an ICMP header and includes ICMP type, code, identifier, and sequence number. It is used to +**ICMPPacket:** Models an ICMP header and includes ICMP type, code, identifier, and sequence number. It is used to create ICMP packets for network control and error reporting. **IPPacket:** Represents the IP layer of a network frame. It includes source and destination IP addresses, protocol @@ -59,7 +59,7 @@ Data Link Layer (Layer 2) This header is used to identify the physical hardware addresses of devices on a local network. **Frame:** Represents a complete network frame with all layers. It includes an ``EthernetHeader``, an ``IPPacket``, an -optional ``TCPHeader``, ``UDPHeader``, or ``ICMPHeader``, a ``PrimaiteHeader`` and an optional payload. This class +optional ``TCPHeader``, ``UDPHeader``, or ``ICMPPacket``, a ``PrimaiteHeader`` and an optional payload. This class combines all the headers and data to create a complete network frame that can be sent over the network and used in the PrimAITE simulation. diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index c3130116..d684a74b 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -1,6 +1,7 @@ """Core of the PrimAITE Simulator.""" from abc import abstractmethod from typing import Callable, Dict, List +from uuid import uuid4 from pydantic import BaseModel @@ -8,6 +9,14 @@ from pydantic import BaseModel class SimComponent(BaseModel): """Extension of pydantic BaseModel with additional methods that must be defined by all classes in the simulator.""" + uuid: str + "The component UUID." + + def __init__(self, **kwargs): + if not kwargs.get("uuid"): + kwargs["uuid"] = str(uuid4()) + super().__init__(**kwargs) + @abstractmethod def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/network/hardware/__init__.py b/src/primaite/simulator/network/hardware/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py new file mode 100644 index 00000000..054eb1c6 --- /dev/null +++ b/src/primaite/simulator/network/hardware/base.py @@ -0,0 +1,665 @@ +from __future__ import annotations + +import re +import secrets +from enum import Enum +from ipaddress import IPv4Address, IPv4Network +from typing import Any, Dict, List, Optional, Union + +from primaite import getLogger +from primaite.exceptions import NetworkError +from primaite.simulator.core import SimComponent +from primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket +from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame +from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader + +_LOGGER = getLogger(__name__) + + +def generate_mac_address(oui: Optional[str] = None) -> str: + """ + Generate a random MAC Address. + + :Example: + + >>> generate_mac_address() + 'ef:7e:97:c8:a8:ce' + + >>> generate_mac_address(oui='aa:bb:cc') + 'aa:bb:cc:42:ba:41' + + :param oui: The Organizationally Unique Identifier (OUI) portion of the MAC address. It should be a string with + the first 3 bytes (24 bits) in the format "XX:XX:XX". + :raises ValueError: If the 'oui' is not in the correct format (hexadecimal and 6 characters). + """ + random_bytes = [secrets.randbits(8) for _ in range(6)] + + if oui: + oui_pattern = re.compile(r"^([0-9A-Fa-f]{2}[:-]){2}[0-9A-Fa-f]{2}$") + if not oui_pattern.match(oui): + msg = f"Invalid oui. The oui should be in the format xx:xx:xx, where x is a hexadecimal digit, got '{oui}'" + raise ValueError(msg) + oui_bytes = [int(chunk, 16) for chunk in oui.split(":")] + mac = oui_bytes + random_bytes[len(oui_bytes) :] + else: + mac = random_bytes + + return ":".join(f"{b:02x}" for b in mac) + + +class NIC(SimComponent): + """ + Models a Network Interface Card (NIC) in a computer or network device. + + :param ip_address: The IPv4 address assigned to the NIC. + :param subnet_mask: The subnet mask assigned to the NIC. + :param gateway: The default gateway IP address for forwarding network traffic to other networks. + :param mac_address: The MAC address of the NIC. Defaults to a randomly set MAC address. + :param speed: The speed of the NIC in Mbps (default is 100 Mbps). + :param mtu: The Maximum Transmission Unit (MTU) of the NIC in Bytes, representing the largest data packet size it + can handle without fragmentation (default is 1500 B). + :param wake_on_lan: Indicates if the NIC supports Wake-on-LAN functionality. + :param dns_servers: List of IP addresses of DNS servers used for name resolution. + """ + + ip_address: IPv4Address + "The IP address assigned to the NIC for communication on an IP-based network." + subnet_mask: str + "The subnet mask assigned to the NIC." + gateway: IPv4Address + "The default gateway IP address for forwarding network traffic to other networks. Randomly generated upon creation." + mac_address: str + "The MAC address of the NIC. Defaults to a randomly set MAC address." + speed: int = 100 + "The speed of the NIC in Mbps. Default is 100 Mbps." + mtu: int = 1500 + "The Maximum Transmission Unit (MTU) of the NIC in Bytes. Default is 1500 B" + wake_on_lan: bool = False + "Indicates if the NIC supports Wake-on-LAN functionality." + dns_servers: List[IPv4Address] = [] + "List of IP addresses of DNS servers used for name resolution." + connected_node: Optional[Node] = None + "The Node to which the NIC is connected." + connected_link: Optional[Link] = None + "The Link to which the NIC is connected." + enabled: bool = False + "Indicates whether the NIC is enabled." + + def __init__(self, **kwargs): + """ + NIC constructor. + + Performs some type conversion the calls ``super().__init__()``. Then performs some checking on the ip_address + and gateway just to check that it's all been configured correctly. + + :raises ValueError: When the ip_address and gateway are the same. And when the ip_address/subnet mask are a + network address. + """ + if not isinstance(kwargs["ip_address"], IPv4Address): + kwargs["ip_address"] = IPv4Address(kwargs["ip_address"]) + if not isinstance(kwargs["gateway"], IPv4Address): + kwargs["gateway"] = IPv4Address(kwargs["gateway"]) + if "mac_address" not in kwargs: + kwargs["mac_address"] = generate_mac_address() + super().__init__(**kwargs) + + if self.ip_address == self.gateway: + msg = f"NIC ip address {self.ip_address} cannot be the same as the gateway {self.gateway}" + _LOGGER.error(msg) + raise ValueError(msg) + if self.ip_network.network_address == self.ip_address: + msg = ( + f"Failed to set IP address {self.ip_address} and subnet mask {self.subnet_mask} as it is a " + f"network address {self.ip_network.network_address}" + ) + _LOGGER.error(msg) + raise ValueError(msg) + + @property + def ip_network(self) -> IPv4Network: + """ + Return the IPv4Network of the NIC. + + :return: The IPv4Network from the ip_address/subnet mask. + """ + return IPv4Network(f"{self.ip_address}/{self.subnet_mask}", strict=False) + + def enable(self): + """Attempt to enable the NIC.""" + if not self.enabled: + if self.connected_node: + if self.connected_node.hardware_state == HardwareState.ON: + self.enabled = True + _LOGGER.info(f"NIC {self} enabled") + if self.connected_link: + self.connected_link.endpoint_up() + else: + _LOGGER.info(f"NIC {self} cannot be enabled as the endpoint is not turned on") + else: + msg = f"NIC {self} cannot be enabled as it is not connected to a Node" + _LOGGER.error(msg) + raise NetworkError(msg) + + def disable(self): + """Disable the NIC.""" + if self.enabled: + self.enabled = False + _LOGGER.info(f"NIC {self} disabled") + if self.connected_link: + self.connected_link.endpoint_down() + + def connect_link(self, link: Link): + """ + Connect the NIC to a link. + + :param link: The link to which the NIC is connected. + :type link: :class:`~primaite.simulator.network.transmission.physical_layer.Link` + :raise NetworkError: When an attempt to connect a Link is made while the NIC has a connected Link. + """ + if not self.connected_link: + if self.connected_link != link: + _LOGGER.info(f"NIC {self} connected to Link") + # TODO: Inform the Node that a link has been connected + self.connected_link = link + else: + _LOGGER.warning(f"Cannot connect link to NIC ({self.mac_address}) as it is already connected") + else: + msg = f"Cannot connect link to NIC ({self.mac_address}) as it already has a connection" + _LOGGER.error(msg) + raise NetworkError(msg) + + def disconnect_link(self): + """Disconnect the NIC from the connected Link.""" + if self.connected_link.endpoint_a == self: + self.connected_link.endpoint_a = None + if self.connected_link.endpoint_b == self: + self.connected_link.endpoint_b = None + self.connected_link = None + + def add_dns_server(self, ip_address: IPv4Address): + """ + Add a DNS server IP address. + + :param ip_address: The IP address of the DNS server to be added. + :type ip_address: ipaddress.IPv4Address + """ + pass + + def remove_dns_server(self, ip_address: IPv4Address): + """ + Remove a DNS server IP Address. + + :param ip_address: The IP address of the DNS server to be removed. + :type ip_address: ipaddress.IPv4Address + """ + pass + + def send_frame(self, frame: Frame) -> bool: + """ + Send a network frame from the NIC to the connected link. + + :param frame: The network frame to be sent. + :type frame: :class:`~primaite.simulator.network.osi_layers.Frame` + """ + if self.enabled: + self.connected_link.transmit_frame(sender_nic=self, frame=frame) + return True + else: + # Cannot send Frame as the NIC is not enabled + return False + + def receive_frame(self, frame: Frame) -> bool: + """ + Receive a network frame from the connected link if the NIC is enabled. + + The Frame is passed to the Node. + + :param frame: The network frame being received. + :type frame: :class:`~primaite.simulator.network.osi_layers.Frame` + """ + if self.enabled: + self.connected_node.receive_frame(frame=frame, from_nic=self) + return True + else: + return False + + def describe_state(self) -> Dict: + """ + Get the current state of the NIC as a dict. + + :return: A dict containing the current state of the NIC. + """ + pass + + def apply_action(self, action: str): + """ + Apply an action to the NIC. + + :param action: The action to be applied. + :type action: str + """ + pass + + def __str__(self) -> str: + return f"{self.mac_address}/{self.ip_address}" + + +class Link(SimComponent): + """ + Represents a network link between two network interface cards (NICs). + + :param endpoint_a: The first NIC connected to the Link. + :type endpoint_a: NIC + :param endpoint_b: The second NIC connected to the Link. + :type endpoint_b: NIC + :param bandwidth: The bandwidth of the Link in Mbps (default is 100 Mbps). + :type bandwidth: int + """ + + endpoint_a: NIC + "The first NIC connected to the Link." + endpoint_b: NIC + "The second NIC connected to the Link." + bandwidth: int = 100 + "The bandwidth of the Link in Mbps (default is 100 Mbps)." + current_load: float = 0.0 + "The current load on the link in Mbps." + + def __init__(self, **kwargs): + """ + Ensure that endpoint_a and endpoint_b are not the same NIC. + + Connect the link to the NICs after creation. + + :raises ValueError: If endpoint_a and endpoint_b are the same NIC. + """ + if kwargs["endpoint_a"] == kwargs["endpoint_b"]: + msg = "endpoint_a and endpoint_b cannot be the same NIC" + _LOGGER.error(msg) + raise ValueError(msg) + super().__init__(**kwargs) + self.endpoint_a.connect_link(self) + self.endpoint_b.connect_link(self) + if self.up: + _LOGGER.info(f"Link up between {self.endpoint_a} and {self.endpoint_b}") + + def endpoint_up(self): + """Let the Link know and endpoint has been brought up.""" + if self.up: + _LOGGER.info(f"Link up between {self.endpoint_a} and {self.endpoint_b}") + + def endpoint_down(self): + """Let the Link know and endpoint has been brought down.""" + if not self.up: + self.current_load = 0.0 + _LOGGER.info(f"Link down between {self.endpoint_a} and {self.endpoint_b}") + + @property + def up(self) -> bool: + """ + Informs whether the link is up. + + This is based upon both NIC endpoints being enabled. + """ + return self.endpoint_a.enabled and self.endpoint_b.enabled + + def _can_transmit(self, frame: Frame) -> bool: + if self.up: + frame_size_Mbits = frame.size_Mbits # noqa - Leaving it as Mbits as this is how they're expressed + return self.current_load + frame_size_Mbits <= self.bandwidth + return False + + def transmit_frame(self, sender_nic: NIC, frame: Frame) -> bool: + """ + Send a network frame from one NIC to another connected NIC. + + :param sender_nic: The NIC sending the frame. + :param frame: The network frame to be sent. + :return: True if the Frame can be sent, otherwise False. + """ + receiver_nic = self.endpoint_a + if receiver_nic == sender_nic: + receiver_nic = self.endpoint_b + frame_size = frame.size + sent = receiver_nic.receive_frame(frame) + if sent: + # Frame transmitted successfully + # Load the frame size on the link + self.current_load += frame_size + return True + # Received NIC disabled, reply + + return False + + def reset_component_for_episode(self): + """ + Link reset function. + + Reset: + - returns the link current_load to 0. + """ + self.current_load = 0 + + def describe_state(self) -> Dict: + """ + Get the current state of the Libk as a dict. + + :return: A dict containing the current state of the Link. + """ + pass + + def apply_action(self, action: str): + """ + Apply an action to the Link. + + :param action: The action to be applied. + :type action: str + """ + pass + + +class HardwareState(Enum): + """Node hardware state enumeration.""" + + ON = 1 + OFF = 2 + RESETTING = 3 + SHUTTING_DOWN = 4 + BOOTING = 5 + + +class Node(SimComponent): + """ + A basic Node class. + + :param hostname: The node hostname on the network. + :param hardware_state: The hardware state of the node. + """ + + hostname: str + "The node hostname on the network." + hardware_state: HardwareState = HardwareState.OFF + "The hardware state of the node." + nics: Dict[str, NIC] = {} + "The NICs on the node." + + accounts: Dict = {} + "All accounts on the node." + applications: Dict = {} + "All applications on the node." + services: Dict = {} + "All services on the node." + processes: Dict = {} + "All processes on the node." + file_system: Any = None + "The nodes file system." + arp_cache: Dict[IPv4Address, ARPEntry] = {} + "The ARP cache." + + revealed_to_red: bool = False + "Informs whether the node has been revealed to a red agent." + + def turn_on(self): + """Turn on the Node.""" + if self.hardware_state == HardwareState.OFF: + self.hardware_state = HardwareState.ON + _LOGGER.info(f"Node {self.hostname} turned on") + for nic in self.nics.values(): + nic.enable() + + def turn_off(self): + """Turn off the Node.""" + if self.hardware_state == HardwareState.ON: + for nic in self.nics.values(): + nic.disable() + self.hardware_state = HardwareState.OFF + _LOGGER.info(f"Node {self.hostname} turned off") + + def connect_nic(self, nic: NIC): + """ + Connect a NIC. + + :param nic: The NIC to connect. + :raise NetworkError: If the NIC is already connected. + """ + if nic.uuid not in self.nics: + self.nics[nic.uuid] = nic + nic.connected_node = self + _LOGGER.debug(f"Node {self.hostname} connected NIC {nic}") + if self.hardware_state == HardwareState.ON: + nic.enable() + else: + msg = f"Cannot connect NIC {nic} to Node {self.hostname} as it is already connected" + _LOGGER.error(msg) + raise NetworkError(msg) + + def disconnect_nic(self, nic: Union[NIC, str]): + """ + Disconnect a NIC. + + :param nic: The NIC to Disconnect. + :raise NetworkError: If the NIC is not connected. + """ + if isinstance(nic, str): + nic = self.nics.get(nic) + if nic or nic.uuid in self.nics: + self.nics.pop(nic.uuid) + nic.disable() + _LOGGER.debug(f"Node {self.hostname} disconnected NIC {nic}") + else: + msg = f"Cannot disconnect NIC {nic} from Node {self.hostname} as it is not connected" + _LOGGER.error(msg) + raise NetworkError(msg) + + def _add_arp_cache_entry(self, ip_address: IPv4Address, mac_address: str, nic: NIC): + """ + Add an ARP entry to the cache. + + :param ip_address: The IP address to be added to the cache. + :param mac_address: The MAC address associated with the IP address. + :param nic: The NIC through which the NIC with the IP address is reachable. + """ + _LOGGER.info(f"Node {self.hostname} Adding ARP cache entry for {mac_address}/{ip_address} via NIC {nic}") + arp_entry = ARPEntry(mac_address=mac_address, nic_uuid=nic.uuid) + self.arp_cache[ip_address] = arp_entry + + def _remove_arp_cache_entry(self, ip_address: IPv4Address): + """ + Remove an ARP entry from the cache. + + :param ip_address: The IP address to be removed from the cache. + """ + if ip_address in self.arp_cache: + del self.arp_cache[ip_address] + + def _get_arp_cache_mac_address(self, ip_address: IPv4Address) -> Optional[str]: + """ + Get the MAC address associated with an IP address. + + :param ip_address: The IP address to look up in the cache. + :return: The MAC address associated with the IP address, or None if not found. + """ + arp_entry = self.arp_cache.get(ip_address) + if arp_entry: + return arp_entry.mac_address + + def _get_arp_cache_nic(self, ip_address: IPv4Address) -> Optional[NIC]: + """ + Get the NIC associated with an IP address. + + :param ip_address: The IP address to look up in the cache. + :return: The NIC associated with the IP address, or None if not found. + """ + arp_entry = self.arp_cache.get(ip_address) + if arp_entry: + return self.nics[arp_entry.nic_uuid] + + def _clear_arp_cache(self): + """Clear the entire ARP cache.""" + self.arp_cache.clear() + + def _send_arp_request(self, target_ip_address: Union[IPv4Address, str]): + """Perform a standard ARP request for a given target IP address.""" + for nic in self.nics.values(): + if nic.enabled: + _LOGGER.info(f"Node {self.hostname} sending ARP request from NIC {nic} for ip {target_ip_address}") + tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP) + + # Network Layer + ip_packet = IPPacket( + src_ip=nic.ip_address, + dst_ip=target_ip_address, + ) + # Data Link Layer + ethernet_header = EthernetHeader(src_mac_addr=nic.mac_address, dst_mac_addr="ff:ff:ff:ff:ff:ff") + arp_packet = ARPPacket( + sender_ip=nic.ip_address, sender_mac_addr=nic.mac_address, target_ip=target_ip_address + ) + frame = Frame(ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, arp=arp_packet) + nic.send_frame(frame) + + def process_arp_packet(self, from_nic: NIC, arp_packet: ARPPacket): + """ + Process an ARP packet. + + # TODO: This will become a service that sits on the Node. + + :param from_nic: The NIC the arp packet was received at. + :param arp_packet:The ARP packet to process. + """ + if arp_packet.request: + _LOGGER.info( + f"Node {self.hostname} received ARP request from {arp_packet.sender_mac_addr}/{arp_packet.sender_ip}" + ) + self._add_arp_cache_entry( + ip_address=arp_packet.sender_ip, mac_address=arp_packet.sender_mac_addr, nic=from_nic + ) + arp_packet = arp_packet.generate_reply(from_nic.mac_address) + _LOGGER.info( + f"Node {self.hostname} sending ARP reply from {arp_packet.sender_mac_addr}/{arp_packet.sender_ip} " + f"to {arp_packet.target_ip}/{arp_packet.target_mac_addr} " + ) + + tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP) + + # Network Layer + ip_packet = IPPacket( + src_ip=arp_packet.sender_ip, + dst_ip=arp_packet.target_ip, + ) + # Data Link Layer + ethernet_header = EthernetHeader( + src_mac_addr=arp_packet.sender_mac_addr, dst_mac_addr=arp_packet.target_mac_addr + ) + frame = Frame(ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, arp=arp_packet) + self.send_frame(frame) + else: + _LOGGER.info( + f"Node {self.hostname} received ARP response for {arp_packet.sender_ip} " + f"from {arp_packet.sender_mac_addr} via NIC {from_nic}" + ) + self._add_arp_cache_entry( + ip_address=arp_packet.sender_ip, mac_address=arp_packet.sender_mac_addr, nic=from_nic + ) + + def process_icmp(self, frame: Frame): + """ + Process an ICMP packet. + + # TODO: This will become a service that sits on the Node. + + :param frame: The Frame containing the icmp packet to process. + """ + if frame.icmp.icmp_type == ICMPType.ECHO_REQUEST: + _LOGGER.info(f"Node {self.hostname} received echo request from {frame.ip.src_ip}") + target_mac_address = self._get_arp_cache_mac_address(frame.ip.src_ip) + src_nic = self._get_arp_cache_nic(frame.ip.src_ip) + tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP) + + # Network Layer + ip_packet = IPPacket(src_ip=src_nic.ip_address, dst_ip=frame.ip.src_ip, protocol=IPProtocol.ICMP) + # Data Link Layer + ethernet_header = EthernetHeader(src_mac_addr=src_nic.mac_address, dst_mac_addr=target_mac_address) + icmp_reply_packet = ICMPPacket( + icmp_type=ICMPType.ECHO_REPLY, + icmp_code=0, + identifier=frame.icmp.identifier, + sequence=frame.icmp.sequence + 1, + ) + frame = Frame(ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, icmp=icmp_reply_packet) + src_nic.send_frame(frame) + elif frame.icmp.icmp_type == ICMPType.ECHO_REPLY: + _LOGGER.info(f"Node {self.hostname} received echo reply from {frame.ip.src_ip}") + if frame.icmp.sequence <= 6: # 3 pings + self._ping(frame.ip.src_ip, sequence=frame.icmp.sequence, identifier=frame.icmp.identifier) + + def _ping(self, target_ip_address: IPv4Address, sequence: int = 0, identifier: Optional[int] = None): + nic = self._get_arp_cache_nic(target_ip_address) + if nic: + sequence += 1 + target_mac_address = self._get_arp_cache_mac_address(target_ip_address) + src_nic = self._get_arp_cache_nic(target_ip_address) + tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP) + + # Network Layer + ip_packet = IPPacket( + src_ip=nic.ip_address, + dst_ip=target_ip_address, + protocol=IPProtocol.ICMP, + ) + # Data Link Layer + ethernet_header = EthernetHeader(src_mac_addr=src_nic.mac_address, dst_mac_addr=target_mac_address) + icmp_packet = ICMPPacket(identifier=identifier, sequence=sequence) + frame = Frame(ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, icmp=icmp_packet) + nic.send_frame(frame) + else: + _LOGGER.info(f"Node {self.hostname} no entry in ARP cache for {target_ip_address}") + self._send_arp_request(target_ip_address) + self._ping(target_ip_address=target_ip_address) + + def ping(self, target_ip_address: Union[IPv4Address, str]) -> bool: + """ + Ping an IP address. + + Performs a standard ICMP echo request/response four times. + + :param target_ip_address: The target IP address to ping. + :return: True if successful, otherwise False. + """ + if not isinstance(target_ip_address, IPv4Address): + target_ip_address = IPv4Address(target_ip_address) + if self.hardware_state == HardwareState.ON: + _LOGGER.info(f"Node {self.hostname} attempting to ping {target_ip_address}") + self._ping(target_ip_address) + return True + return False + + def send_frame(self, frame: Frame): + """ + Send a Frame from the Node to the connected NIC. + + :param frame: The Frame to be sent. + """ + nic: NIC = self._get_arp_cache_nic(frame.ip.dst_ip) + nic.send_frame(frame) + + def receive_frame(self, frame: Frame, from_nic: NIC): + """ + Receive a Frame from the connected NIC. + + The Frame is passed to up to the SessionManager. + + :param frame: The Frame being received. + """ + if frame.ip.protocol == IPProtocol.TCP: + if frame.tcp.src_port == Port.ARP: + self.process_arp_packet(from_nic=from_nic, arp_packet=frame.arp) + elif frame.ip.protocol == IPProtocol.UDP: + pass + elif frame.ip.protocol == IPProtocol.ICMP: + self.process_icmp(frame=frame) + + def describe_state(self) -> Dict: + """Describe the state of a Node.""" + pass diff --git a/src/primaite/simulator/network/nodes/__init__.py b/src/primaite/simulator/network/nodes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/simulator/network/protocols/__init__.py b/src/primaite/simulator/network/protocols/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/simulator/network/protocols/arp.py b/src/primaite/simulator/network/protocols/arp.py new file mode 100644 index 00000000..bae14d28 --- /dev/null +++ b/src/primaite/simulator/network/protocols/arp.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from ipaddress import IPv4Address +from typing import Optional + +from pydantic import BaseModel + + +class ARPEntry(BaseModel): + """ + Represents an entry in the ARP cache. + + :param mac_address: The MAC address associated with the IP address. + :param nic: The NIC through which the NIC with the IP address is reachable. + """ + + mac_address: str + nic_uuid: str + + +class ARPPacket(BaseModel): + """ + Represents the ARP layer of a network frame. + + :param request: ARP operation. True if a request, False if a reply. + :param sender_mac_addr: Sender MAC address. + :param sender_ip: Sender IP address. + :param target_mac_addr: Target MAC address. + :param target_ip: Target IP address. + + :Example: + + >>> arp_request = ARPPacket( + ... sender_mac_addr="aa:bb:cc:dd:ee:ff", + ... sender_ip=IPv4Address("192.168.0.1"), + ... target_ip=IPv4Address("192.168.0.2") + ... ) + >>> arp_response = ARPPacket( + ... sender_mac_addr="aa:bb:cc:dd:ee:ff", + ... sender_ip=IPv4Address("192.168.0.1"), + ... target_ip=IPv4Address("192.168.0.2") + ... ) + """ + + request: bool = True + "ARP operation. True if a request, False if a reply." + sender_mac_addr: str + "Sender MAC address." + sender_ip: IPv4Address + "Sender IP address." + target_mac_addr: Optional[str] = None + "Target MAC address." + target_ip: IPv4Address + "Target IP address." + + def generate_reply(self, mac_address: str) -> ARPPacket: + """ + Generate a new ARPPacket to be sent as a response with a given mac address. + + :param mac_address: The mac_address that was being sought after from the original target IP address. + :return: A new instance of ARPPacket. + """ + return ARPPacket( + request=False, + sender_ip=self.target_ip, + sender_mac_addr=mac_address, + target_ip=self.sender_ip, + target_mac_addr=self.sender_mac_addr, + ) diff --git a/src/primaite/simulator/network/transmission/data_link_layer.py b/src/primaite/simulator/network/transmission/data_link_layer.py index b9d969bd..bc7e2453 100644 --- a/src/primaite/simulator/network/transmission/data_link_layer.py +++ b/src/primaite/simulator/network/transmission/data_link_layer.py @@ -3,9 +3,11 @@ from typing import Any, Optional from pydantic import BaseModel from primaite import getLogger -from primaite.simulator.network.transmission.network_layer import ICMPHeader, IPPacket, IPProtocol +from primaite.simulator.network.protocols.arp import ARPPacket +from primaite.simulator.network.transmission.network_layer import ICMPPacket, IPPacket, IPProtocol from primaite.simulator.network.transmission.primaite_layer import PrimaiteHeader from primaite.simulator.network.transmission.transport_layer import TCPHeader, UDPHeader +from primaite.simulator.network.utils import convert_bytes_to_megabits _LOGGER = getLogger(__name__) @@ -74,9 +76,11 @@ class Frame(BaseModel): _LOGGER.error(msg) raise ValueError(msg) if kwargs["ip"].protocol == IPProtocol.ICMP and not kwargs.get("icmp"): - msg = "Cannot build a Frame using the ICMP IP Protocol without a ICMPHeader" + msg = "Cannot build a Frame using the ICMP IP Protocol without a ICMPPacket" _LOGGER.error(msg) raise ValueError(msg) + kwargs["primaite"] = PrimaiteHeader() + super().__init__(**kwargs) ethernet: EthernetHeader @@ -87,14 +91,21 @@ class Frame(BaseModel): "TCP header." udp: Optional[UDPHeader] = None "UDP header." - icmp: Optional[ICMPHeader] = None + icmp: Optional[ICMPPacket] = None "ICMP header." - primaite: PrimaiteHeader = PrimaiteHeader() + arp: Optional[ARPPacket] = None + "ARP packet." + primaite: PrimaiteHeader "PrimAITE header." payload: Optional[Any] = None "Raw data payload." @property - def size(self) -> int: - """The size in Bytes.""" - return len(self.model_dump_json().encode("utf-8")) + def size(self) -> float: # noqa - Keep it as MBits as this is how they're expressed + """The size of the Frame in Bytes.""" + return float(len(self.model_dump_json().encode("utf-8"))) + + @property + def size_Mbits(self) -> float: # noqa - Keep it as MBits as this is how they're expressed + """The daa transfer size of the Frame in MBits.""" + return convert_bytes_to_megabits(self.size) diff --git a/src/primaite/simulator/network/transmission/network_layer.py b/src/primaite/simulator/network/transmission/network_layer.py index 69b682cc..afd1ecef 100644 --- a/src/primaite/simulator/network/transmission/network_layer.py +++ b/src/primaite/simulator/network/transmission/network_layer.py @@ -120,18 +120,23 @@ def get_icmp_type_code_description(icmp_type: ICMPType, icmp_code: int) -> Union return icmp_code_descriptions[icmp_type].get(icmp_code) -class ICMPHeader(BaseModel): - """Models an ICMP Header.""" +class ICMPPacket(BaseModel): + """Models an ICMP Packet.""" icmp_type: ICMPType = ICMPType.ECHO_REQUEST "ICMP Type." icmp_code: int = 0 "ICMP Code." - identifier: str = secrets.randbits(16) + identifier: int "ICMP identifier (16 bits randomly generated)." - sequence: int = 1 + sequence: int = 0 "ICMP message sequence number." + def __init__(self, **kwargs): + if not kwargs.get("identifier"): + kwargs["identifier"] = secrets.randbits(16) + super().__init__(**kwargs) + @field_validator("icmp_code") # noqa @classmethod def _icmp_type_must_have_icmp_code(cls, v: int, info: FieldValidationInfo) -> int: diff --git a/src/primaite/simulator/network/transmission/physical_layer.py b/src/primaite/simulator/network/transmission/physical_layer.py deleted file mode 100644 index ee2297b6..00000000 --- a/src/primaite/simulator/network/transmission/physical_layer.py +++ /dev/null @@ -1,277 +0,0 @@ -from __future__ import annotations - -import re -import secrets -from ipaddress import IPv4Address, IPv4Network -from typing import Dict, List, Optional - -from primaite import getLogger -from primaite.exceptions import NetworkError -from primaite.simulator.core import SimComponent -from primaite.simulator.network.transmission.data_link_layer import Frame - -_LOGGER = getLogger(__name__) - - -def generate_mac_address(oui: Optional[str] = None) -> str: - """ - Generate a random MAC Address. - - :Example: - - >>> generate_mac_address() - 'ef:7e:97:c8:a8:ce' - - >>> generate_mac_address(oui='aa:bb:cc') - 'aa:bb:cc:42:ba:41' - - :param oui: The Organizationally Unique Identifier (OUI) portion of the MAC address. It should be a string with - the first 3 bytes (24 bits) in the format "XX:XX:XX". - :raises ValueError: If the 'oui' is not in the correct format (hexadecimal and 6 characters). - """ - random_bytes = [secrets.randbits(8) for _ in range(6)] - - if oui: - oui_pattern = re.compile(r"^([0-9A-Fa-f]{2}[:-]){2}[0-9A-Fa-f]{2}$") - if not oui_pattern.match(oui): - msg = f"Invalid oui. The oui should be in the format xx:xx:xx, where x is a hexadecimal digit, got '{oui}'" - raise ValueError(msg) - oui_bytes = [int(chunk, 16) for chunk in oui.split(":")] - mac = oui_bytes + random_bytes[len(oui_bytes) :] - else: - mac = random_bytes - - return ":".join(f"{b:02x}" for b in mac) - - -class NIC(SimComponent): - """ - Models a Network Interface Card (NIC) in a computer or network device. - - :param ip_address: The IPv4 address assigned to the NIC. - :param subnet_mask: The subnet mask assigned to the NIC. - :param gateway: The default gateway IP address for forwarding network traffic to other networks. - :param mac_address: The MAC address of the NIC. Defaults to a randomly set MAC address. - :param speed: The speed of the NIC in Mbps (default is 100 Mbps). - :param mtu: The Maximum Transmission Unit (MTU) of the NIC in Bytes, representing the largest data packet size it - can handle without fragmentation (default is 1500 B). - :param wake_on_lan: Indicates if the NIC supports Wake-on-LAN functionality. - :param dns_servers: List of IP addresses of DNS servers used for name resolution. - """ - - ip_address: IPv4Address - "The IP address assigned to the NIC for communication on an IP-based network." - subnet_mask: str - "The subnet mask assigned to the NIC." - gateway: IPv4Address - "The default gateway IP address for forwarding network traffic to other networks. Randomly generated upon creation." - mac_address: str = generate_mac_address() - "The MAC address of the NIC. Defaults to a randomly set MAC address." - speed: int = 100 - "The speed of the NIC in Mbps. Default is 100 Mbps." - mtu: int = 1500 - "The Maximum Transmission Unit (MTU) of the NIC in Bytes. Default is 1500 B" - wake_on_lan: bool = False - "Indicates if the NIC supports Wake-on-LAN functionality." - dns_servers: List[IPv4Address] = [] - "List of IP addresses of DNS servers used for name resolution." - connected_link: Optional[Link] = None - "The Link to which the NIC is connected." - enabled: bool = False - "Indicates whether the NIC is enabled." - - def __init__(self, **kwargs): - """ - NIC constructor. - - Performs some type conversion the calls ``super().__init__()``. Then performs some checking on the ip_address - and gateway just to check that it's all been configured correctly. - - :raises ValueError: When the ip_address and gateway are the same. And when the ip_address/subnet mask are a - network address. - """ - if not isinstance(kwargs["ip_address"], IPv4Address): - kwargs["ip_address"] = IPv4Address(kwargs["ip_address"]) - if not isinstance(kwargs["gateway"], IPv4Address): - kwargs["gateway"] = IPv4Address(kwargs["gateway"]) - super().__init__(**kwargs) - - if self.ip_address == self.gateway: - msg = f"NIC ip address {self.ip_address} cannot be the same as the gateway {self.gateway}" - _LOGGER.error(msg) - raise ValueError(msg) - if self.ip_network.network_address == self.ip_address: - msg = ( - f"Failed to set IP address {self.ip_address} and subnet mask {self.subnet_mask} as it is a " - f"network address {self.ip_network.network_address}" - ) - _LOGGER.error(msg) - raise ValueError(msg) - - @property - def ip_network(self) -> IPv4Network: - """ - Return the IPv4Network of the NIC. - - :return: The IPv4Network from the ip_address/subnet mask. - """ - return IPv4Network(f"{self.ip_address}/{self.subnet_mask}", strict=False) - - def connect_link(self, link: Link): - """ - Connect the NIC to a link. - - :param link: The link to which the NIC is connected. - :type link: :class:`~primaite.simulator.network.transmission.physical_layer.Link` - :raise NetworkError: When an attempt to connect a Link is made while the NIC has a connected Link. - """ - if not self.connected_link: - if self.connected_link != link: - # TODO: Inform the Node that a link has been connected - self.connected_link = link - else: - _LOGGER.warning(f"Cannot connect link to NIC ({self.mac_address}) as it is already connected") - else: - msg = f"Cannot connect link to NIC ({self.mac_address}) as it already has a connection" - _LOGGER.error(msg) - raise NetworkError(msg) - - def disconnect_link(self): - """Disconnect the NIC from the connected Link.""" - if self.connected_link.endpoint_a == self: - self.connected_link.endpoint_a = None - if self.connected_link.endpoint_b == self: - self.connected_link.endpoint_b = None - self.connected_link = None - - def add_dns_server(self, ip_address: IPv4Address): - """ - Add a DNS server IP address. - - :param ip_address: The IP address of the DNS server to be added. - :type ip_address: ipaddress.IPv4Address - """ - pass - - def remove_dns_server(self, ip_address: IPv4Address): - """ - Remove a DNS server IP Address. - - :param ip_address: The IP address of the DNS server to be removed. - :type ip_address: ipaddress.IPv4Address - """ - pass - - def send_frame(self, frame: Frame): - """ - Send a network frame from the NIC to the connected link. - - :param frame: The network frame to be sent. - :type frame: :class:`~primaite.simulator.network.osi_layers.Frame` - """ - pass - - def receive_frame(self, frame: Frame): - """ - Receive a network frame from the connected link. - - The Frame is passed to the Node. - - :param frame: The network frame being received. - :type frame: :class:`~primaite.simulator.network.osi_layers.Frame` - """ - pass - - def describe_state(self) -> Dict: - """ - Get the current state of the NIC as a dict. - - :return: A dict containing the current state of the NIC. - """ - pass - - def apply_action(self, action: str): - """ - Apply an action to the NIC. - - :param action: The action to be applied. - :type action: str - """ - pass - - -class Link(SimComponent): - """ - Represents a network link between two network interface cards (NICs). - - :param endpoint_a: The first NIC connected to the Link. - :type endpoint_a: NIC - :param endpoint_b: The second NIC connected to the Link. - :type endpoint_b: NIC - :param bandwidth: The bandwidth of the Link in Mbps (default is 100 Mbps). - :type bandwidth: int - """ - - endpoint_a: NIC - "The first NIC connected to the Link." - endpoint_b: NIC - "The second NIC connected to the Link." - bandwidth: int = 100 - "The bandwidth of the Link in Mbps (default is 100 Mbps)." - current_load: int = 0 - "The current load on the link in Mbps." - - def __init__(self, **kwargs): - """ - Ensure that endpoint_a and endpoint_b are not the same NIC. - - Connect the link to the NICs after creation. - - :raises ValueError: If endpoint_a and endpoint_b are the same NIC. - """ - if kwargs["endpoint_a"] == kwargs["endpoint_b"]: - msg = "endpoint_a and endpoint_b cannot be the same NIC" - _LOGGER.error(msg) - raise ValueError(msg) - super().__init__(**kwargs) - self.endpoint_a.connect_link(self) - self.endpoint_b.connect_link(self) - - def send_frame(self, sender_nic: NIC, frame: Frame): - """ - Send a network frame from one NIC to another connected NIC. - - :param sender_nic: The NIC sending the frame. - :type sender_nic: NIC - :param frame: The network frame to be sent. - :type frame: Frame - """ - pass - - def receive_frame(self, sender_nic: NIC, frame: Frame): - """ - Receive a network frame from a connected NIC. - - :param sender_nic: The NIC sending the frame. - :type sender_nic: NIC - :param frame: The network frame being received. - :type frame: Frame - """ - pass - - def describe_state(self) -> Dict: - """ - Get the current state of the Libk as a dict. - - :return: A dict containing the current state of the Link. - """ - pass - - def apply_action(self, action: str): - """ - Apply an action to the Link. - - :param action: The action to be applied. - :type action: str - """ - pass diff --git a/src/primaite/simulator/network/transmission/transport_layer.py b/src/primaite/simulator/network/transmission/transport_layer.py index c8e6b89d..b95b4a74 100644 --- a/src/primaite/simulator/network/transmission/transport_layer.py +++ b/src/primaite/simulator/network/transmission/transport_layer.py @@ -33,6 +33,8 @@ class Port(Enum): "Simple Network Management Protocol (SNMP) - Used for network device management." SNMP_TRAP = 162 "SNMP Trap - Used for sending SNMP notifications (traps) to a network management system." + ARP = 219 + "Address resolution Protocol - Used to connect a MAC address to an IP address." LDAP = 389 "Lightweight Directory Access Protocol (LDAP) - Used for accessing and modifying directory information." HTTPS = 443 @@ -114,6 +116,6 @@ class TCPHeader(BaseModel): ... ) """ - src_port: int - dst_port: int + src_port: Port + dst_port: Port flags: List[TCPFlags] = [TCPFlags.SYN] diff --git a/src/primaite/simulator/network/utils.py b/src/primaite/simulator/network/utils.py new file mode 100644 index 00000000..496f5e13 --- /dev/null +++ b/src/primaite/simulator/network/utils.py @@ -0,0 +1,27 @@ +from typing import Union + + +def convert_bytes_to_megabits(B: Union[int, float]) -> float: # noqa - Keep it as B as this is how Bytes are expressed + """ + Convert Bytes (file size) to Megabits (data transfer). + + :param B: The file size in Bytes. + :return: File bits to transfer in Megabits. + """ + if isinstance(B, int): + B = float(B) + bits = B * 8.0 + return bits / 1024.0**2.0 + + +def convert_megabits_to_bytes(Mbits: Union[int, float]) -> float: # noqa - The same for Mbits + """ + Convert Megabits (data transfer) to Bytes (file size). + + :param Mbits bits to transfer in Megabits. + :return: The file size in Bytes. + """ + if isinstance(Mbits, int): + Mbits = float(Mbits) + bits = Mbits * 1024.0**2.0 + return bits / 8 diff --git a/tests/integration_tests/network/test_frame_transmission.py b/tests/integration_tests/network/test_frame_transmission.py new file mode 100644 index 00000000..32abd0ef --- /dev/null +++ b/tests/integration_tests/network/test_frame_transmission.py @@ -0,0 +1,25 @@ +from primaite.simulator.network.hardware.base import Link, NIC, Node + + +def test_node_to_node_ping(): + node_a = Node(hostname="node_a") + nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", gateway="192.168.0.1") + node_a.connect_nic(nic_a) + node_a.turn_on() + + node_b = Node(hostname="node_b") + nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0", gateway="192.168.0.1") + node_b.connect_nic(nic_b) + node_b.turn_on() + + Link(endpoint_a=nic_a, endpoint_b=nic_b) + + assert node_a.ping("192.168.0.11") + + node_a.turn_off() + + assert not node_a.ping("192.168.0.11") + + node_a.turn_on() + + assert node_a.ping("192.168.0.11") diff --git a/tests/integration_tests/network/test_link_connection.py b/tests/integration_tests/network/test_link_connection.py new file mode 100644 index 00000000..50abed77 --- /dev/null +++ b/tests/integration_tests/network/test_link_connection.py @@ -0,0 +1,21 @@ +from primaite.simulator.network.hardware.base import Link, NIC, Node + + +def test_link_up(): + """Tests Nodes, NICs, and Links can all be connected and be in an enabled/up state.""" + node_a = Node(hostname="node_a") + nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", gateway="192.168.0.1") + node_a.connect_nic(nic_a) + node_a.turn_on() + assert nic_a.enabled + + node_b = Node(hostname="node_b") + nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0", gateway="192.168.0.1") + node_b.connect_nic(nic_b) + node_b.turn_on() + + assert nic_b.enabled + + link = Link(endpoint_a=nic_a, endpoint_b=nic_b) + + assert link.up diff --git a/tests/integration_tests/network/test_nic_link_connection.py b/tests/integration_tests/network/test_nic_link_connection.py index 6bca3c0a..52a0c735 100644 --- a/tests/integration_tests/network/test_nic_link_connection.py +++ b/tests/integration_tests/network/test_nic_link_connection.py @@ -1,6 +1,6 @@ import pytest -from primaite.simulator.network.transmission.physical_layer import Link, NIC +from primaite.simulator.network.hardware.base import Link, NIC def test_link_fails_with_same_nic(): diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/__init__.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_physical_layer.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_nic.py similarity index 95% rename from tests/unit_tests/_primaite/_simulator/_network/_transmission/test_physical_layer.py rename to tests/unit_tests/_primaite/_simulator/_network/_hardware/test_nic.py index 5a33e723..dc508508 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_physical_layer.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_nic.py @@ -3,7 +3,7 @@ from ipaddress import IPv4Address import pytest -from primaite.simulator.network.transmission.physical_layer import generate_mac_address, NIC +from primaite.simulator.network.hardware.base import generate_mac_address, NIC def test_mac_address_generation(): diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node.py new file mode 100644 index 00000000..0e5fb4c7 --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node.py @@ -0,0 +1,10 @@ +import re +from ipaddress import IPv4Address + +import pytest + +from primaite.simulator.network.hardware.base import Node + + +def test_node_creation(): + node = Node(hostname="host_1") diff --git a/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py b/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py index 83b215ca..8a78d1bc 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py @@ -1,7 +1,7 @@ import pytest from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame -from primaite.simulator.network.transmission.network_layer import ICMPHeader, IPPacket, IPProtocol, Precedence +from primaite.simulator.network.transmission.network_layer import ICMPPacket, IPPacket, IPProtocol, Precedence from primaite.simulator.network.transmission.primaite_layer import AgentSource, DataStatus from primaite.simulator.network.transmission.transport_layer import Port, TCPFlags, TCPHeader, UDPHeader @@ -76,7 +76,7 @@ def test_icmp_frame_creation(): frame = Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), ip=IPPacket(src_ip="192.168.0.10", dst_ip="192.168.0.20", protocol=IPProtocol.ICMP), - icmp=ICMPHeader(), + icmp=ICMPPacket(), ) assert frame diff --git a/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_network_layer.py b/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_network_layer.py index 584ff25d..a7189452 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_network_layer.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_network_layer.py @@ -1,24 +1,24 @@ import pytest -from primaite.simulator.network.transmission.network_layer import ICMPHeader, ICMPType +from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType def test_icmp_minimal_header_creation(): - """Checks the minimal ICMPHeader (ping 1 request) creation using default values.""" - ping = ICMPHeader() + """Checks the minimal ICMPPacket (ping 1 request) creation using default values.""" + ping = ICMPPacket() assert ping.icmp_type == ICMPType.ECHO_REQUEST assert ping.icmp_code == 0 assert ping.identifier - assert ping.sequence == 1 + assert ping.sequence == 0 def test_valid_icmp_type_code_pairing(): - """Tests ICMPHeader creation with valid type and code pairing.""" - assert ICMPHeader(icmp_type=ICMPType.DESTINATION_UNREACHABLE, icmp_code=6) + """Tests ICMPPacket creation with valid type and code pairing.""" + assert ICMPPacket(icmp_type=ICMPType.DESTINATION_UNREACHABLE, icmp_code=6) def test_invalid_icmp_type_code_pairing(): - """Tests ICMPHeader creation fails with invalid type and code pairing.""" + """Tests ICMPPacket creation fails with invalid type and code pairing.""" with pytest.raises(ValueError): - assert ICMPHeader(icmp_type=ICMPType.DESTINATION_UNREACHABLE, icmp_code=16) + assert ICMPPacket(icmp_type=ICMPType.DESTINATION_UNREACHABLE, icmp_code=16) diff --git a/tests/unit_tests/_primaite/_simulator/test_core.py b/tests/unit_tests/_primaite/_simulator/test_core.py index de0732f9..9f4b5fd9 100644 --- a/tests/unit_tests/_primaite/_simulator/test_core.py +++ b/tests/unit_tests/_primaite/_simulator/test_core.py @@ -1,4 +1,5 @@ from typing import Callable, Dict, List, Literal, Tuple +from uuid import uuid4 import pytest from pydantic import ValidationError @@ -35,15 +36,17 @@ class TestIsolatedSimComponent: """Validate that our added functionality does not interfere with pydantic.""" class TestComponent(SimComponent): + uuid: str name: str size: Tuple[float, float] def describe_state(self) -> Dict: return {} - comp = TestComponent(name="computer", size=(5, 10)) + uuid = str(uuid4()) + comp = TestComponent(uuid=uuid, name="computer", size=(5, 10)) dump = comp.model_dump() - assert dump == {"name": "computer", "size": (5, 10)} + assert dump == {"uuid": uuid, "name": "computer", "size": (5, 10)} def test_apply_action(self): """Validate that we can override apply_action behaviour and it updates the state of the component."""