#1800 - Added ACL and routing classes.

- Added .show() methods to new router classes to enable inspection of the components as you would a real router.
- Removed gateway from the NIC and added default_gateway to Node so that Node has a single default gateway.
- Added some routing tests to check that ping can be performed when router between subnets.
This commit is contained in:
Chris McCarthy
2023-08-30 21:38:55 +01:00
parent c6f71600fc
commit 1bf51c7741
9 changed files with 739 additions and 119 deletions

View File

@@ -77,12 +77,10 @@ class NIC(SimComponent):
ip_address: IPv4Address
"The IP address assigned to the NIC for communication on an IP-based network."
subnet_mask: str
subnet_mask: IPv4Address
"The subnet mask assigned to the NIC."
gateway: IPv4Address
"The default gateway IP address for forwarding network traffic to other networks. Randomly generated upon creation."
mac_address: str
"The MAC address of the NIC. Defaults to a randomly set MAC address."
"The MAC address of the NIC. Defaults to a randomly set MAC address. Randomly generated upon creation."
speed: int = 100
"The speed of the NIC in Mbps. Default is 100 Mbps."
mtu: int = 1500
@@ -111,16 +109,10 @@ class NIC(SimComponent):
"""
if not isinstance(kwargs["ip_address"], IPv4Address):
kwargs["ip_address"] = IPv4Address(kwargs["ip_address"])
if not isinstance(kwargs["gateway"], IPv4Address):
kwargs["gateway"] = IPv4Address(kwargs["gateway"])
if "mac_address" not in kwargs:
kwargs["mac_address"] = generate_mac_address()
super().__init__(**kwargs)
if self.ip_address == self.gateway:
msg = f"NIC ip address {self.ip_address} cannot be the same as the gateway {self.gateway}"
_LOGGER.error(msg)
raise ValueError(msg)
if self.ip_network.network_address == self.ip_address:
msg = (
f"Failed to set IP address {self.ip_address} and subnet mask {self.subnet_mask} as it is a "
@@ -274,6 +266,9 @@ class NIC(SimComponent):
if frame.ethernet.dst_mac_addr == self.mac_address or frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff":
self.connected_node.receive_frame(frame=frame, from_nic=self)
return True
else:
self.connected_node.sys_log.info("Dropping frame not for me")
print(frame)
return False
def __str__(self) -> str:
@@ -567,7 +562,21 @@ class ARPCache:
self.arp: Dict[IPv4Address, ARPEntry] = {}
self.nics: Dict[str, "NIC"] = {}
def _add_arp_cache_entry(self, ip_address: IPv4Address, mac_address: str, nic: NIC):
def show(self):
"""Prints a table of ARC Cache."""
table = PrettyTable(["IP Address", "MAC Address", "Via"])
table.title = f"{self.sys_log.hostname} ARP Cache"
for ip, arp in self.arp.items():
table.add_row(
[
str(ip),
arp.mac_address,
self.nics[arp.nic_uuid].mac_address,
]
)
print(table)
def add_arp_cache_entry(self, ip_address: IPv4Address, mac_address: str, nic: NIC, override: bool = False):
"""
Add an ARP entry to the cache.
@@ -575,9 +584,14 @@ class ARPCache:
:param mac_address: The MAC address associated with the IP address.
:param nic: The NIC through which the NIC with the IP address is reachable.
"""
self.sys_log.info(f"Adding ARP cache entry for {mac_address}/{ip_address} via NIC {nic}")
arp_entry = ARPEntry(mac_address=mac_address, nic_uuid=nic.uuid)
self.arp[ip_address] = arp_entry
for _nic in self.nics.values():
if _nic.ip_address == ip_address:
return
if override or not self.arp.get(ip_address):
self.sys_log.info(f"Adding ARP cache entry for {mac_address}/{ip_address} via NIC {nic}")
arp_entry = ARPEntry(mac_address=mac_address, nic_uuid=nic.uuid)
self.arp[ip_address] = arp_entry
def _remove_arp_cache_entry(self, ip_address: IPv4Address):
"""
@@ -607,6 +621,7 @@ class ARPCache:
:return: The NIC associated with the IP address, or None if not found.
"""
arp_entry = self.arp.get(ip_address)
if arp_entry:
return self.nics[arp_entry.nic_uuid]
@@ -641,6 +656,29 @@ class ARPCache:
frame = Frame(ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, arp=arp_packet)
nic.send_frame(frame)
def send_arp_reply(self, arp_reply: ARPPacket, from_nic: NIC):
"""
Send an ARP reply back through the NIC it came from.
:param arp_reply: The ARP reply to send.
:param from_nic: The NIC to send the ARP reply from.
"""
self.sys_log.info(
f"Sending ARP reply from {arp_reply.sender_mac_addr}/{arp_reply.sender_ip} "
f"to {arp_reply.target_ip}/{arp_reply.target_mac_addr} "
)
tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP)
ip_packet = IPPacket(
src_ip=arp_reply.sender_ip,
dst_ip=arp_reply.target_ip,
)
ethernet_header = EthernetHeader(src_mac_addr=arp_reply.sender_mac_addr, dst_mac_addr=arp_reply.target_mac_addr)
frame = Frame(ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, arp=arp_reply)
from_nic.send_frame(frame)
def process_arp_packet(self, from_nic: NIC, arp_packet: ARPPacket):
"""
Process a received ARP packet, handling both ARP requests and responses.
@@ -656,7 +694,7 @@ class ARPCache:
self.sys_log.info(
f"Received ARP response for {arp_packet.sender_ip} from {arp_packet.sender_mac_addr} via NIC {from_nic}"
)
self._add_arp_cache_entry(
self.add_arp_cache_entry(
ip_address=arp_packet.sender_ip, mac_address=arp_packet.sender_mac_addr, nic=from_nic
)
return
@@ -673,26 +711,13 @@ class ARPCache:
return
# Matched ARP request
self._add_arp_cache_entry(ip_address=arp_packet.sender_ip, mac_address=arp_packet.sender_mac_addr, nic=from_nic)
self.add_arp_cache_entry(ip_address=arp_packet.sender_ip, mac_address=arp_packet.sender_mac_addr, nic=from_nic)
arp_packet = arp_packet.generate_reply(from_nic.mac_address)
self.sys_log.info(
f"Sending ARP reply from {arp_packet.sender_mac_addr}/{arp_packet.sender_ip} "
f"to {arp_packet.target_ip}/{arp_packet.target_mac_addr} "
)
self.send_arp_reply(arp_packet, from_nic)
tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP)
def __contains__(self, item) -> bool:
return item in self.arp
# Network Layer
ip_packet = IPPacket(
src_ip=arp_packet.sender_ip,
dst_ip=arp_packet.target_ip,
)
# Data Link Layer
ethernet_header = EthernetHeader(
src_mac_addr=arp_packet.sender_mac_addr, dst_mac_addr=arp_packet.target_mac_addr
)
frame = Frame(ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, arp=arp_packet)
from_nic.send_frame(frame)
class ICMP:
"""
@@ -712,8 +737,7 @@ class ICMP:
self.arp: ARPCache = arp_cache
self.request_replies = {}
def process_icmp(self, frame: Frame):
def process_icmp(self, frame: Frame, from_nic: NIC, is_reattempt: bool = False):
"""
Process an ICMP packet, including handling echo requests and replies.
@@ -722,7 +746,15 @@ class ICMP:
if frame.icmp.icmp_type == ICMPType.ECHO_REQUEST:
self.sys_log.info(f"Received echo request from {frame.ip.src_ip}")
target_mac_address = self.arp.get_arp_cache_mac_address(frame.ip.src_ip)
src_nic = self.arp.get_arp_cache_nic(frame.ip.src_ip)
if not src_nic:
print(self.sys_log.hostname)
print(frame.ip.src_ip)
self.arp.show()
self.arp.send_arp_request(frame.ip.src_ip)
self.process_icmp(frame=frame, from_nic=from_nic, is_reattempt=True)
return
tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP)
# Network Layer
@@ -737,6 +769,7 @@ class ICMP:
)
frame = Frame(ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, icmp=icmp_reply_packet)
self.sys_log.info(f"Sending echo reply to {frame.ip.dst_ip}")
src_nic.send_frame(frame)
elif frame.icmp.icmp_type == ICMPType.ECHO_REPLY:
self.sys_log.info(f"Received echo reply from {frame.ip.src_ip}")
@@ -745,7 +778,7 @@ class ICMP:
self.request_replies[frame.icmp.identifier] += 1
def ping(
self, target_ip_address: IPv4Address, sequence: int = 0, identifier: Optional[int] = None
self, target_ip_address: IPv4Address, sequence: int = 0, identifier: Optional[int] = None, pings: int = 4
) -> Tuple[int, Union[int, None]]:
"""
Send an ICMP echo request (ping) to a target IP address and manage the sequence and identifier.
@@ -757,13 +790,21 @@ class ICMP:
was not found in the ARP cache.
"""
nic = self.arp.get_arp_cache_nic(target_ip_address)
# TODO: Eventually this ARP request needs to be done elsewhere. It's not the resonsibility of the
# TODO: Eventually this ARP request needs to be done elsewhere. It's not the responsibility of the
# ping function to handle ARP lookups
# Already tried once and cannot get ARP entry, stop trying
if sequence == -1:
if not nic:
return 4, None
else:
sequence = 0
# No existing ARP entry
if not nic:
self.sys_log.info(f"No entry in ARP cache for {target_ip_address}")
self.arp.send_arp_request(target_ip_address)
return 0, None
return -1, None
# ARP entry exists
sequence += 1
@@ -812,6 +853,8 @@ class Node(SimComponent):
hostname: str
"The node hostname on the network."
default_gateway: Optional[IPv4Address] = None
"The default gateway IP address for forwarding network traffic to other networks."
operating_state: NodeOperatingState = NodeOperatingState.OFF
"The hardware state of the node."
nics: Dict[str, NIC] = {}
@@ -843,9 +886,12 @@ class Node(SimComponent):
This method initializes the ARP cache, ICMP handler, session manager, and software manager if they are not
provided.
"""
if kwargs.get("default_gateway"):
if not isinstance(kwargs["default_gateway"], IPv4Address):
kwargs["default_gateway"] = IPv4Address(kwargs["default_gateway"])
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(kwargs["hostname"])
if not kwargs.get("arp_cache"):
if not kwargs.get("arp"):
kwargs["arp"] = ARPCache(sys_log=kwargs.get("sys_log"))
if not kwargs.get("icmp"):
kwargs["icmp"] = ICMP(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp"))
@@ -886,10 +932,8 @@ class Node(SimComponent):
def show(self):
"""Prints a table of the NICs on the Node."""
from prettytable import PrettyTable
table = PrettyTable(["MAC Address", "Address", "Default Gateway", "Speed", "Status"])
table.title = f"{self.hostname} Network Interface Cards"
for nic in self.nics.values():
table.add_row(
[
@@ -967,13 +1011,18 @@ class Node(SimComponent):
"""
if not isinstance(target_ip_address, IPv4Address):
target_ip_address = IPv4Address(target_ip_address)
if target_ip_address.is_loopback:
self.sys_log.info("Pinging loopback address")
return any(nic.enabled for nic in self.nics.values())
if self.operating_state == NodeOperatingState.ON:
self.sys_log.info(f"Attempting to ping {target_ip_address}")
sequence, identifier = 0, None
while sequence < pings:
sequence, identifier = self.icmp.ping(target_ip_address, sequence, identifier)
passed = self.icmp.request_replies[identifier] == pings
self.icmp.request_replies.pop(identifier)
sequence, identifier = self.icmp.ping(target_ip_address, sequence, identifier, pings)
request_replies = self.icmp.request_replies.get(identifier)
passed = request_replies == pings
if request_replies:
self.icmp.request_replies.pop(identifier)
return passed
self.sys_log.info("Ping failed as the node is turned off")
return False
@@ -997,13 +1046,18 @@ class Node(SimComponent):
:param frame: The Frame being received.
:param from_nic: The NIC that received the frame.
"""
if frame.ip:
if frame.ip.src_ip in self.arp:
self.arp.add_arp_cache_entry(
ip_address=frame.ip.src_ip, mac_address=frame.ethernet.src_mac_addr, nic=from_nic
)
if frame.ip.protocol == IPProtocol.TCP:
if frame.tcp.src_port == Port.ARP:
self.arp.process_arp_packet(from_nic=from_nic, arp_packet=frame.arp)
elif frame.ip.protocol == IPProtocol.UDP:
pass
elif frame.ip.protocol == IPProtocol.ICMP:
self.icmp.process_icmp(frame=frame)
self.icmp.process_icmp(frame=frame, from_nic=from_nic)
class Switch(Node):
@@ -1027,7 +1081,7 @@ class Switch(Node):
def show(self):
"""Prints a table of the SwitchPorts on the Switch."""
table = PrettyTable(["Port", "MAC Address", "Speed", "Status"])
table.title = f"{self.hostname} Switch Ports"
for port_num, port in self.switch_ports.items():
table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"])
print(table)

View File

@@ -1,63 +1,514 @@
from enum import Enum
from ipaddress import IPv4Address
from typing import Dict, List, Union
from __future__ import annotations
from enum import Enum
from ipaddress import IPv4Address, IPv4Network
from typing import Dict, List, Optional, Tuple, Union
from primaite.simulator.core import SimComponent
from primaite.simulator.network.hardware.base import Node, NIC
from prettytable import PrettyTable
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.core import SimComponent
from primaite.simulator.network.hardware.base import ARPCache, ICMP, NIC, Node
from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader
from primaite.simulator.system.core.sys_log import SysLog
class ACLAction(Enum):
"""Enum for defining the ACL action types."""
DENY = 0
PERMIT = 1
class ACLRule(SimComponent):
action: ACLAction
protocol: IPProtocol
src_ip: IPv4Address
src_wildcard: IPv4Address = IPv4Address("0.0.0.0")
src_port: Port
dst_ip: IPv4Address
dst_port: Port
def describe_state(self) -> Dict:
pass
action: ACLAction = ACLAction.DENY
protocol: Optional[IPProtocol] = None
src_ip: Optional[IPv4Address] = None
src_port: Optional[Port] = None
dst_ip: Optional[IPv4Address] = None
dst_port: Optional[Port] = None
def __str__(self) -> str:
rule_strings = []
for key, value in self.model_dump(exclude={"uuid", "action_manager"}).items():
if value is None:
value = "ANY"
if isinstance(value, Enum):
rule_strings.append(f"{key}={value.name}")
else:
rule_strings.append(f"{key}={value}")
return ", ".join(rule_strings)
class RouteTableEntry(SimComponent):
pass
class AccessControlList(SimComponent):
sys_log: SysLog
implicit_action: ACLAction
implicit_rule: ACLRule
max_acl_rules: int = 25
_acl: List[Optional[ACLRule]] = [None] * 24
def __init__(self, **kwargs) -> None:
if not kwargs.get("implicit_action"):
kwargs["implicit_action"] = ACLAction.DENY
if not kwargs.get("max_acl_rules"):
kwargs["max_acl_rules"] = 25
kwargs["implicit_rule"] = ACLRule(action=kwargs["implicit_action"])
kwargs["_acl"] = [None] * (kwargs["max_acl_rules"] - 1)
super().__init__(**kwargs)
def describe_state(self) -> Dict:
pass
@property
def acl(self) -> List[Optional[ACLRule]]:
return self._acl
def add_rule(
self,
action: ACLAction,
protocol: Optional[IPProtocol] = None,
src_ip: Optional[Union[str, IPv4Address]] = None,
src_port: Optional[Port] = None,
dst_ip: Optional[Union[str, IPv4Address]] = None,
dst_port: Optional[Port] = None,
position: int = 0,
) -> None:
if isinstance(src_ip, str):
src_ip = IPv4Address(src_ip)
if isinstance(dst_ip, str):
dst_ip = IPv4Address(dst_ip)
if 0 <= position < self.max_acl_rules:
self._acl[position] = ACLRule(
action=action, src_ip=src_ip, dst_ip=dst_ip, protocol=protocol, src_port=src_port, dst_port=dst_port
)
else:
raise ValueError(f"Position {position} is out of bounds.")
def remove_rule(self, position: int) -> None:
if 0 <= position < self.max_acl_rules:
self._acl[position] = None
else:
raise ValueError(f"Position {position} is out of bounds.")
def is_permitted(
self,
protocol: IPProtocol,
src_ip: Union[str, IPv4Address],
src_port: Optional[Port],
dst_ip: Union[str, IPv4Address],
dst_port: Optional[Port],
) -> Tuple[bool, Optional[Union[str, ACLRule]]]:
if not isinstance(src_ip, IPv4Address):
src_ip = IPv4Address(src_ip)
if not isinstance(dst_ip, IPv4Address):
dst_ip = IPv4Address(dst_ip)
for rule in self._acl:
if not rule:
continue
if (
(rule.src_ip == src_ip or rule.src_ip is None)
and (rule.dst_ip == dst_ip or rule.dst_ip is None)
and (rule.protocol == protocol or rule.protocol is None)
and (rule.src_port == src_port or rule.src_port is None)
and (rule.dst_port == dst_port or rule.dst_port is None)
):
return rule.action == ACLAction.PERMIT, rule
return self.implicit_action == ACLAction.PERMIT, f"Implicit {self.implicit_action.name}"
def get_relevant_rules(
self,
protocol: IPProtocol,
src_ip: Union[str, IPv4Address],
src_port: Port,
dst_ip: Union[str, IPv4Address],
dst_port: Port,
) -> List[ACLRule]:
if not isinstance(src_ip, IPv4Address):
src_ip = IPv4Address(src_ip)
if not isinstance(dst_ip, IPv4Address):
dst_ip = IPv4Address(dst_ip)
relevant_rules = []
for rule in self._acl:
if rule is None:
continue
if (
(rule.src_ip == src_ip or rule.src_ip is None)
or (rule.dst_ip == dst_ip or rule.dst_ip is None)
or (rule.protocol == protocol or rule.protocol is None)
or (rule.src_port == src_port or rule.src_port is None)
or (rule.dst_port == dst_port or rule.dst_port is None)
):
relevant_rules.append(rule)
return relevant_rules
def show(self):
"""Prints a table of the routes in the RouteTable."""
"""
action: ACLAction
protocol: Optional[IPProtocol]
src_ip: Optional[IPv4Address]
src_port: Optional[Port]
dst_ip: Optional[IPv4Address]
dst_port: Optional[Port]
"""
table = PrettyTable(["Index", "Action", "Protocol", "Src IP", "Src Port", "Dst IP", "Dst Port"])
table.title = f"{self.sys_log.hostname} Access Control List"
for index, rule in enumerate(self.acl + [self.implicit_rule]):
if rule:
table.add_row(
[
index,
rule.action.name if rule.action else "ANY",
rule.protocol.name if rule.protocol else "ANY",
rule.src_ip if rule.src_ip else "ANY",
f"{rule.src_port.value} ({rule.src_port.name})" if rule.src_port else "ANY",
rule.dst_ip if rule.dst_ip else "ANY",
f"{rule.dst_port.value} ({rule.dst_port.name})" if rule.dst_port else "ANY",
]
)
print(table)
class RouteEntry(SimComponent):
"""
Represents a single entry in a routing table.
Attributes:
address (IPv4Address): The destination IP address or network address.
subnet_mask (IPv4Address): The subnet mask for the network.
next_hop (IPv4Address): The next hop IP address to which packets should be forwarded.
metric (int): The cost metric for this route. Default is 0.0.
Example:
>>> entry = RouteEntry(
... IPv4Address("192.168.1.0"),
... IPv4Address("255.255.255.0"),
... IPv4Address("192.168.2.1"),
... metric=5
... )
"""
address: IPv4Address
"The destination IP address or network address."
subnet_mask: IPv4Address
"The subnet mask for the network."
next_hop: IPv4Address
"The next hop IP address to which packets should be forwarded."
metric: float = 0.0
"The cost metric for this route. Default is 0.0."
def __init__(self, **kwargs):
for key in {"address", "subnet_mask", "next_hop"}:
if not isinstance(kwargs[key], IPv4Address):
kwargs[key] = IPv4Address(kwargs[key])
super().__init__(**kwargs)
def describe_state(self) -> Dict:
pass
class RouteTable(SimComponent):
"""
Represents a routing table holding multiple route entries.
Attributes:
routes (List[RouteEntry]): A list of RouteEntry objects.
Methods:
add_route: Add a route to the routing table.
find_best_route: Find the best route for a given destination IP.
Example:
>>> rt = RouteTable()
>>> rt.add_route(
... RouteEntry(
... IPv4Address("192.168.1.0"),
... IPv4Address("255.255.255.0"),
... IPv4Address("192.168.2.1"),
... metric=5
... )
... )
>>> best_route = rt.find_best_route(IPv4Address("192.168.1.34"))
"""
routes: List[RouteEntry] = []
sys_log: SysLog
def describe_state(self) -> Dict:
pass
def add_route(
self,
address: Union[IPv4Address, str],
subnet_mask: Union[IPv4Address, str],
next_hop: Union[IPv4Address, str],
metric: float = 0.0,
):
"""Add a route to the routing table.
:param route: A RouteEntry object representing the route.
"""
for key in {address, subnet_mask, next_hop}:
if not isinstance(key, IPv4Address):
key = IPv4Address(key)
route = RouteEntry(address=address, subnet_mask=subnet_mask, next_hop=next_hop, metric=metric)
self.routes.append(route)
def find_best_route(self, destination_ip: Union[str, IPv4Address]) -> Optional[RouteEntry]:
"""
Find the best route for a given destination IP.
:param destination_ip: The destination IPv4Address to find the route for.
:return: The best matching RouteEntry, or None if no route matches.
The algorithm uses Longest Prefix Match and considers metrics to find the best route.
"""
if not isinstance(destination_ip, IPv4Address):
destination_ip = IPv4Address(destination_ip)
best_route = None
longest_prefix = -1
lowest_metric = float("inf") # Initialise at infinity as any other number we compare to it will be smaller
for route in self.routes:
route_network = IPv4Network(f"{route.address}/{route.subnet_mask}", strict=False)
prefix_len = route_network.prefixlen
if destination_ip in route_network:
if prefix_len > longest_prefix or (prefix_len == longest_prefix and route.metric < lowest_metric):
best_route = route
longest_prefix = prefix_len
lowest_metric = route.metric
return best_route
def show(self):
"""Prints a table of the routes in the RouteTable."""
table = PrettyTable(["Index", "Address", "Next Hop", "Metric"])
table.title = f"{self.sys_log.hostname} Route Table"
for index, route in enumerate(self.routes):
network = IPv4Network(f"{route.address}/{route.subnet_mask}")
table.add_row([index, f"{route.address}/{network.prefixlen}", route.next_hop, route.metric])
print(table)
class RouterARPCache(ARPCache):
def __init__(self, sys_log: SysLog, router: Router):
super().__init__(sys_log)
self.router: Router = router
def process_arp_packet(self, from_nic: NIC, frame: Frame):
"""
Overridden method to process a received ARP packet in a router-specific way.
:param from_nic: The NIC that received the ARP packet.
:param frame: The original arp frame.
"""
arp_packet = frame.arp
# ARP Reply
if not arp_packet.request:
for nic in self.router.nics.values():
if arp_packet.target_ip == nic.ip_address:
# reply to the Router specifically
self.sys_log.info(
f"Received ARP response for {arp_packet.sender_ip} from {arp_packet.sender_mac_addr} via NIC {from_nic}"
)
self.add_arp_cache_entry(
ip_address=arp_packet.sender_ip,
mac_address=arp_packet.sender_mac_addr,
nic=from_nic,
)
return
# Reply for a connected requested
nic = self.get_arp_cache_nic(arp_packet.target_ip)
if nic:
self.sys_log.info(f"Forwarding arp reply for {arp_packet.target_ip}, from {arp_packet.sender_ip}")
arp_packet.sender_mac_addr = nic.mac_address
frame.decrement_ttl()
nic.send_frame(frame)
# ARP Request
self.sys_log.info(
f"Received ARP request for {arp_packet.target_ip} from "
f"{arp_packet.sender_mac_addr}/{arp_packet.sender_ip} "
)
# Matched ARP request
self.add_arp_cache_entry(ip_address=arp_packet.sender_ip, mac_address=arp_packet.sender_mac_addr, nic=from_nic)
arp_packet = arp_packet.generate_reply(from_nic.mac_address)
self.send_arp_reply(arp_packet, from_nic)
# If the target IP matches one of the router's NICs
for nic in self.nics.values():
if nic.enabled and nic.ip_address == arp_packet.target_ip:
arp_reply = arp_packet.generate_reply(from_nic.mac_address)
self.send_arp_reply(arp_reply, from_nic)
return
class RouterICMP(ICMP):
router: Router
def __init__(self, sys_log: SysLog, arp_cache: ARPCache, router: Router):
super().__init__(sys_log, arp_cache)
self.router = router
def process_icmp(self, frame: Frame, from_nic: NIC, is_reattempt: bool = False):
if frame.icmp.icmp_type == ICMPType.ECHO_REQUEST:
# determine if request is for router interface or whether it needs to be routed
for nic in self.router.nics.values():
if nic.ip_address == frame.ip.dst_ip and nic.enabled:
# reply to the request
self.sys_log.info(f"Received echo request from {frame.ip.src_ip}")
target_mac_address = self.arp.get_arp_cache_mac_address(frame.ip.src_ip)
src_nic = self.arp.get_arp_cache_nic(frame.ip.src_ip)
tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP)
# Network Layer
ip_packet = IPPacket(src_ip=nic.ip_address, dst_ip=frame.ip.src_ip, protocol=IPProtocol.ICMP)
# Data Link Layer
ethernet_header = EthernetHeader(src_mac_addr=src_nic.mac_address, dst_mac_addr=target_mac_address)
icmp_reply_packet = ICMPPacket(
icmp_type=ICMPType.ECHO_REPLY,
icmp_code=0,
identifier=frame.icmp.identifier,
sequence=frame.icmp.sequence + 1,
)
frame = Frame(ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, icmp=icmp_reply_packet)
self.sys_log.info(f"Sending echo reply to {frame.ip.dst_ip}")
src_nic.send_frame(frame)
return
# Route the frame
self.router.route_frame(frame, from_nic)
elif frame.icmp.icmp_type == ICMPType.ECHO_REPLY:
self.sys_log.info(f"Received echo reply from {frame.ip.src_ip}")
if not self.request_replies.get(frame.icmp.identifier):
self.request_replies[frame.icmp.identifier] = 0
self.request_replies[frame.icmp.identifier] += 1
class Router(Node):
num_ports: int
ethernet_ports: Dict[int, NIC] = {}
acl: List = []
route_table: Dict = {}
acl: AccessControlList
route_table: RouteTable
arp: RouterARPCache
icmp: RouterICMP
def __init__(self, hostname: str, num_ports: int = 5, **kwargs):
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(hostname)
if not kwargs.get("acl"):
kwargs["acl"] = AccessControlList(sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY)
if not kwargs.get("route_table"):
kwargs["route_table"] = RouteTable(sys_log=kwargs["sys_log"])
if not kwargs.get("arp"):
kwargs["arp"] = RouterARPCache(sys_log=kwargs.get("sys_log"), router=self)
if not kwargs.get("icmp"):
kwargs["icmp"] = RouterICMP(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp"), router=self)
super().__init__(hostname=hostname, num_ports=num_ports, **kwargs)
for i in range(1, self.num_ports + 1):
nic = NIC(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")
self.connect_nic(nic)
self.ethernet_ports[i] = nic
self.arp.nics = self.nics
self.icmp.arp = self.arp
def _get_port_of_nic(self, target_nic: NIC) -> Optional[int]:
for port, nic in self.ethernet_ports.items():
if nic == target_nic:
return port
def describe_state(self) -> Dict:
pass
def configure_port(
self,
port: int,
ip_address: Union[IPv4Address, str],
subnet_mask: str
):
def route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
if not re_attempt:
# Check if src ip is on network of one of the NICs
nic = self.arp.get_arp_cache_nic(frame.ip.dst_ip)
target_mac = self.arp.get_arp_cache_mac_address(frame.ip.dst_ip)
if not nic:
self.arp.send_arp_request(frame.ip.dst_ip)
return self.route_frame(frame=frame, from_nic=from_nic, re_attempt=True)
for nic in self.nics.values():
if nic.enabled and frame.ip.dst_ip in nic.ip_network:
from_port = self._get_port_of_nic(from_nic)
to_port = self._get_port_of_nic(nic)
self.sys_log.info(f"Routing frame to internally from port {from_port} to port {to_port}")
frame.decrement_ttl()
frame.ethernet.src_mac_addr = nic.mac_address
frame.ethernet.dst_mac_addr = target_mac
nic.send_frame(frame)
return
else:
self.sys_log.info(f"Destination {frame.ip.dst_ip} is unreachable")
def receive_frame(self, frame: Frame, from_nic: NIC):
"""
Receive a Frame from the connected NIC and process it.
Depending on the protocol, the frame is passed to the appropriate handler such as ARP or ICMP, or up to the
SessionManager if no code manager exists.
:param frame: The Frame being received.
:param from_nic: The NIC that received the frame.
"""
route_frame = False
protocol = frame.ip.protocol
src_ip = frame.ip.src_ip
dst_ip = frame.ip.dst_ip
src_port = None
dst_port = None
if frame.ip.protocol == IPProtocol.TCP:
src_port = frame.tcp.src_port
dst_port = frame.tcp.dst_port
elif frame.ip.protocol == IPProtocol.UDP:
src_port = frame.udp.src_port
dst_port = frame.udp.dst_port
# Check if it's permitted
permitted, rule = self.acl.is_permitted(
protocol=protocol, src_ip=src_ip, src_port=src_port, dst_ip=dst_ip, dst_port=dst_port
)
if not permitted:
at_port = self._get_port_of_nic(from_nic)
self.sys_log.info(f"Frame blocked at port {at_port} by rule {rule}")
return
if not self.arp.get_arp_cache_nic(src_ip):
self.arp.add_arp_cache_entry(src_ip, frame.ethernet.src_mac_addr, from_nic)
if frame.ip.protocol == IPProtocol.ICMP:
self.icmp.process_icmp(frame=frame, from_nic=from_nic)
else:
if src_port == Port.ARP:
self.arp.process_arp_packet(from_nic=from_nic, frame=frame)
else:
# All other traffic
route_frame = True
if route_frame:
self.route_frame(frame, from_nic)
def configure_port(self, port: int, ip_address: Union[IPv4Address, str], subnet_mask: Union[IPv4Address, str]):
if not isinstance(ip_address, IPv4Address):
ip_address = IPv4Address(ip_address)
if not isinstance(subnet_mask, IPv4Address):
subnet_mask = IPv4Address(subnet_mask)
nic = self.ethernet_ports[port]
nic.ip_address = ip_address
nic.subnet_mask = subnet_mask
self.sys_log.info(f"Configured port {port} with {ip_address=} {subnet_mask=}")
self.sys_log.info(f"Configured port {port} with ip_address={ip_address}/{nic.ip_network.prefixlen}")
def enable_port(self, port: int):
nic = self.ethernet_ports.get(port)
@@ -72,7 +523,7 @@ class Router(Node):
def show(self):
"""Prints a table of the NICs on the Node."""
table = PrettyTable(["Port", "MAC Address", "Address", "Speed", "Status"])
table.title = f"{self.hostname} Ethernet Interfaces"
for port, nic in self.ethernet_ports.items():
table.add_row(
[

View File

@@ -3,14 +3,13 @@ from primaite.simulator.network.hardware.base import Link, NIC, Node, Switch
def test_node_to_node_ping():
"""Tests two Nodes are able to ping each other."""
# TODO Add actual checks. Manual check performed for now.
node_a = Node(hostname="node_a")
nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", gateway="192.168.0.1")
nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0")
node_a.connect_nic(nic_a)
node_a.power_on()
node_b = Node(hostname="node_b")
nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0", gateway="192.168.0.1")
nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0")
node_b.connect_nic(nic_b)
node_b.power_on()
@@ -23,19 +22,19 @@ def test_multi_nic():
"""Tests that Nodes with multiple NICs can ping each other and the data go across the correct links."""
# TODO Add actual checks. Manual check performed for now.
node_a = Node(hostname="node_a")
nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", gateway="192.168.0.1")
nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0")
node_a.connect_nic(nic_a)
node_a.power_on()
node_b = Node(hostname="node_b")
nic_b1 = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0", gateway="192.168.0.1")
nic_b2 = NIC(ip_address="10.0.0.12", subnet_mask="255.0.0.0", gateway="10.0.0.1")
nic_b1 = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0")
nic_b2 = NIC(ip_address="10.0.0.12", subnet_mask="255.0.0.0")
node_b.connect_nic(nic_b1)
node_b.connect_nic(nic_b2)
node_b.power_on()
node_c = Node(hostname="node_c")
nic_c = NIC(ip_address="10.0.0.13", subnet_mask="255.0.0.0", gateway="10.0.0.1")
nic_c = NIC(ip_address="10.0.0.13", subnet_mask="255.0.0.0")
node_c.connect_nic(nic_c)
node_c.power_on()
@@ -52,22 +51,22 @@ def test_switched_network():
"""Tests a larges network of Nodes and Switches with one node pinging another."""
# TODO Add actual checks. Manual check performed for now.
pc_a = Node(hostname="pc_a")
nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", gateway="192.168.0.1")
nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0")
pc_a.connect_nic(nic_a)
pc_a.power_on()
pc_b = Node(hostname="pc_b")
nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0", gateway="192.168.0.1")
nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0")
pc_b.connect_nic(nic_b)
pc_b.power_on()
pc_c = Node(hostname="pc_c")
nic_c = NIC(ip_address="192.168.0.12", subnet_mask="255.255.255.0", gateway="192.168.0.1")
nic_c = NIC(ip_address="192.168.0.12", subnet_mask="255.255.255.0")
pc_c.connect_nic(nic_c)
pc_c.power_on()
pc_d = Node(hostname="pc_d")
nic_d = NIC(ip_address="192.168.0.13", subnet_mask="255.255.255.0", gateway="192.168.0.1")
nic_d = NIC(ip_address="192.168.0.13", subnet_mask="255.255.255.0")
pc_d.connect_nic(nic_d)
pc_d.power_on()

View File

@@ -4,18 +4,17 @@ from primaite.simulator.network.hardware.base import Link, NIC, Node
def test_link_up():
"""Tests Nodes, NICs, and Links can all be connected and be in an enabled/up state."""
node_a = Node(hostname="node_a")
nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", gateway="192.168.0.1")
nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0")
node_a.connect_nic(nic_a)
node_a.power_on()
assert nic_a.enabled
node_b = Node(hostname="node_b")
nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0", gateway="192.168.0.1")
nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0")
node_b.connect_nic(nic_b)
node_b.power_on()
assert nic_b.enabled
link = Link(endpoint_a=nic_a, endpoint_b=nic_b)
assert nic_a.enabled
assert nic_b.enabled
assert link.is_up

View File

@@ -8,7 +8,6 @@ def test_link_fails_with_same_nic():
with pytest.raises(ValueError):
nic_a = NIC(
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
gateway="192.168.0.1",
subnet_mask="255.255.255.0"
)
Link(endpoint_a=nic_a, endpoint_b=nic_a)

View File

@@ -1,27 +1,55 @@
from primaite.simulator.network.hardware.base import Node, NIC, Link
from primaite.simulator.network.hardware.nodes.router import Router
from typing import Tuple
import pytest
from primaite.simulator.network.hardware.base import Link, NIC, Node
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
def test_ping_fails_with_no_route():
"""Tests a larges network of Nodes and Switches with one node pinging another."""
pc_a = Node(hostname="pc_a")
nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", gateway="192.168.0.1")
@pytest.fixture(scope="function")
def pc_a_pc_b_router_1() -> Tuple[Node, Node, Router]:
pc_a = Node(hostname="pc_a", default_gateway="192.168.0.1")
nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0")
pc_a.connect_nic(nic_a)
pc_a.power_on()
pc_b = Node(hostname="pc_b")
nic_b = NIC(ip_address="192.168.1.10", subnet_mask="255.255.255.0", gateway="192.168.1.1")
pc_b = Node(hostname="pc_b", default_gateway="192.168.1.1")
nic_b = NIC(ip_address="192.168.1.10", subnet_mask="255.255.255.0")
pc_b.connect_nic(nic_b)
pc_b.power_on()
router_1 = Router(hostname="router_1")
router_1.power_on()
router_1.configure_port(1, "192.168.0.1", "255.255.255.0")
router_1.configure_port(2, "192.168.1.1", "255.255.255.0")
router_1.power_on()
router_1.show()
Link(endpoint_a=nic_a, endpoint_b=router_1.ethernet_ports[1])
Link(endpoint_a=nic_b, endpoint_b=router_1.ethernet_ports[2])
router_1.enable_port(1)
router_1.enable_port(2)
link_nic_a_router_1 = Link(endpoint_a=nic_a, endpoint_b=router_1.ethernet_ports[1])
link_nic_b_router_1 = Link(endpoint_a=nic_b, endpoint_b=router_1.ethernet_ports[2])
router_1.power_on()
#assert pc_a.ping("192.168.1.10")
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
return pc_a, pc_b, router_1
def test_ping_default_gateway(pc_a_pc_b_router_1):
pc_a, pc_b, router_1 = pc_a_pc_b_router_1
assert pc_a.ping(pc_a.default_gateway)
def test_ping_other_router_port(pc_a_pc_b_router_1):
pc_a, pc_b, router_1 = pc_a_pc_b_router_1
assert pc_a.ping(pc_b.default_gateway)
def test_host_on_other_subnet(pc_a_pc_b_router_1):
pc_a, pc_b, router_1 = pc_a_pc_b_router_1
assert pc_a.ping("192.168.1.10")

View File

@@ -0,0 +1,104 @@
from ipaddress import IPv4Address
from primaite.simulator.network.hardware.nodes.router import AccessControlList, ACLAction, ACLRule
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
def test_add_rule():
acl = AccessControlList()
acl.add_rule(
action=ACLAction.PERMIT,
protocol=IPProtocol.TCP,
src_ip=IPv4Address("192.168.1.1"),
src_port=Port(8080),
dst_ip=IPv4Address("192.168.1.2"),
dst_port=Port(80),
position=1,
)
assert acl.acl[1].action == ACLAction.PERMIT
assert acl.acl[1].protocol == IPProtocol.TCP
assert acl.acl[1].src_ip == IPv4Address("192.168.1.1")
assert acl.acl[1].src_port == Port(8080)
assert acl.acl[1].dst_ip == IPv4Address("192.168.1.2")
assert acl.acl[1].dst_port == Port(80)
def test_remove_rule():
acl = AccessControlList()
acl.add_rule(
action=ACLAction.PERMIT,
protocol=IPProtocol.TCP,
src_ip=IPv4Address("192.168.1.1"),
src_port=Port(8080),
dst_ip=IPv4Address("192.168.1.2"),
dst_port=Port(80),
position=1,
)
acl.remove_rule(1)
assert not acl.acl[1]
def test_rules():
acl = AccessControlList()
acl.add_rule(
action=ACLAction.PERMIT,
protocol=IPProtocol.TCP,
src_ip=IPv4Address("192.168.1.1"),
src_port=Port(8080),
dst_ip=IPv4Address("192.168.1.2"),
dst_port=Port(80),
position=1,
)
acl.add_rule(
action=ACLAction.DENY,
protocol=IPProtocol.TCP,
src_ip=IPv4Address("192.168.1.3"),
src_port=Port(8080),
dst_ip=IPv4Address("192.168.1.4"),
dst_port=Port(80),
position=2,
)
assert acl.is_permitted(
protocol=IPProtocol.TCP,
src_ip=IPv4Address("192.168.1.1"),
src_port=Port(8080),
dst_ip=IPv4Address("192.168.1.2"),
dst_port=Port(80),
)
assert not acl.is_permitted(
protocol=IPProtocol.TCP,
src_ip=IPv4Address("192.168.1.3"),
src_port=Port(8080),
dst_ip=IPv4Address("192.168.1.4"),
dst_port=Port(80),
)
def test_default_rule():
acl = AccessControlList()
acl.add_rule(
action=ACLAction.PERMIT,
protocol=IPProtocol.TCP,
src_ip=IPv4Address("192.168.1.1"),
src_port=Port(8080),
dst_ip=IPv4Address("192.168.1.2"),
dst_port=Port(80),
position=1,
)
acl.add_rule(
action=ACLAction.DENY,
protocol=IPProtocol.TCP,
src_ip=IPv4Address("192.168.1.3"),
src_port=Port(8080),
dst_ip=IPv4Address("192.168.1.4"),
dst_port=Port(80),
position=2,
)
assert not acl.is_permitted(
protocol=IPProtocol.UDP,
src_ip=IPv4Address("192.168.1.5"),
src_port=Port(8080),
dst_ip=IPv4Address("192.168.1.12"),
dst_port=Port(80),
)

View File

@@ -32,10 +32,8 @@ def test_nic_ip_address_type_conversion():
nic = NIC(
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
gateway="192.168.0.1",
)
assert isinstance(nic.ip_address, IPv4Address)
assert isinstance(nic.gateway, IPv4Address)
def test_nic_deserialize():
@@ -43,7 +41,6 @@ def test_nic_deserialize():
nic = NIC(
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
gateway="192.168.0.1",
)
nic_json = nic.model_dump_json()
@@ -51,21 +48,10 @@ def test_nic_deserialize():
assert nic == deserialized_nic
def test_nic_ip_address_as_gateway_fails():
"""Tests NIC creation fails if ip address is the same as the gateway."""
with pytest.raises(ValueError):
NIC(
ip_address="192.168.0.1",
subnet_mask="255.255.255.0",
gateway="192.168.0.1",
)
def test_nic_ip_address_as_network_address_fails():
"""Tests NIC creation fails if ip address and subnet mask are a network address."""
with pytest.raises(ValueError):
NIC(
ip_address="192.168.0.0",
subnet_mask="255.255.255.0",
gateway="192.168.0.1",
)