Files
PrimAITE/src/primaite/simulator/network/hardware/nodes/router.py
2023-12-22 10:31:11 +00:00

953 lines
38 KiB
Python

from __future__ import annotations
import secrets
from enum import Enum
from ipaddress import IPv4Address, IPv4Network
from typing import Dict, List, Optional, Tuple, Union
from prettytable import MARKDOWN, PrettyTable
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.network.hardware.base import ARPCache, ICMP, NIC, Node
from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader
from primaite.simulator.system.core.sys_log import SysLog
class ACLAction(Enum):
"""Enum for defining the ACL action types."""
DENY = 0
PERMIT = 1
class ACLRule(SimComponent):
"""
Represents an Access Control List (ACL) rule.
:ivar ACLAction action: Action to be performed (Permit/Deny). Default is DENY.
:ivar Optional[IPProtocol] protocol: Network protocol. Default is None.
:ivar Optional[IPv4Address] src_ip_address: Source IP address. Default is None.
:ivar Optional[Port] src_port: Source port number. Default is None.
:ivar Optional[IPv4Address] dst_ip_address: Destination IP address. Default is None.
:ivar Optional[Port] dst_port: Destination port number. Default is None.
"""
action: ACLAction = ACLAction.DENY
protocol: Optional[IPProtocol] = None
src_ip_address: Optional[IPv4Address] = None
src_port: Optional[Port] = None
dst_ip_address: Optional[IPv4Address] = None
dst_port: Optional[Port] = None
def __str__(self) -> str:
rule_strings = []
for key, value in self.model_dump(exclude={"uuid", "request_manager"}).items():
if value is None:
value = "ANY"
if isinstance(value, Enum):
rule_strings.append(f"{key}={value.name}")
else:
rule_strings.append(f"{key}={value}")
return ", ".join(rule_strings)
def set_original_state(self):
"""Sets the original state."""
vals_to_keep = {"action", "protocol", "src_ip_address", "src_port", "dst_ip_address", "dst_port"}
self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
def describe_state(self) -> Dict:
"""
Describes the current state of the ACLRule.
:return: A dictionary representing the current state.
"""
state = super().describe_state()
state["action"] = self.action.value
state["protocol"] = self.protocol.value if self.protocol else None
state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None
state["src_port"] = self.src_port.value if self.src_port else None
state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None
state["dst_port"] = self.dst_port.value if self.dst_port else None
return state
class AccessControlList(SimComponent):
"""
Manages a list of ACLRules to filter network traffic.
:ivar SysLog sys_log: System logging instance.
:ivar ACLAction implicit_action: Default action for rules.
:ivar ACLRule implicit_rule: Implicit ACL rule, created based on implicit_action.
:ivar int max_acl_rules: Maximum number of ACL rules that can be added. Default is 25.
:ivar List[Optional[ACLRule]] _acl: A list containing the ACL rules.
"""
sys_log: SysLog
implicit_action: ACLAction
implicit_rule: ACLRule
max_acl_rules: int = 25
_acl: List[Optional[ACLRule]] = [None] * 24
def __init__(self, **kwargs) -> None:
if not kwargs.get("implicit_action"):
kwargs["implicit_action"] = ACLAction.DENY
kwargs["implicit_rule"] = ACLRule(action=kwargs["implicit_action"])
super().__init__(**kwargs)
self._acl = [None] * (self.max_acl_rules - 1)
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
self.implicit_rule.set_original_state()
vals_to_keep = {"implicit_action", "max_acl_rules", "acl"}
self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.implicit_rule.reset_component_for_episode(episode)
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
# When the request reaches this action, it should now contain solely positional args for the 'add_rule' action.
# POSITIONAL ARGUMENTS:
# 0: action (str name of an ACLAction)
# 1: protocol (str name of an IPProtocol)
# 2: source ip address (str castable to IPV4Address (e.g. '10.10.1.2'))
# 3: source port (str name of a Port (e.g. "HTTP")) # should we be using value, such as 80 or 443?
# 4: destination ip address (str castable to IPV4Address (e.g. '10.10.1.2'))
# 5: destination port (str name of a Port (e.g. "HTTP"))
# 6: position (int)
rm.add_request(
"add_rule",
RequestType(
func=lambda request, context: self.add_rule(
ACLAction[request[0]],
None if request[1] == "ALL" else IPProtocol[request[1]],
IPv4Address(request[2]),
None if request[3] == "ALL" else Port[request[3]],
IPv4Address(request[4]),
None if request[5] == "ALL" else Port[request[5]],
int(request[6]),
)
),
)
rm.add_request("remove_rule", RequestType(func=lambda request, context: self.remove_rule(int(request[0]))))
return rm
def describe_state(self) -> Dict:
"""
Describes the current state of the AccessControlList.
:return: A dictionary representing the current state.
"""
state = super().describe_state()
state["implicit_action"] = self.implicit_action.value
state["implicit_rule"] = self.implicit_rule.describe_state()
state["max_acl_rules"] = self.max_acl_rules
state["acl"] = {i: r.describe_state() if isinstance(r, ACLRule) else None for i, r in enumerate(self._acl)}
return state
@property
def acl(self) -> List[Optional[ACLRule]]:
"""
Get the list of ACL rules.
:return: The list of ACL rules.
"""
return self._acl
def add_rule(
self,
action: ACLAction,
protocol: Optional[IPProtocol] = None,
src_ip_address: Optional[Union[str, IPv4Address]] = None,
src_port: Optional[Port] = None,
dst_ip_address: Optional[Union[str, IPv4Address]] = None,
dst_port: Optional[Port] = None,
position: int = 0,
) -> None:
"""
Add a new ACL rule.
:param ACLAction action: Action to be performed (Permit/Deny).
:param Optional[IPProtocol] protocol: Network protocol.
:param Optional[Union[str, IPv4Address]] src_ip_address: Source IP address.
:param Optional[Port] src_port: Source port number.
:param Optional[Union[str, IPv4Address]] dst_ip_address: Destination IP address.
:param Optional[Port] dst_port: Destination port number.
:param int position: Position in the ACL list to insert the rule.
:raises ValueError: When the position is out of bounds.
"""
if isinstance(src_ip_address, str):
src_ip_address = IPv4Address(src_ip_address)
if isinstance(dst_ip_address, str):
dst_ip_address = IPv4Address(dst_ip_address)
if 0 <= position < self.max_acl_rules:
if self._acl[position]:
self.sys_log.info(f"Overwriting ACL rule at position {position}")
self._acl[position] = ACLRule(
action=action,
src_ip_address=src_ip_address,
dst_ip_address=dst_ip_address,
protocol=protocol,
src_port=src_port,
dst_port=dst_port,
)
else:
raise ValueError(f"Cannot add ACL rule, position {position} is out of bounds.")
def remove_rule(self, position: int) -> None:
"""
Remove an ACL rule from a specific position.
:param int position: The position of the rule to be removed.
:raises ValueError: When the position is out of bounds.
"""
if 0 <= position < self.max_acl_rules - 1:
rule = self._acl[position] # noqa
self._acl[position] = None
del rule
else:
raise ValueError(f"Cannot remove ACL rule, position {position} is out of bounds.")
def is_permitted(
self,
protocol: IPProtocol,
src_ip_address: Union[str, IPv4Address],
src_port: Optional[Port],
dst_ip_address: Union[str, IPv4Address],
dst_port: Optional[Port],
) -> Tuple[bool, Optional[Union[str, ACLRule]]]:
"""
Check if a packet with the given properties is permitted through the ACL.
:param protocol: The protocol of the packet.
:param src_ip_address: Source IP address of the packet. Accepts string and IPv4Address.
:param src_port: Source port of the packet. Optional.
:param dst_ip_address: Destination IP address of the packet. Accepts string and IPv4Address.
:param dst_port: Destination port of the packet. Optional.
:return: A tuple with a boolean indicating if the packet is permitted and an optional rule or implicit action
string.
"""
if not isinstance(src_ip_address, IPv4Address):
src_ip_address = IPv4Address(src_ip_address)
if not isinstance(dst_ip_address, IPv4Address):
dst_ip_address = IPv4Address(dst_ip_address)
for rule in self._acl:
if not rule:
continue
if (
(rule.src_ip_address == src_ip_address or rule.src_ip_address is None)
and (rule.dst_ip_address == dst_ip_address or rule.dst_ip_address is None)
and (rule.protocol == protocol or rule.protocol is None)
and (rule.src_port == src_port or rule.src_port is None)
and (rule.dst_port == dst_port or rule.dst_port is None)
):
return rule.action == ACLAction.PERMIT, rule
return self.implicit_action == ACLAction.PERMIT, f"Implicit {self.implicit_action.name}"
def get_relevant_rules(
self,
protocol: IPProtocol,
src_ip_address: Union[str, IPv4Address],
src_port: Port,
dst_ip_address: Union[str, IPv4Address],
dst_port: Port,
) -> List[ACLRule]:
"""
Get the list of relevant rules for a packet with given properties.
:param protocol: The protocol of the packet.
:param src_ip_address: Source IP address of the packet. Accepts string and IPv4Address.
:param src_port: Source port of the packet.
:param dst_ip_address: Destination IP address of the packet. Accepts string and IPv4Address.
:param dst_port: Destination port of the packet.
:return: A list of relevant ACLRules.
"""
if not isinstance(src_ip_address, IPv4Address):
src_ip_address = IPv4Address(src_ip_address)
if not isinstance(dst_ip_address, IPv4Address):
dst_ip_address = IPv4Address(dst_ip_address)
relevant_rules = []
for rule in self._acl:
if rule is None:
continue
if (
(rule.src_ip_address == src_ip_address or rule.src_ip_address is None)
or (rule.dst_ip_address == dst_ip_address or rule.dst_ip_address is None)
or (rule.protocol == protocol or rule.protocol is None)
or (rule.src_port == src_port or rule.src_port is None)
or (rule.dst_port == dst_port or rule.dst_port is None)
):
relevant_rules.append(rule)
return relevant_rules
def show(self, markdown: bool = False):
"""
Display the current ACL rules as a table.
:param markdown: Whether to display the table in Markdown format. Defaults to False.
"""
table = PrettyTable(["Index", "Action", "Protocol", "Src IP", "Src Port", "Dst IP", "Dst Port"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.sys_log.hostname} Access Control List"
for index, rule in enumerate(self.acl + [self.implicit_rule]):
if rule:
table.add_row(
[
index,
rule.action.name if rule.action else "ANY",
rule.protocol.name if rule.protocol else "ANY",
rule.src_ip_address if rule.src_ip_address else "ANY",
f"{rule.src_port.value} ({rule.src_port.name})" if rule.src_port else "ANY",
rule.dst_ip_address if rule.dst_ip_address else "ANY",
f"{rule.dst_port.value} ({rule.dst_port.name})" if rule.dst_port else "ANY",
]
)
print(table)
class RouteEntry(SimComponent):
"""
Represents a single entry in a routing table.
:ivar address: The destination IP address or network address.
:ivar subnet_mask: The subnet mask for the network.
:ivar next_hop_ip_address: The next hop IP address to which packets should be forwarded.
:ivar metric: The cost metric for this route. Default is 0.0.
Example:
>>> entry = RouteEntry(
... IPv4Address("192.168.1.0"),
... IPv4Address("255.255.255.0"),
... IPv4Address("192.168.2.1"),
... metric=5
... )
"""
address: IPv4Address
"The destination IP address or network address."
subnet_mask: IPv4Address
"The subnet mask for the network."
next_hop_ip_address: IPv4Address
"The next hop IP address to which packets should be forwarded."
metric: float = 0.0
"The cost metric for this route. Default is 0.0."
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {"address", "subnet_mask", "next_hop_ip_address", "metric"}
self._original_values = self.model_dump(include=vals_to_include)
def describe_state(self) -> Dict:
"""
Describes the current state of the RouteEntry.
:return: A dictionary representing the current state.
"""
pass
class RouteTable(SimComponent):
"""
Represents a routing table holding multiple route entries.
:ivar List[RouteEntry] routes: A list of RouteEntry objects.
Example:
>>> rt = RouteTable()
>>> rt.add_route(
... RouteEntry(
... IPv4Address("192.168.1.0"),
... IPv4Address("255.255.255.0"),
... IPv4Address("192.168.2.1"),
... metric=5
... )
... )
>>> best_route = rt.find_best_route(IPv4Address("192.168.1.34"))
"""
routes: List[RouteEntry] = []
default_route: Optional[RouteEntry] = None
sys_log: SysLog
def set_original_state(self):
"""Sets the original state."""
"""Sets the original state."""
super().set_original_state()
self._original_state["routes_orig"] = self.routes
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.routes.clear()
self.routes = self._original_state["routes_orig"]
super().reset_component_for_episode(episode)
def describe_state(self) -> Dict:
"""
Describes the current state of the RouteTable.
:return: A dictionary representing the current state.
"""
pass
def add_route(
self,
address: Union[IPv4Address, str],
subnet_mask: Union[IPv4Address, str],
next_hop_ip_address: Union[IPv4Address, str],
metric: float = 0.0,
):
"""
Add a route to the routing table.
:param address: The destination address of the route.
:param subnet_mask: The subnet mask of the route.
:param next_hop_ip_address: The next hop IP for the route.
:param metric: The metric of the route, default is 0.0.
"""
for key in {address, subnet_mask, next_hop_ip_address}:
if not isinstance(key, IPv4Address):
key = IPv4Address(key)
route = RouteEntry(
address=address, subnet_mask=subnet_mask, next_hop_ip_address=next_hop_ip_address, metric=metric
)
self.routes.append(route)
def set_default_route_next_hop_ip_address(self, ip_address: IPv4Address):
"""
Sets the next-hop IP address for the default route in a routing table.
This method checks if a default route (0.0.0.0/0) exists in the routing table. If it does not exist,
the method creates a new default route with the specified next-hop IP address. If a default route already
exists, it updates the next-hop IP address of the existing default route. After setting the next-hop
IP address, the method logs this action.
:param ip_address: The next-hop IP address to be set for the default route.
"""
if not self.default_route:
self.default_route = RouteEntry(
ip_address=IPv4Address("0.0.0.0"),
subnet_mask=IPv4Address("0.0.0.0"),
next_hop_ip_address=ip_address,
)
else:
self.default_route.next_hop_ip_address = ip_address
self.sys_log.info(f"Default configured to use {ip_address} as the next-hop")
def find_best_route(self, destination_ip: Union[str, IPv4Address]) -> Optional[RouteEntry]:
"""
Find the best route for a given destination IP.
This method uses the Longest Prefix Match algorithm and considers metrics to find the best route.
If no dedicated route exists but a default route does, then the default route is returned as a last resort.
:param destination_ip: The destination IP to find the route for.
:return: The best matching RouteEntry, or None if no route matches.
"""
if not isinstance(destination_ip, IPv4Address):
destination_ip = IPv4Address(destination_ip)
best_route = None
longest_prefix = -1
lowest_metric = float("inf") # Initialise at infinity as any other number we compare to it will be smaller
for route in self.routes:
route_network = IPv4Network(f"{route.address}/{route.subnet_mask}", strict=False)
prefix_len = route_network.prefixlen
if destination_ip in route_network:
if prefix_len > longest_prefix or (prefix_len == longest_prefix and route.metric < lowest_metric):
best_route = route
longest_prefix = prefix_len
lowest_metric = route.metric
if not best_route and self.default_route:
best_route = self.default_route
return best_route
def show(self, markdown: bool = False):
"""
Display the current routing table as a table.
:param markdown: Whether to display the table in Markdown format. Defaults to False.
"""
table = PrettyTable(["Index", "Address", "Next Hop", "Metric"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.sys_log.hostname} Route Table"
for index, route in enumerate(self.routes):
network = IPv4Network(f"{route.address}/{route.subnet_mask}")
table.add_row([index, f"{route.address}/{network.prefixlen}", route.next_hop_ip_address, route.metric])
print(table)
class RouterARPCache(ARPCache):
"""
Inherits from ARPCache and adds router-specific ARP packet processing.
:ivar SysLog sys_log: A system log for logging messages.
:ivar Router router: The router to which this ARP cache belongs.
"""
def __init__(self, sys_log: SysLog, router: Router):
super().__init__(sys_log)
self.router: Router = router
def process_arp_packet(
self, from_nic: NIC, frame: Frame, route_table: RouteTable, is_reattempt: bool = False
) -> None:
"""
Processes a received ARP (Address Resolution Protocol) packet in a router-specific way.
This method is responsible for handling both ARP requests and responses. It processes ARP packets received on a
Network Interface Card (NIC) and performs actions based on whether the packet is a request or a reply. This
includes updating the ARP cache, forwarding ARP replies, sending ARP requests for unknown destinations, and
handling packet TTL (Time To Live).
The method first checks if the ARP packet is a request or a reply. For ARP replies, it updates the ARP cache
and forwards the reply if necessary. For ARP requests, it checks if the target IP matches one of the router's
NICs and sends an ARP reply if so. If the destination is not directly connected, it consults the routing table
to find the best route and reattempts ARP request processing if needed.
:param from_nic: The NIC that received the ARP packet.
:param frame: The frame containing the ARP packet.
:param route_table: The routing table of the router.
:param is_reattempt: Flag to indicate if this is a reattempt of processing the ARP packet, defaults to False.
"""
arp_packet = frame.arp
# ARP Reply
if not arp_packet.request:
for nic in self.router.nics.values():
if arp_packet.target_ip_address == nic.ip_address:
# reply to the Router specifically
self.sys_log.info(
f"Received ARP response for {arp_packet.sender_ip_address} "
f"from {arp_packet.sender_mac_addr} via NIC {from_nic}"
)
self.add_arp_cache_entry(
ip_address=arp_packet.sender_ip_address,
mac_address=arp_packet.sender_mac_addr,
nic=from_nic,
)
return
# Reply for a connected requested
nic = self.get_arp_cache_nic(arp_packet.target_ip_address)
if nic:
self.sys_log.info(
f"Forwarding arp reply for {arp_packet.target_ip_address}, from {arp_packet.sender_ip_address}"
)
arp_packet.sender_mac_addr = nic.mac_address
frame.decrement_ttl()
if frame.ip and frame.ip.ttl < 1:
self.sys_log.info("Frame discarded as TTL limit reached")
return
nic.send_frame(frame)
return
# ARP Request
self.sys_log.info(
f"Received ARP request for {arp_packet.target_ip_address} from "
f"{arp_packet.sender_mac_addr}/{arp_packet.sender_ip_address} "
)
# Matched ARP request
self.add_arp_cache_entry(
ip_address=arp_packet.sender_ip_address, mac_address=arp_packet.sender_mac_addr, nic=from_nic
)
# If the target IP matches one of the router's NICs
for nic in self.nics.values():
if arp_packet.target_ip_address in nic.ip_network:
# if nic.enabled and nic.ip_address == arp_packet.target_ip_address:
arp_reply = arp_packet.generate_reply(from_nic.mac_address)
self.send_arp_reply(arp_reply, from_nic)
return
# Check Route Table
route = route_table.find_best_route(arp_packet.target_ip_address)
if route:
nic = self.get_arp_cache_nic(route.next_hop_ip_address)
if not nic:
if not is_reattempt:
self.send_arp_request(route.next_hop_ip_address, ignore_networks=[frame.ip.src_ip_address])
return self.process_arp_packet(from_nic, frame, route_table, is_reattempt=True)
else:
pass
# TODO: destination unavailable/No ARP netry found
else:
arp_reply = arp_packet.generate_reply(from_nic.mac_address)
self.send_arp_reply(arp_reply, from_nic)
return
class RouterICMP(ICMP):
"""
A class to represent a router's Internet Control Message Protocol (ICMP) handler.
:param sys_log: System log for logging network events and errors.
:type sys_log: SysLog
:param arp_cache: The ARP cache for resolving MAC addresses.
:type arp_cache: ARPCache
:param router: The router to which this ICMP handler belongs.
:type router: Router
"""
router: Router
def __init__(self, sys_log: SysLog, arp_cache: ARPCache, router: Router):
super().__init__(sys_log, arp_cache)
self.router = router
def process_icmp(self, frame: Frame, from_nic: NIC, is_reattempt: bool = False):
"""
Process incoming ICMP frames based on ICMP type.
:param frame: The incoming frame to process.
:param from_nic: The network interface where the frame is coming from.
:param is_reattempt: Flag to indicate if the process is a reattempt.
"""
if frame.icmp.icmp_type == ICMPType.ECHO_REQUEST:
# determine if request is for router interface or whether it needs to be routed
for nic in self.router.nics.values():
if nic.ip_address == frame.ip.dst_ip_address:
if nic.enabled:
# reply to the request
if not is_reattempt:
self.sys_log.info(f"Received echo request from {frame.ip.src_ip_address}")
target_mac_address = self.arp.get_arp_cache_mac_address(frame.ip.src_ip_address)
src_nic = self.arp.get_arp_cache_nic(frame.ip.src_ip_address)
tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP)
# Network Layer
ip_packet = IPPacket(
src_ip_address=nic.ip_address,
dst_ip_address=frame.ip.src_ip_address,
protocol=IPProtocol.ICMP,
)
# Data Link Layer
ethernet_header = EthernetHeader(
src_mac_addr=src_nic.mac_address, dst_mac_addr=target_mac_address
)
icmp_reply_packet = ICMPPacket(
icmp_type=ICMPType.ECHO_REPLY,
icmp_code=0,
identifier=frame.icmp.identifier,
sequence=frame.icmp.sequence + 1,
)
payload = secrets.token_urlsafe(int(32 / 1.3)) # Standard ICMP 32 bytes size
frame = Frame(
ethernet=ethernet_header,
ip=ip_packet,
tcp=tcp_header,
icmp=icmp_reply_packet,
payload=payload,
)
self.sys_log.info(f"Sending echo reply to {frame.ip.dst_ip_address}")
src_nic.send_frame(frame)
return
# Route the frame
self.router.process_frame(frame, from_nic)
elif frame.icmp.icmp_type == ICMPType.ECHO_REPLY:
for nic in self.router.nics.values():
if nic.ip_address == frame.ip.dst_ip_address:
if nic.enabled:
time = frame.transmission_duration()
time_str = f"{time}ms" if time > 0 else "<1ms"
self.sys_log.info(
f"Reply from {frame.ip.src_ip_address}: "
f"bytes={len(frame.payload)}, "
f"time={time_str}, "
f"TTL={frame.ip.ttl}"
)
if not self.request_replies.get(frame.icmp.identifier):
self.request_replies[frame.icmp.identifier] = 0
self.request_replies[frame.icmp.identifier] += 1
return
# Route the frame
self.router.process_frame(frame, from_nic)
class Router(Node):
"""
A class to represent a network router node.
:ivar str hostname: The name of the router node.
:ivar int num_ports: The number of ports in the router.
:ivar dict kwargs: Optional keyword arguments for SysLog, ACL, RouteTable, RouterARPCache, RouterICMP.
"""
num_ports: int
ethernet_ports: Dict[int, NIC] = {}
acl: AccessControlList
route_table: RouteTable
arp: RouterARPCache
icmp: RouterICMP
def __init__(self, hostname: str, num_ports: int = 5, **kwargs):
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(hostname)
if not kwargs.get("acl"):
kwargs["acl"] = AccessControlList(sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY)
if not kwargs.get("route_table"):
kwargs["route_table"] = RouteTable(sys_log=kwargs["sys_log"])
if not kwargs.get("arp"):
kwargs["arp"] = RouterARPCache(sys_log=kwargs.get("sys_log"), router=self)
if not kwargs.get("icmp"):
kwargs["icmp"] = RouterICMP(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp"), router=self)
super().__init__(hostname=hostname, num_ports=num_ports, **kwargs)
for i in range(1, self.num_ports + 1):
nic = NIC(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")
self.connect_nic(nic)
self.ethernet_ports[i] = nic
self.arp.nics = self.nics
self.icmp.arp = self.arp
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
self.acl.set_original_state()
self.route_table.set_original_state()
super().set_original_state()
vals_to_include = {"num_ports"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.arp.clear()
self.acl.reset_component_for_episode(episode)
self.route_table.reset_component_for_episode(episode)
for i, nic in self.ethernet_ports.items():
nic.reset_component_for_episode(episode)
self.enable_port(i)
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request("acl", RequestType(func=self.acl._request_manager))
return rm
def _get_port_of_nic(self, target_nic: NIC) -> Optional[int]:
"""
Retrieve the port number for a given NIC.
:param target_nic: Target network interface.
:return: The port number if NIC is found, otherwise None.
"""
for port, nic in self.ethernet_ports.items():
if nic == target_nic:
return port
def describe_state(self) -> Dict:
"""
Describes the current state of the Router.
:return: A dictionary representing the current state.
"""
state = super().describe_state()
state["num_ports"] = (self.num_ports,)
state["acl"] = (self.acl.describe_state(),)
return state
def process_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
"""
Process a Frame.
:param frame: The frame to be routed.
:param from_nic: The source network interface.
:param re_attempt: Flag to indicate if the routing is a reattempt.
"""
# Check if src ip is on network of one of the NICs
nic = self.arp.get_arp_cache_nic(frame.ip.dst_ip_address)
target_mac = self.arp.get_arp_cache_mac_address(frame.ip.dst_ip_address)
if re_attempt and not nic:
self.sys_log.info(f"Destination {frame.ip.dst_ip_address} is unreachable")
return
if not nic:
self.arp.send_arp_request(
frame.ip.dst_ip_address, ignore_networks=[frame.ip.src_ip_address, from_nic.ip_address]
)
return self.process_frame(frame=frame, from_nic=from_nic, re_attempt=True)
if not nic.enabled:
# TODO: Add sys_log here
return
if frame.ip.dst_ip_address in nic.ip_network:
from_port = self._get_port_of_nic(from_nic)
to_port = self._get_port_of_nic(nic)
self.sys_log.info(f"Forwarding frame to internally from port {from_port} to port {to_port}")
frame.decrement_ttl()
if frame.ip and frame.ip.ttl < 1:
self.sys_log.info("Frame discarded as TTL limit reached")
return
frame.ethernet.src_mac_addr = nic.mac_address
frame.ethernet.dst_mac_addr = target_mac
nic.send_frame(frame)
return
else:
self._route_frame(frame, from_nic)
def _route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
route = self.route_table.find_best_route(frame.ip.dst_ip_address)
if route:
nic = self.arp.get_arp_cache_nic(route.next_hop_ip_address)
target_mac = self.arp.get_arp_cache_mac_address(route.next_hop_ip_address)
if re_attempt and not nic:
self.sys_log.info(f"Destination {frame.ip.dst_ip_address} is unreachable")
return
if not nic:
self.arp.send_arp_request(frame.ip.dst_ip_address, ignore_networks=[frame.ip.src_ip_address])
return self.process_frame(frame=frame, from_nic=from_nic, re_attempt=True)
if not nic.enabled:
# TODO: Add sys_log here
return
from_port = self._get_port_of_nic(from_nic)
to_port = self._get_port_of_nic(nic)
self.sys_log.info(f"Routing frame to internally from port {from_port} to port {to_port}")
frame.decrement_ttl()
if frame.ip and frame.ip.ttl < 1:
self.sys_log.info("Frame discarded as TTL limit reached")
return
frame.ethernet.src_mac_addr = nic.mac_address
frame.ethernet.dst_mac_addr = target_mac
nic.send_frame(frame)
def receive_frame(self, frame: Frame, from_nic: NIC):
"""
Receive a frame from a NIC and processes it based on its protocol.
:param frame: The incoming frame.
:param from_nic: The network interface where the frame is coming from.
"""
process_frame = False
protocol = frame.ip.protocol
src_ip_address = frame.ip.src_ip_address
dst_ip_address = frame.ip.dst_ip_address
src_port = None
dst_port = None
if frame.ip.protocol == IPProtocol.TCP:
src_port = frame.tcp.src_port
dst_port = frame.tcp.dst_port
elif frame.ip.protocol == IPProtocol.UDP:
src_port = frame.udp.src_port
dst_port = frame.udp.dst_port
# Check if it's permitted
permitted, rule = self.acl.is_permitted(
protocol=protocol,
src_ip_address=src_ip_address,
src_port=src_port,
dst_ip_address=dst_ip_address,
dst_port=dst_port,
)
if not permitted:
at_port = self._get_port_of_nic(from_nic)
self.sys_log.info(f"Frame blocked at port {at_port} by rule {rule}")
return
if not self.arp.get_arp_cache_nic(src_ip_address):
self.arp.add_arp_cache_entry(src_ip_address, frame.ethernet.src_mac_addr, from_nic)
if frame.ip.protocol == IPProtocol.ICMP:
self.icmp.process_icmp(frame=frame, from_nic=from_nic)
else:
if src_port == Port.ARP:
self.arp.process_arp_packet(from_nic=from_nic, frame=frame, route_table=self.route_table)
else:
# All other traffic
process_frame = True
if process_frame:
self.process_frame(frame, from_nic)
def configure_port(self, port: int, ip_address: Union[IPv4Address, str], subnet_mask: Union[IPv4Address, str]):
"""
Configure the IP settings of a given port.
:param port: The port to configure.
:param ip_address: The IP address to set.
:param subnet_mask: The subnet mask to set.
"""
if not isinstance(ip_address, IPv4Address):
ip_address = IPv4Address(ip_address)
if not isinstance(subnet_mask, IPv4Address):
subnet_mask = IPv4Address(subnet_mask)
nic = self.ethernet_ports[port]
nic.ip_address = ip_address
nic.subnet_mask = subnet_mask
self.sys_log.info(f"Configured port {port} with ip_address={ip_address}/{nic.ip_network.prefixlen}")
self.set_original_state()
def enable_port(self, port: int):
"""
Enable a given port on the router.
:param port: The port to enable.
"""
nic = self.ethernet_ports.get(port)
if nic:
nic.enable()
def disable_port(self, port: int):
"""
Disable a given port on the router.
:param port: The port to disable.
"""
nic = self.ethernet_ports.get(port)
if nic:
nic.disable()
def show(self, markdown: bool = False):
"""
Prints the state of the Ethernet interfaces on the Router.
:param markdown: Flag to indicate if the output should be in markdown format.
"""
"""Prints a table of the NICs on the Node."""
table = PrettyTable(["Port", "MAC Address", "Address", "Speed", "Status"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.hostname} Ethernet Interfaces"
for port, nic in self.ethernet_ports.items():
table.add_row(
[
port,
nic.mac_address,
f"{nic.ip_address}/{nic.ip_network.prefixlen}",
nic.speed,
"Enabled" if nic.enabled else "Disabled",
]
)
print(table)