diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index c64b9b67..921ebbcd 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -77,12 +77,10 @@ class NIC(SimComponent): ip_address: IPv4Address "The IP address assigned to the NIC for communication on an IP-based network." - subnet_mask: str + subnet_mask: IPv4Address "The subnet mask assigned to the NIC." - gateway: IPv4Address - "The default gateway IP address for forwarding network traffic to other networks. Randomly generated upon creation." mac_address: str - "The MAC address of the NIC. Defaults to a randomly set MAC address." + "The MAC address of the NIC. Defaults to a randomly set MAC address. Randomly generated upon creation." speed: int = 100 "The speed of the NIC in Mbps. Default is 100 Mbps." mtu: int = 1500 @@ -111,16 +109,10 @@ class NIC(SimComponent): """ if not isinstance(kwargs["ip_address"], IPv4Address): kwargs["ip_address"] = IPv4Address(kwargs["ip_address"]) - if not isinstance(kwargs["gateway"], IPv4Address): - kwargs["gateway"] = IPv4Address(kwargs["gateway"]) if "mac_address" not in kwargs: kwargs["mac_address"] = generate_mac_address() super().__init__(**kwargs) - if self.ip_address == self.gateway: - msg = f"NIC ip address {self.ip_address} cannot be the same as the gateway {self.gateway}" - _LOGGER.error(msg) - raise ValueError(msg) if self.ip_network.network_address == self.ip_address: msg = ( f"Failed to set IP address {self.ip_address} and subnet mask {self.subnet_mask} as it is a " @@ -274,6 +266,9 @@ class NIC(SimComponent): if frame.ethernet.dst_mac_addr == self.mac_address or frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff": self.connected_node.receive_frame(frame=frame, from_nic=self) return True + else: + self.connected_node.sys_log.info("Dropping frame not for me") + print(frame) return False def __str__(self) -> str: @@ -567,7 +562,21 @@ class ARPCache: self.arp: Dict[IPv4Address, ARPEntry] = {} self.nics: Dict[str, "NIC"] = {} - def _add_arp_cache_entry(self, ip_address: IPv4Address, mac_address: str, nic: NIC): + def show(self): + """Prints a table of ARC Cache.""" + table = PrettyTable(["IP Address", "MAC Address", "Via"]) + table.title = f"{self.sys_log.hostname} ARP Cache" + for ip, arp in self.arp.items(): + table.add_row( + [ + str(ip), + arp.mac_address, + self.nics[arp.nic_uuid].mac_address, + ] + ) + print(table) + + def add_arp_cache_entry(self, ip_address: IPv4Address, mac_address: str, nic: NIC, override: bool = False): """ Add an ARP entry to the cache. @@ -575,9 +584,14 @@ class ARPCache: :param mac_address: The MAC address associated with the IP address. :param nic: The NIC through which the NIC with the IP address is reachable. """ - self.sys_log.info(f"Adding ARP cache entry for {mac_address}/{ip_address} via NIC {nic}") - arp_entry = ARPEntry(mac_address=mac_address, nic_uuid=nic.uuid) - self.arp[ip_address] = arp_entry + for _nic in self.nics.values(): + if _nic.ip_address == ip_address: + return + if override or not self.arp.get(ip_address): + self.sys_log.info(f"Adding ARP cache entry for {mac_address}/{ip_address} via NIC {nic}") + arp_entry = ARPEntry(mac_address=mac_address, nic_uuid=nic.uuid) + + self.arp[ip_address] = arp_entry def _remove_arp_cache_entry(self, ip_address: IPv4Address): """ @@ -607,6 +621,7 @@ class ARPCache: :return: The NIC associated with the IP address, or None if not found. """ arp_entry = self.arp.get(ip_address) + if arp_entry: return self.nics[arp_entry.nic_uuid] @@ -641,6 +656,29 @@ class ARPCache: frame = Frame(ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, arp=arp_packet) nic.send_frame(frame) + def send_arp_reply(self, arp_reply: ARPPacket, from_nic: NIC): + """ + Send an ARP reply back through the NIC it came from. + + :param arp_reply: The ARP reply to send. + :param from_nic: The NIC to send the ARP reply from. + """ + self.sys_log.info( + f"Sending ARP reply from {arp_reply.sender_mac_addr}/{arp_reply.sender_ip} " + f"to {arp_reply.target_ip}/{arp_reply.target_mac_addr} " + ) + tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP) + + ip_packet = IPPacket( + src_ip=arp_reply.sender_ip, + dst_ip=arp_reply.target_ip, + ) + + ethernet_header = EthernetHeader(src_mac_addr=arp_reply.sender_mac_addr, dst_mac_addr=arp_reply.target_mac_addr) + + frame = Frame(ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, arp=arp_reply) + from_nic.send_frame(frame) + def process_arp_packet(self, from_nic: NIC, arp_packet: ARPPacket): """ Process a received ARP packet, handling both ARP requests and responses. @@ -656,7 +694,7 @@ class ARPCache: self.sys_log.info( f"Received ARP response for {arp_packet.sender_ip} from {arp_packet.sender_mac_addr} via NIC {from_nic}" ) - self._add_arp_cache_entry( + self.add_arp_cache_entry( ip_address=arp_packet.sender_ip, mac_address=arp_packet.sender_mac_addr, nic=from_nic ) return @@ -673,26 +711,13 @@ class ARPCache: return # Matched ARP request - self._add_arp_cache_entry(ip_address=arp_packet.sender_ip, mac_address=arp_packet.sender_mac_addr, nic=from_nic) + self.add_arp_cache_entry(ip_address=arp_packet.sender_ip, mac_address=arp_packet.sender_mac_addr, nic=from_nic) arp_packet = arp_packet.generate_reply(from_nic.mac_address) - self.sys_log.info( - f"Sending ARP reply from {arp_packet.sender_mac_addr}/{arp_packet.sender_ip} " - f"to {arp_packet.target_ip}/{arp_packet.target_mac_addr} " - ) + self.send_arp_reply(arp_packet, from_nic) - tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP) + def __contains__(self, item) -> bool: + return item in self.arp - # Network Layer - ip_packet = IPPacket( - src_ip=arp_packet.sender_ip, - dst_ip=arp_packet.target_ip, - ) - # Data Link Layer - ethernet_header = EthernetHeader( - src_mac_addr=arp_packet.sender_mac_addr, dst_mac_addr=arp_packet.target_mac_addr - ) - frame = Frame(ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, arp=arp_packet) - from_nic.send_frame(frame) class ICMP: """ @@ -712,8 +737,7 @@ class ICMP: self.arp: ARPCache = arp_cache self.request_replies = {} - - def process_icmp(self, frame: Frame): + def process_icmp(self, frame: Frame, from_nic: NIC, is_reattempt: bool = False): """ Process an ICMP packet, including handling echo requests and replies. @@ -722,7 +746,15 @@ class ICMP: if frame.icmp.icmp_type == ICMPType.ECHO_REQUEST: self.sys_log.info(f"Received echo request from {frame.ip.src_ip}") target_mac_address = self.arp.get_arp_cache_mac_address(frame.ip.src_ip) + src_nic = self.arp.get_arp_cache_nic(frame.ip.src_ip) + if not src_nic: + print(self.sys_log.hostname) + print(frame.ip.src_ip) + self.arp.show() + self.arp.send_arp_request(frame.ip.src_ip) + self.process_icmp(frame=frame, from_nic=from_nic, is_reattempt=True) + return tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP) # Network Layer @@ -737,6 +769,7 @@ class ICMP: ) frame = Frame(ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, icmp=icmp_reply_packet) self.sys_log.info(f"Sending echo reply to {frame.ip.dst_ip}") + src_nic.send_frame(frame) elif frame.icmp.icmp_type == ICMPType.ECHO_REPLY: self.sys_log.info(f"Received echo reply from {frame.ip.src_ip}") @@ -745,7 +778,7 @@ class ICMP: self.request_replies[frame.icmp.identifier] += 1 def ping( - self, target_ip_address: IPv4Address, sequence: int = 0, identifier: Optional[int] = None + self, target_ip_address: IPv4Address, sequence: int = 0, identifier: Optional[int] = None, pings: int = 4 ) -> Tuple[int, Union[int, None]]: """ Send an ICMP echo request (ping) to a target IP address and manage the sequence and identifier. @@ -757,13 +790,21 @@ class ICMP: was not found in the ARP cache. """ nic = self.arp.get_arp_cache_nic(target_ip_address) - # TODO: Eventually this ARP request needs to be done elsewhere. It's not the resonsibility of the + # TODO: Eventually this ARP request needs to be done elsewhere. It's not the responsibility of the # ping function to handle ARP lookups + + # Already tried once and cannot get ARP entry, stop trying + if sequence == -1: + if not nic: + return 4, None + else: + sequence = 0 + # No existing ARP entry if not nic: self.sys_log.info(f"No entry in ARP cache for {target_ip_address}") self.arp.send_arp_request(target_ip_address) - return 0, None + return -1, None # ARP entry exists sequence += 1 @@ -812,6 +853,8 @@ class Node(SimComponent): hostname: str "The node hostname on the network." + default_gateway: Optional[IPv4Address] = None + "The default gateway IP address for forwarding network traffic to other networks." operating_state: NodeOperatingState = NodeOperatingState.OFF "The hardware state of the node." nics: Dict[str, NIC] = {} @@ -843,9 +886,12 @@ class Node(SimComponent): This method initializes the ARP cache, ICMP handler, session manager, and software manager if they are not provided. """ + if kwargs.get("default_gateway"): + if not isinstance(kwargs["default_gateway"], IPv4Address): + kwargs["default_gateway"] = IPv4Address(kwargs["default_gateway"]) if not kwargs.get("sys_log"): kwargs["sys_log"] = SysLog(kwargs["hostname"]) - if not kwargs.get("arp_cache"): + if not kwargs.get("arp"): kwargs["arp"] = ARPCache(sys_log=kwargs.get("sys_log")) if not kwargs.get("icmp"): kwargs["icmp"] = ICMP(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp")) @@ -886,10 +932,8 @@ class Node(SimComponent): def show(self): """Prints a table of the NICs on the Node.""" - from prettytable import PrettyTable - table = PrettyTable(["MAC Address", "Address", "Default Gateway", "Speed", "Status"]) - + table.title = f"{self.hostname} Network Interface Cards" for nic in self.nics.values(): table.add_row( [ @@ -967,13 +1011,18 @@ class Node(SimComponent): """ if not isinstance(target_ip_address, IPv4Address): target_ip_address = IPv4Address(target_ip_address) + if target_ip_address.is_loopback: + self.sys_log.info("Pinging loopback address") + return any(nic.enabled for nic in self.nics.values()) if self.operating_state == NodeOperatingState.ON: self.sys_log.info(f"Attempting to ping {target_ip_address}") sequence, identifier = 0, None while sequence < pings: - sequence, identifier = self.icmp.ping(target_ip_address, sequence, identifier) - passed = self.icmp.request_replies[identifier] == pings - self.icmp.request_replies.pop(identifier) + sequence, identifier = self.icmp.ping(target_ip_address, sequence, identifier, pings) + request_replies = self.icmp.request_replies.get(identifier) + passed = request_replies == pings + if request_replies: + self.icmp.request_replies.pop(identifier) return passed self.sys_log.info("Ping failed as the node is turned off") return False @@ -997,13 +1046,18 @@ class Node(SimComponent): :param frame: The Frame being received. :param from_nic: The NIC that received the frame. """ + if frame.ip: + if frame.ip.src_ip in self.arp: + self.arp.add_arp_cache_entry( + ip_address=frame.ip.src_ip, mac_address=frame.ethernet.src_mac_addr, nic=from_nic + ) if frame.ip.protocol == IPProtocol.TCP: if frame.tcp.src_port == Port.ARP: self.arp.process_arp_packet(from_nic=from_nic, arp_packet=frame.arp) elif frame.ip.protocol == IPProtocol.UDP: pass elif frame.ip.protocol == IPProtocol.ICMP: - self.icmp.process_icmp(frame=frame) + self.icmp.process_icmp(frame=frame, from_nic=from_nic) class Switch(Node): @@ -1027,7 +1081,7 @@ class Switch(Node): def show(self): """Prints a table of the SwitchPorts on the Switch.""" table = PrettyTable(["Port", "MAC Address", "Speed", "Status"]) - + table.title = f"{self.hostname} Switch Ports" for port_num, port in self.switch_ports.items(): table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"]) print(table) diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index c5620b88..528e4a73 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -1,63 +1,514 @@ -from enum import Enum -from ipaddress import IPv4Address -from typing import Dict, List, Union +from __future__ import annotations + +from enum import Enum +from ipaddress import IPv4Address, IPv4Network +from typing import Dict, List, Optional, Tuple, Union -from primaite.simulator.core import SimComponent -from primaite.simulator.network.hardware.base import Node, NIC from prettytable import PrettyTable -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.core import SimComponent +from primaite.simulator.network.hardware.base import ARPCache, ICMP, NIC, Node +from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame +from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader +from primaite.simulator.system.core.sys_log import SysLog class ACLAction(Enum): + """Enum for defining the ACL action types.""" + DENY = 0 PERMIT = 1 class ACLRule(SimComponent): - action: ACLAction - protocol: IPProtocol - src_ip: IPv4Address - src_wildcard: IPv4Address = IPv4Address("0.0.0.0") - src_port: Port - dst_ip: IPv4Address - dst_port: Port + def describe_state(self) -> Dict: + pass + + action: ACLAction = ACLAction.DENY + protocol: Optional[IPProtocol] = None + src_ip: Optional[IPv4Address] = None + src_port: Optional[Port] = None + dst_ip: Optional[IPv4Address] = None + dst_port: Optional[Port] = None + + def __str__(self) -> str: + rule_strings = [] + for key, value in self.model_dump(exclude={"uuid", "action_manager"}).items(): + if value is None: + value = "ANY" + if isinstance(value, Enum): + rule_strings.append(f"{key}={value.name}") + else: + rule_strings.append(f"{key}={value}") + return ", ".join(rule_strings) -class RouteTableEntry(SimComponent): - pass +class AccessControlList(SimComponent): + sys_log: SysLog + implicit_action: ACLAction + implicit_rule: ACLRule + max_acl_rules: int = 25 + _acl: List[Optional[ACLRule]] = [None] * 24 + + def __init__(self, **kwargs) -> None: + if not kwargs.get("implicit_action"): + kwargs["implicit_action"] = ACLAction.DENY + if not kwargs.get("max_acl_rules"): + kwargs["max_acl_rules"] = 25 + kwargs["implicit_rule"] = ACLRule(action=kwargs["implicit_action"]) + kwargs["_acl"] = [None] * (kwargs["max_acl_rules"] - 1) + + super().__init__(**kwargs) + + def describe_state(self) -> Dict: + pass + + @property + def acl(self) -> List[Optional[ACLRule]]: + return self._acl + + def add_rule( + self, + action: ACLAction, + protocol: Optional[IPProtocol] = None, + src_ip: Optional[Union[str, IPv4Address]] = None, + src_port: Optional[Port] = None, + dst_ip: Optional[Union[str, IPv4Address]] = None, + dst_port: Optional[Port] = None, + position: int = 0, + ) -> None: + if isinstance(src_ip, str): + src_ip = IPv4Address(src_ip) + if isinstance(dst_ip, str): + dst_ip = IPv4Address(dst_ip) + if 0 <= position < self.max_acl_rules: + self._acl[position] = ACLRule( + action=action, src_ip=src_ip, dst_ip=dst_ip, protocol=protocol, src_port=src_port, dst_port=dst_port + ) + else: + raise ValueError(f"Position {position} is out of bounds.") + + def remove_rule(self, position: int) -> None: + if 0 <= position < self.max_acl_rules: + self._acl[position] = None + else: + raise ValueError(f"Position {position} is out of bounds.") + + def is_permitted( + self, + protocol: IPProtocol, + src_ip: Union[str, IPv4Address], + src_port: Optional[Port], + dst_ip: Union[str, IPv4Address], + dst_port: Optional[Port], + ) -> Tuple[bool, Optional[Union[str, ACLRule]]]: + if not isinstance(src_ip, IPv4Address): + src_ip = IPv4Address(src_ip) + if not isinstance(dst_ip, IPv4Address): + dst_ip = IPv4Address(dst_ip) + for rule in self._acl: + if not rule: + continue + + if ( + (rule.src_ip == src_ip or rule.src_ip is None) + and (rule.dst_ip == dst_ip or rule.dst_ip is None) + and (rule.protocol == protocol or rule.protocol is None) + and (rule.src_port == src_port or rule.src_port is None) + and (rule.dst_port == dst_port or rule.dst_port is None) + ): + return rule.action == ACLAction.PERMIT, rule + + return self.implicit_action == ACLAction.PERMIT, f"Implicit {self.implicit_action.name}" + + def get_relevant_rules( + self, + protocol: IPProtocol, + src_ip: Union[str, IPv4Address], + src_port: Port, + dst_ip: Union[str, IPv4Address], + dst_port: Port, + ) -> List[ACLRule]: + if not isinstance(src_ip, IPv4Address): + src_ip = IPv4Address(src_ip) + if not isinstance(dst_ip, IPv4Address): + dst_ip = IPv4Address(dst_ip) + relevant_rules = [] + for rule in self._acl: + if rule is None: + continue + + if ( + (rule.src_ip == src_ip or rule.src_ip is None) + or (rule.dst_ip == dst_ip or rule.dst_ip is None) + or (rule.protocol == protocol or rule.protocol is None) + or (rule.src_port == src_port or rule.src_port is None) + or (rule.dst_port == dst_port or rule.dst_port is None) + ): + relevant_rules.append(rule) + + return relevant_rules + + def show(self): + """Prints a table of the routes in the RouteTable.""" + """ + action: ACLAction + protocol: Optional[IPProtocol] + src_ip: Optional[IPv4Address] + src_port: Optional[Port] + dst_ip: Optional[IPv4Address] + dst_port: Optional[Port] + """ + table = PrettyTable(["Index", "Action", "Protocol", "Src IP", "Src Port", "Dst IP", "Dst Port"]) + table.title = f"{self.sys_log.hostname} Access Control List" + for index, rule in enumerate(self.acl + [self.implicit_rule]): + if rule: + table.add_row( + [ + index, + rule.action.name if rule.action else "ANY", + rule.protocol.name if rule.protocol else "ANY", + rule.src_ip if rule.src_ip else "ANY", + f"{rule.src_port.value} ({rule.src_port.name})" if rule.src_port else "ANY", + rule.dst_ip if rule.dst_ip else "ANY", + f"{rule.dst_port.value} ({rule.dst_port.name})" if rule.dst_port else "ANY", + ] + ) + print(table) + + +class RouteEntry(SimComponent): + """ + Represents a single entry in a routing table. + + Attributes: + address (IPv4Address): The destination IP address or network address. + subnet_mask (IPv4Address): The subnet mask for the network. + next_hop (IPv4Address): The next hop IP address to which packets should be forwarded. + metric (int): The cost metric for this route. Default is 0.0. + + Example: + >>> entry = RouteEntry( + ... IPv4Address("192.168.1.0"), + ... IPv4Address("255.255.255.0"), + ... IPv4Address("192.168.2.1"), + ... metric=5 + ... ) + """ + + address: IPv4Address + "The destination IP address or network address." + subnet_mask: IPv4Address + "The subnet mask for the network." + next_hop: IPv4Address + "The next hop IP address to which packets should be forwarded." + metric: float = 0.0 + "The cost metric for this route. Default is 0.0." + + def __init__(self, **kwargs): + for key in {"address", "subnet_mask", "next_hop"}: + if not isinstance(kwargs[key], IPv4Address): + kwargs[key] = IPv4Address(kwargs[key]) + super().__init__(**kwargs) + + def describe_state(self) -> Dict: + pass + + +class RouteTable(SimComponent): + """ + Represents a routing table holding multiple route entries. + + Attributes: + routes (List[RouteEntry]): A list of RouteEntry objects. + + Methods: + add_route: Add a route to the routing table. + find_best_route: Find the best route for a given destination IP. + + Example: + >>> rt = RouteTable() + >>> rt.add_route( + ... RouteEntry( + ... IPv4Address("192.168.1.0"), + ... IPv4Address("255.255.255.0"), + ... IPv4Address("192.168.2.1"), + ... metric=5 + ... ) + ... ) + >>> best_route = rt.find_best_route(IPv4Address("192.168.1.34")) + """ + + routes: List[RouteEntry] = [] + sys_log: SysLog + + def describe_state(self) -> Dict: + pass + + def add_route( + self, + address: Union[IPv4Address, str], + subnet_mask: Union[IPv4Address, str], + next_hop: Union[IPv4Address, str], + metric: float = 0.0, + ): + """Add a route to the routing table. + + :param route: A RouteEntry object representing the route. + """ + for key in {address, subnet_mask, next_hop}: + if not isinstance(key, IPv4Address): + key = IPv4Address(key) + route = RouteEntry(address=address, subnet_mask=subnet_mask, next_hop=next_hop, metric=metric) + self.routes.append(route) + + def find_best_route(self, destination_ip: Union[str, IPv4Address]) -> Optional[RouteEntry]: + """ + Find the best route for a given destination IP. + + :param destination_ip: The destination IPv4Address to find the route for. + :return: The best matching RouteEntry, or None if no route matches. + + The algorithm uses Longest Prefix Match and considers metrics to find the best route. + """ + if not isinstance(destination_ip, IPv4Address): + destination_ip = IPv4Address(destination_ip) + best_route = None + longest_prefix = -1 + lowest_metric = float("inf") # Initialise at infinity as any other number we compare to it will be smaller + + for route in self.routes: + route_network = IPv4Network(f"{route.address}/{route.subnet_mask}", strict=False) + prefix_len = route_network.prefixlen + + if destination_ip in route_network: + if prefix_len > longest_prefix or (prefix_len == longest_prefix and route.metric < lowest_metric): + best_route = route + longest_prefix = prefix_len + lowest_metric = route.metric + + return best_route + + def show(self): + """Prints a table of the routes in the RouteTable.""" + table = PrettyTable(["Index", "Address", "Next Hop", "Metric"]) + table.title = f"{self.sys_log.hostname} Route Table" + for index, route in enumerate(self.routes): + network = IPv4Network(f"{route.address}/{route.subnet_mask}") + table.add_row([index, f"{route.address}/{network.prefixlen}", route.next_hop, route.metric]) + print(table) + + +class RouterARPCache(ARPCache): + def __init__(self, sys_log: SysLog, router: Router): + super().__init__(sys_log) + self.router: Router = router + + def process_arp_packet(self, from_nic: NIC, frame: Frame): + """ + Overridden method to process a received ARP packet in a router-specific way. + + :param from_nic: The NIC that received the ARP packet. + :param frame: The original arp frame. + """ + arp_packet = frame.arp + + # ARP Reply + if not arp_packet.request: + for nic in self.router.nics.values(): + if arp_packet.target_ip == nic.ip_address: + # reply to the Router specifically + self.sys_log.info( + f"Received ARP response for {arp_packet.sender_ip} from {arp_packet.sender_mac_addr} via NIC {from_nic}" + ) + self.add_arp_cache_entry( + ip_address=arp_packet.sender_ip, + mac_address=arp_packet.sender_mac_addr, + nic=from_nic, + ) + return + + # Reply for a connected requested + nic = self.get_arp_cache_nic(arp_packet.target_ip) + if nic: + self.sys_log.info(f"Forwarding arp reply for {arp_packet.target_ip}, from {arp_packet.sender_ip}") + arp_packet.sender_mac_addr = nic.mac_address + frame.decrement_ttl() + nic.send_frame(frame) + + # ARP Request + self.sys_log.info( + f"Received ARP request for {arp_packet.target_ip} from " + f"{arp_packet.sender_mac_addr}/{arp_packet.sender_ip} " + ) + # Matched ARP request + self.add_arp_cache_entry(ip_address=arp_packet.sender_ip, mac_address=arp_packet.sender_mac_addr, nic=from_nic) + arp_packet = arp_packet.generate_reply(from_nic.mac_address) + self.send_arp_reply(arp_packet, from_nic) + + # If the target IP matches one of the router's NICs + for nic in self.nics.values(): + if nic.enabled and nic.ip_address == arp_packet.target_ip: + arp_reply = arp_packet.generate_reply(from_nic.mac_address) + self.send_arp_reply(arp_reply, from_nic) + return + + +class RouterICMP(ICMP): + router: Router + + def __init__(self, sys_log: SysLog, arp_cache: ARPCache, router: Router): + super().__init__(sys_log, arp_cache) + self.router = router + + def process_icmp(self, frame: Frame, from_nic: NIC, is_reattempt: bool = False): + if frame.icmp.icmp_type == ICMPType.ECHO_REQUEST: + # determine if request is for router interface or whether it needs to be routed + + for nic in self.router.nics.values(): + if nic.ip_address == frame.ip.dst_ip and nic.enabled: + # reply to the request + self.sys_log.info(f"Received echo request from {frame.ip.src_ip}") + target_mac_address = self.arp.get_arp_cache_mac_address(frame.ip.src_ip) + src_nic = self.arp.get_arp_cache_nic(frame.ip.src_ip) + tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP) + + # Network Layer + ip_packet = IPPacket(src_ip=nic.ip_address, dst_ip=frame.ip.src_ip, protocol=IPProtocol.ICMP) + # Data Link Layer + ethernet_header = EthernetHeader(src_mac_addr=src_nic.mac_address, dst_mac_addr=target_mac_address) + icmp_reply_packet = ICMPPacket( + icmp_type=ICMPType.ECHO_REPLY, + icmp_code=0, + identifier=frame.icmp.identifier, + sequence=frame.icmp.sequence + 1, + ) + frame = Frame(ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, icmp=icmp_reply_packet) + self.sys_log.info(f"Sending echo reply to {frame.ip.dst_ip}") + + src_nic.send_frame(frame) + return + + # Route the frame + self.router.route_frame(frame, from_nic) + elif frame.icmp.icmp_type == ICMPType.ECHO_REPLY: + self.sys_log.info(f"Received echo reply from {frame.ip.src_ip}") + if not self.request_replies.get(frame.icmp.identifier): + self.request_replies[frame.icmp.identifier] = 0 + self.request_replies[frame.icmp.identifier] += 1 class Router(Node): num_ports: int ethernet_ports: Dict[int, NIC] = {} - acl: List = [] - route_table: Dict = {} + acl: AccessControlList + route_table: RouteTable + arp: RouterARPCache + icmp: RouterICMP def __init__(self, hostname: str, num_ports: int = 5, **kwargs): + if not kwargs.get("sys_log"): + kwargs["sys_log"] = SysLog(hostname) + if not kwargs.get("acl"): + kwargs["acl"] = AccessControlList(sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY) + if not kwargs.get("route_table"): + kwargs["route_table"] = RouteTable(sys_log=kwargs["sys_log"]) + if not kwargs.get("arp"): + kwargs["arp"] = RouterARPCache(sys_log=kwargs.get("sys_log"), router=self) + if not kwargs.get("icmp"): + kwargs["icmp"] = RouterICMP(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp"), router=self) super().__init__(hostname=hostname, num_ports=num_ports, **kwargs) - for i in range(1, self.num_ports + 1): nic = NIC(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0") self.connect_nic(nic) self.ethernet_ports[i] = nic + self.arp.nics = self.nics + self.icmp.arp = self.arp + + def _get_port_of_nic(self, target_nic: NIC) -> Optional[int]: + for port, nic in self.ethernet_ports.items(): + if nic == target_nic: + return port + def describe_state(self) -> Dict: pass - def configure_port( - self, - port: int, - ip_address: Union[IPv4Address, str], - subnet_mask: str - ): + def route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None: + if not re_attempt: + # Check if src ip is on network of one of the NICs + nic = self.arp.get_arp_cache_nic(frame.ip.dst_ip) + target_mac = self.arp.get_arp_cache_mac_address(frame.ip.dst_ip) + if not nic: + self.arp.send_arp_request(frame.ip.dst_ip) + return self.route_frame(frame=frame, from_nic=from_nic, re_attempt=True) + for nic in self.nics.values(): + if nic.enabled and frame.ip.dst_ip in nic.ip_network: + from_port = self._get_port_of_nic(from_nic) + to_port = self._get_port_of_nic(nic) + self.sys_log.info(f"Routing frame to internally from port {from_port} to port {to_port}") + frame.decrement_ttl() + frame.ethernet.src_mac_addr = nic.mac_address + frame.ethernet.dst_mac_addr = target_mac + nic.send_frame(frame) + return + else: + self.sys_log.info(f"Destination {frame.ip.dst_ip} is unreachable") + + def receive_frame(self, frame: Frame, from_nic: NIC): + """ + Receive a Frame from the connected NIC and process it. + + Depending on the protocol, the frame is passed to the appropriate handler such as ARP or ICMP, or up to the + SessionManager if no code manager exists. + + :param frame: The Frame being received. + :param from_nic: The NIC that received the frame. + """ + route_frame = False + protocol = frame.ip.protocol + src_ip = frame.ip.src_ip + dst_ip = frame.ip.dst_ip + src_port = None + dst_port = None + if frame.ip.protocol == IPProtocol.TCP: + src_port = frame.tcp.src_port + dst_port = frame.tcp.dst_port + elif frame.ip.protocol == IPProtocol.UDP: + src_port = frame.udp.src_port + dst_port = frame.udp.dst_port + + # Check if it's permitted + permitted, rule = self.acl.is_permitted( + protocol=protocol, src_ip=src_ip, src_port=src_port, dst_ip=dst_ip, dst_port=dst_port + ) + if not permitted: + at_port = self._get_port_of_nic(from_nic) + self.sys_log.info(f"Frame blocked at port {at_port} by rule {rule}") + return + if not self.arp.get_arp_cache_nic(src_ip): + self.arp.add_arp_cache_entry(src_ip, frame.ethernet.src_mac_addr, from_nic) + if frame.ip.protocol == IPProtocol.ICMP: + self.icmp.process_icmp(frame=frame, from_nic=from_nic) + else: + if src_port == Port.ARP: + self.arp.process_arp_packet(from_nic=from_nic, frame=frame) + else: + # All other traffic + route_frame = True + if route_frame: + self.route_frame(frame, from_nic) + + def configure_port(self, port: int, ip_address: Union[IPv4Address, str], subnet_mask: Union[IPv4Address, str]): if not isinstance(ip_address, IPv4Address): ip_address = IPv4Address(ip_address) + if not isinstance(subnet_mask, IPv4Address): + subnet_mask = IPv4Address(subnet_mask) nic = self.ethernet_ports[port] nic.ip_address = ip_address nic.subnet_mask = subnet_mask - self.sys_log.info(f"Configured port {port} with {ip_address=} {subnet_mask=}") + self.sys_log.info(f"Configured port {port} with ip_address={ip_address}/{nic.ip_network.prefixlen}") def enable_port(self, port: int): nic = self.ethernet_ports.get(port) @@ -72,7 +523,7 @@ class Router(Node): def show(self): """Prints a table of the NICs on the Node.""" table = PrettyTable(["Port", "MAC Address", "Address", "Speed", "Status"]) - + table.title = f"{self.hostname} Ethernet Interfaces" for port, nic in self.ethernet_ports.items(): table.add_row( [ diff --git a/tests/integration_tests/network/test_frame_transmission.py b/tests/integration_tests/network/test_frame_transmission.py index d3d6541a..34b76060 100644 --- a/tests/integration_tests/network/test_frame_transmission.py +++ b/tests/integration_tests/network/test_frame_transmission.py @@ -3,14 +3,13 @@ from primaite.simulator.network.hardware.base import Link, NIC, Node, Switch def test_node_to_node_ping(): """Tests two Nodes are able to ping each other.""" - # TODO Add actual checks. Manual check performed for now. node_a = Node(hostname="node_a") - nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", gateway="192.168.0.1") + nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0") node_a.connect_nic(nic_a) node_a.power_on() node_b = Node(hostname="node_b") - nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0", gateway="192.168.0.1") + nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0") node_b.connect_nic(nic_b) node_b.power_on() @@ -23,19 +22,19 @@ def test_multi_nic(): """Tests that Nodes with multiple NICs can ping each other and the data go across the correct links.""" # TODO Add actual checks. Manual check performed for now. node_a = Node(hostname="node_a") - nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", gateway="192.168.0.1") + nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0") node_a.connect_nic(nic_a) node_a.power_on() node_b = Node(hostname="node_b") - nic_b1 = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0", gateway="192.168.0.1") - nic_b2 = NIC(ip_address="10.0.0.12", subnet_mask="255.0.0.0", gateway="10.0.0.1") + nic_b1 = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0") + nic_b2 = NIC(ip_address="10.0.0.12", subnet_mask="255.0.0.0") node_b.connect_nic(nic_b1) node_b.connect_nic(nic_b2) node_b.power_on() node_c = Node(hostname="node_c") - nic_c = NIC(ip_address="10.0.0.13", subnet_mask="255.0.0.0", gateway="10.0.0.1") + nic_c = NIC(ip_address="10.0.0.13", subnet_mask="255.0.0.0") node_c.connect_nic(nic_c) node_c.power_on() @@ -52,22 +51,22 @@ def test_switched_network(): """Tests a larges network of Nodes and Switches with one node pinging another.""" # TODO Add actual checks. Manual check performed for now. pc_a = Node(hostname="pc_a") - nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", gateway="192.168.0.1") + nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0") pc_a.connect_nic(nic_a) pc_a.power_on() pc_b = Node(hostname="pc_b") - nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0", gateway="192.168.0.1") + nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0") pc_b.connect_nic(nic_b) pc_b.power_on() pc_c = Node(hostname="pc_c") - nic_c = NIC(ip_address="192.168.0.12", subnet_mask="255.255.255.0", gateway="192.168.0.1") + nic_c = NIC(ip_address="192.168.0.12", subnet_mask="255.255.255.0") pc_c.connect_nic(nic_c) pc_c.power_on() pc_d = Node(hostname="pc_d") - nic_d = NIC(ip_address="192.168.0.13", subnet_mask="255.255.255.0", gateway="192.168.0.1") + nic_d = NIC(ip_address="192.168.0.13", subnet_mask="255.255.255.0") pc_d.connect_nic(nic_d) pc_d.power_on() diff --git a/tests/integration_tests/network/test_link_connection.py b/tests/integration_tests/network/test_link_connection.py index e08e40b9..ef65f078 100644 --- a/tests/integration_tests/network/test_link_connection.py +++ b/tests/integration_tests/network/test_link_connection.py @@ -4,18 +4,17 @@ from primaite.simulator.network.hardware.base import Link, NIC, Node def test_link_up(): """Tests Nodes, NICs, and Links can all be connected and be in an enabled/up state.""" node_a = Node(hostname="node_a") - nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", gateway="192.168.0.1") + nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0") node_a.connect_nic(nic_a) node_a.power_on() - assert nic_a.enabled node_b = Node(hostname="node_b") - nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0", gateway="192.168.0.1") + nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0") node_b.connect_nic(nic_b) node_b.power_on() - assert nic_b.enabled - link = Link(endpoint_a=nic_a, endpoint_b=nic_b) + assert nic_a.enabled + assert nic_b.enabled assert link.is_up diff --git a/tests/integration_tests/network/test_nic_link_connection.py b/tests/integration_tests/network/test_nic_link_connection.py index 52a0c735..f051d026 100644 --- a/tests/integration_tests/network/test_nic_link_connection.py +++ b/tests/integration_tests/network/test_nic_link_connection.py @@ -8,7 +8,6 @@ def test_link_fails_with_same_nic(): with pytest.raises(ValueError): nic_a = NIC( ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - gateway="192.168.0.1", + subnet_mask="255.255.255.0" ) Link(endpoint_a=nic_a, endpoint_b=nic_a) diff --git a/tests/integration_tests/network/test_routing.py b/tests/integration_tests/network/test_routing.py index cca48c0d..cb420e22 100644 --- a/tests/integration_tests/network/test_routing.py +++ b/tests/integration_tests/network/test_routing.py @@ -1,27 +1,55 @@ -from primaite.simulator.network.hardware.base import Node, NIC, Link -from primaite.simulator.network.hardware.nodes.router import Router +from typing import Tuple + +import pytest + +from primaite.simulator.network.hardware.base import Link, NIC, Node +from primaite.simulator.network.hardware.nodes.router import ACLAction, Router +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port -def test_ping_fails_with_no_route(): - """Tests a larges network of Nodes and Switches with one node pinging another.""" - pc_a = Node(hostname="pc_a") - nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", gateway="192.168.0.1") +@pytest.fixture(scope="function") +def pc_a_pc_b_router_1() -> Tuple[Node, Node, Router]: + pc_a = Node(hostname="pc_a", default_gateway="192.168.0.1") + nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0") pc_a.connect_nic(nic_a) pc_a.power_on() - pc_b = Node(hostname="pc_b") - nic_b = NIC(ip_address="192.168.1.10", subnet_mask="255.255.255.0", gateway="192.168.1.1") + pc_b = Node(hostname="pc_b", default_gateway="192.168.1.1") + nic_b = NIC(ip_address="192.168.1.10", subnet_mask="255.255.255.0") pc_b.connect_nic(nic_b) pc_b.power_on() router_1 = Router(hostname="router_1") + router_1.power_on() + router_1.configure_port(1, "192.168.0.1", "255.255.255.0") router_1.configure_port(2, "192.168.1.1", "255.255.255.0") - router_1.power_on() - router_1.show() + Link(endpoint_a=nic_a, endpoint_b=router_1.ethernet_ports[1]) + Link(endpoint_a=nic_b, endpoint_b=router_1.ethernet_ports[2]) + router_1.enable_port(1) + router_1.enable_port(2) - link_nic_a_router_1 = Link(endpoint_a=nic_a, endpoint_b=router_1.ethernet_ports[1]) - link_nic_b_router_1 = Link(endpoint_a=nic_b, endpoint_b=router_1.ethernet_ports[2]) - router_1.power_on() - #assert pc_a.ping("192.168.1.10") \ No newline at end of file + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) + + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + return pc_a, pc_b, router_1 + + +def test_ping_default_gateway(pc_a_pc_b_router_1): + pc_a, pc_b, router_1 = pc_a_pc_b_router_1 + + assert pc_a.ping(pc_a.default_gateway) + + +def test_ping_other_router_port(pc_a_pc_b_router_1): + pc_a, pc_b, router_1 = pc_a_pc_b_router_1 + + assert pc_a.ping(pc_b.default_gateway) + + +def test_host_on_other_subnet(pc_a_pc_b_router_1): + pc_a, pc_b, router_1 = pc_a_pc_b_router_1 + + assert pc_a.ping("192.168.1.10") diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/__init__.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py new file mode 100644 index 00000000..48d0fc06 --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py @@ -0,0 +1,104 @@ +from ipaddress import IPv4Address + +from primaite.simulator.network.hardware.nodes.router import AccessControlList, ACLAction, ACLRule +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port + + +def test_add_rule(): + acl = AccessControlList() + acl.add_rule( + action=ACLAction.PERMIT, + protocol=IPProtocol.TCP, + src_ip=IPv4Address("192.168.1.1"), + src_port=Port(8080), + dst_ip=IPv4Address("192.168.1.2"), + dst_port=Port(80), + position=1, + ) + assert acl.acl[1].action == ACLAction.PERMIT + assert acl.acl[1].protocol == IPProtocol.TCP + assert acl.acl[1].src_ip == IPv4Address("192.168.1.1") + assert acl.acl[1].src_port == Port(8080) + assert acl.acl[1].dst_ip == IPv4Address("192.168.1.2") + assert acl.acl[1].dst_port == Port(80) + + +def test_remove_rule(): + acl = AccessControlList() + acl.add_rule( + action=ACLAction.PERMIT, + protocol=IPProtocol.TCP, + src_ip=IPv4Address("192.168.1.1"), + src_port=Port(8080), + dst_ip=IPv4Address("192.168.1.2"), + dst_port=Port(80), + position=1, + ) + acl.remove_rule(1) + assert not acl.acl[1] + + +def test_rules(): + acl = AccessControlList() + acl.add_rule( + action=ACLAction.PERMIT, + protocol=IPProtocol.TCP, + src_ip=IPv4Address("192.168.1.1"), + src_port=Port(8080), + dst_ip=IPv4Address("192.168.1.2"), + dst_port=Port(80), + position=1, + ) + acl.add_rule( + action=ACLAction.DENY, + protocol=IPProtocol.TCP, + src_ip=IPv4Address("192.168.1.3"), + src_port=Port(8080), + dst_ip=IPv4Address("192.168.1.4"), + dst_port=Port(80), + position=2, + ) + assert acl.is_permitted( + protocol=IPProtocol.TCP, + src_ip=IPv4Address("192.168.1.1"), + src_port=Port(8080), + dst_ip=IPv4Address("192.168.1.2"), + dst_port=Port(80), + ) + assert not acl.is_permitted( + protocol=IPProtocol.TCP, + src_ip=IPv4Address("192.168.1.3"), + src_port=Port(8080), + dst_ip=IPv4Address("192.168.1.4"), + dst_port=Port(80), + ) + + +def test_default_rule(): + acl = AccessControlList() + acl.add_rule( + action=ACLAction.PERMIT, + protocol=IPProtocol.TCP, + src_ip=IPv4Address("192.168.1.1"), + src_port=Port(8080), + dst_ip=IPv4Address("192.168.1.2"), + dst_port=Port(80), + position=1, + ) + acl.add_rule( + action=ACLAction.DENY, + protocol=IPProtocol.TCP, + src_ip=IPv4Address("192.168.1.3"), + src_port=Port(8080), + dst_ip=IPv4Address("192.168.1.4"), + dst_port=Port(80), + position=2, + ) + assert not acl.is_permitted( + protocol=IPProtocol.UDP, + src_ip=IPv4Address("192.168.1.5"), + src_port=Port(8080), + dst_ip=IPv4Address("192.168.1.12"), + dst_port=Port(80), + ) diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_nic.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_nic.py index dc508508..11873128 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_nic.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_nic.py @@ -32,10 +32,8 @@ def test_nic_ip_address_type_conversion(): nic = NIC( ip_address="192.168.1.2", subnet_mask="255.255.255.0", - gateway="192.168.0.1", ) assert isinstance(nic.ip_address, IPv4Address) - assert isinstance(nic.gateway, IPv4Address) def test_nic_deserialize(): @@ -43,7 +41,6 @@ def test_nic_deserialize(): nic = NIC( ip_address="192.168.1.2", subnet_mask="255.255.255.0", - gateway="192.168.0.1", ) nic_json = nic.model_dump_json() @@ -51,21 +48,10 @@ def test_nic_deserialize(): assert nic == deserialized_nic -def test_nic_ip_address_as_gateway_fails(): - """Tests NIC creation fails if ip address is the same as the gateway.""" - with pytest.raises(ValueError): - NIC( - ip_address="192.168.0.1", - subnet_mask="255.255.255.0", - gateway="192.168.0.1", - ) - - def test_nic_ip_address_as_network_address_fails(): """Tests NIC creation fails if ip address and subnet mask are a network address.""" with pytest.raises(ValueError): NIC( ip_address="192.168.0.0", subnet_mask="255.255.255.0", - gateway="192.168.0.1", )