#1706 - Refactored a bunch of if statements in base.py to improve readability

This commit is contained in:
Chris McCarthy
2023-08-09 20:31:42 +01:00
parent a840159460
commit b46057841d

View File

@@ -135,29 +135,33 @@ class NIC(SimComponent):
def enable(self):
"""Attempt to enable the NIC."""
if not self.enabled:
if self.connected_node:
if self.connected_node.operating_state == NodeOperatingState.ON:
self.enabled = True
self.connected_node.sys_log.info(f"NIC {self} enabled")
self.pcap = PacketCapture(hostname=self.connected_node.hostname, ip_address=self.ip_address)
if self.connected_link:
self.connected_link.endpoint_up()
else:
self.connected_node.sys_log.error(f"NIC {self} cannot be enabled as the endpoint is not turned on")
else:
_LOGGER.error(f"NIC {self} cannot be enabled as it is not connected to a Node")
if self.enabled:
return
if not self.connected_node:
_LOGGER.error(f"NIC {self} cannot be enabled as it is not connected to a Node")
return
if self.connected_node.operating_state != NodeOperatingState.ON:
self.connected_node.sys_log.error(f"NIC {self} cannot be enabled as the endpoint is not turned on")
return
self.enabled = True
self.connected_node.sys_log.info(f"NIC {self} enabled")
self.pcap = PacketCapture(hostname=self.connected_node.hostname, ip_address=self.ip_address)
if self.connected_link:
self.connected_link.endpoint_up()
def disable(self):
"""Disable the NIC."""
if self.enabled:
self.enabled = False
if self.connected_node:
self.connected_node.sys_log.info(f"NIC {self} disabled")
else:
_LOGGER.info(f"NIC {self} disabled")
if self.connected_link:
self.connected_link.endpoint_down()
if not self.enabled:
return
self.enabled = False
if self.connected_node:
self.connected_node.sys_log.info(f"NIC {self} disabled")
else:
_LOGGER.info(f"NIC {self} disabled")
if self.connected_link:
self.connected_link.endpoint_down()
def connect_link(self, link: Link):
"""
@@ -166,15 +170,17 @@ class NIC(SimComponent):
:param link: The link to which the NIC is connected.
:type link: :class:`~primaite.simulator.network.transmission.physical_layer.Link`
"""
if not self.connected_link:
if self.connected_link != link:
# TODO: Inform the Node that a link has been connected
self.connected_link = link
_LOGGER.info(f"NIC {self} connected to Link {link}")
else:
_LOGGER.error(f"Cannot connect Link to NIC ({self.mac_address}) as it is already connected")
else:
if self.connected_link:
_LOGGER.error(f"Cannot connect Link to NIC ({self.mac_address}) as it already has a connection")
return
if self.connected_link == link:
_LOGGER.error(f"Cannot connect Link to NIC ({self.mac_address}) as it is already connected")
return
# TODO: Inform the Node that a link has been connected
self.connected_link = link
_LOGGER.info(f"NIC {self} connected to Link {link}")
def disconnect_link(self):
"""Disconnect the NIC from the connected Link."""
@@ -214,9 +220,8 @@ class NIC(SimComponent):
self.pcap.capture(frame)
self.connected_link.transmit_frame(sender_nic=self, frame=frame)
return True
else:
# Cannot send Frame as the NIC is not enabled
return False
# Cannot send Frame as the NIC is not enabled
return False
def receive_frame(self, frame: Frame) -> bool:
"""
@@ -233,8 +238,7 @@ class NIC(SimComponent):
self.pcap.capture(frame)
self.connected_node.receive_frame(frame=frame, from_nic=self)
return True
else:
return False
return False
def describe_state(self) -> Dict:
"""
@@ -290,33 +294,34 @@ class SwitchPort(SimComponent):
def enable(self):
"""Attempt to enable the SwitchPort."""
if not self.enabled:
if self.connected_node:
if self.connected_node.operating_state == NodeOperatingState.ON:
self.enabled = True
self.connected_node.sys_log.info(f"SwitchPort {self} enabled")
self.pcap = PacketCapture(hostname=self.connected_node.hostname)
if self.connected_link:
self.connected_link.endpoint_up()
else:
self.connected_node.sys_log.info(
f"SwitchPort {self} cannot be enabled as the endpoint is not turned on"
)
else:
msg = f"SwitchPort {self} cannot be enabled as it is not connected to a Node"
_LOGGER.error(msg)
raise NetworkError(msg)
if self.enabled:
return
if not self.connected_node:
_LOGGER.error(f"SwitchPort {self} cannot be enabled as it is not connected to a Node")
return
if self.connected_node.operating_state != NodeOperatingState.ON:
self.connected_node.sys_log.info(f"SwitchPort {self} cannot be enabled as the endpoint is not turned on")
return
self.enabled = True
self.connected_node.sys_log.info(f"SwitchPort {self} enabled")
self.pcap = PacketCapture(hostname=self.connected_node.hostname)
if self.connected_link:
self.connected_link.endpoint_up()
def disable(self):
"""Disable the SwitchPort."""
if self.enabled:
self.enabled = False
if self.connected_node:
self.connected_node.sys_log.info(f"SwitchPort {self} disabled")
else:
_LOGGER.info(f"SwitchPort {self} disabled")
if self.connected_link:
self.connected_link.endpoint_down()
if not self.enabled:
return
self.enabled = False
if self.connected_node:
self.connected_node.sys_log.info(f"SwitchPort {self} disabled")
else:
_LOGGER.info(f"SwitchPort {self} disabled")
if self.connected_link:
self.connected_link.endpoint_down()
def connect_link(self, link: Link):
"""
@@ -324,16 +329,18 @@ class SwitchPort(SimComponent):
:param link: The link to which the SwitchPort is connected.
"""
if not self.connected_link:
if self.connected_link != link:
# TODO: Inform the Switch that a link has been connected
self.connected_link = link
_LOGGER.info(f"SwitchPort {self} connected to Link {link}")
self.enable()
else:
_LOGGER.error(f"Cannot connect Link to SwitchPort {self.mac_address} as it is already connected")
else:
if self.connected_link:
_LOGGER.error(f"Cannot connect link to SwitchPort {self.mac_address} as it already has a connection")
return
if self.connected_link == link:
_LOGGER.error(f"Cannot connect Link to SwitchPort {self.mac_address} as it is already connected")
return
# TODO: Inform the Switch that a link has been connected
self.connected_link = link
_LOGGER.info(f"SwitchPort {self} connected to Link {link}")
self.enable()
def disconnect_link(self):
"""Disconnect the SwitchPort from the connected Link."""
@@ -353,9 +360,8 @@ class SwitchPort(SimComponent):
self.pcap.capture(frame)
self.connected_link.transmit_frame(sender_nic=self, frame=frame)
return True
else:
# Cannot send Frame as the SwitchPort is not enabled
return False
# Cannot send Frame as the SwitchPort is not enabled
return False
def receive_frame(self, frame: Frame) -> bool:
"""
@@ -370,8 +376,7 @@ class SwitchPort(SimComponent):
self.pcap.capture(frame)
self.connected_node.forward_frame(frame=frame, incoming_port=self)
return True
else:
return False
return False
def describe_state(self) -> Dict:
"""
@@ -468,30 +473,27 @@ class Link(SimComponent):
:param frame: The network frame to be sent.
:return: True if the Frame can be sent, otherwise False.
"""
if self._can_transmit(frame):
receiver = self.endpoint_a
if receiver == sender_nic:
receiver = self.endpoint_b
frame_size = frame.size_Mbits
sent = receiver.receive_frame(frame)
if sent:
# Frame transmitted successfully
# Load the frame size on the link
self.current_load += frame_size
(
_LOGGER.info(
f"Added {frame_size:.3f} Mbits to {self}, current load {self.current_load:.3f} Mbits "
f"({self.current_load_percent})"
)
)
return True
# Received NIC disabled, reply
return False
else:
can_transmit = self._can_transmit(frame)
if not can_transmit:
_LOGGER.info(f"Cannot transmit frame as {self} is at capacity")
return False
receiver = self.endpoint_a
if receiver == sender_nic:
receiver = self.endpoint_b
frame_size = frame.size_Mbits
if receiver.receive_frame(frame):
# Frame transmitted successfully
# Load the frame size on the link
self.current_load += frame_size
_LOGGER.info(
f"Added {frame_size:.3f} Mbits to {self}, current load {self.current_load:.3f} Mbits "
f"({self.current_load_percent})"
)
return True
return False
def reset_component_for_episode(self, episode: int):
"""
Link reset function.
@@ -624,43 +626,48 @@ class ARPCache:
:param from_nic: The NIC that received the ARP packet.
:param arp_packet: The ARP packet to be processed.
"""
if arp_packet.request:
self.sys_log.info(
f"Received ARP request for {arp_packet.target_ip} from "
f"{arp_packet.sender_mac_addr}/{arp_packet.sender_ip} "
)
if arp_packet.target_ip == from_nic.ip_address:
self._add_arp_cache_entry(
ip_address=arp_packet.sender_ip, mac_address=arp_packet.sender_mac_addr, nic=from_nic
)
arp_packet = arp_packet.generate_reply(from_nic.mac_address)
self.sys_log.info(
f"Sending ARP reply from {arp_packet.sender_mac_addr}/{arp_packet.sender_ip} "
f"to {arp_packet.target_ip}/{arp_packet.target_mac_addr} "
)
tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP)
# Network Layer
ip_packet = IPPacket(
src_ip=arp_packet.sender_ip,
dst_ip=arp_packet.target_ip,
)
# Data Link Layer
ethernet_header = EthernetHeader(
src_mac_addr=arp_packet.sender_mac_addr, dst_mac_addr=arp_packet.target_mac_addr
)
frame = Frame(ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, arp=arp_packet)
from_nic.send_frame(frame)
else:
self.sys_log.info(f"Ignoring ARP request for {arp_packet.target_ip}")
else:
# ARP Reply
if not arp_packet.request:
self.sys_log.info(
f"Received ARP response for {arp_packet.sender_ip} from {arp_packet.sender_mac_addr} via NIC {from_nic}"
)
self._add_arp_cache_entry(
ip_address=arp_packet.sender_ip, mac_address=arp_packet.sender_mac_addr, nic=from_nic
)
return
# ARP Request
self.sys_log.info(
f"Received ARP request for {arp_packet.target_ip} from "
f"{arp_packet.sender_mac_addr}/{arp_packet.sender_ip} "
)
# Unmatched ARP Request
if arp_packet.target_ip != from_nic.ip_address:
self.sys_log.info(f"Ignoring ARP request for {arp_packet.target_ip}")
return
# Matched ARP request
self._add_arp_cache_entry(ip_address=arp_packet.sender_ip, mac_address=arp_packet.sender_mac_addr, nic=from_nic)
arp_packet = arp_packet.generate_reply(from_nic.mac_address)
self.sys_log.info(
f"Sending ARP reply from {arp_packet.sender_mac_addr}/{arp_packet.sender_ip} "
f"to {arp_packet.target_ip}/{arp_packet.target_mac_addr} "
)
tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP)
# Network Layer
ip_packet = IPPacket(
src_ip=arp_packet.sender_ip,
dst_ip=arp_packet.target_ip,
)
# Data Link Layer
ethernet_header = EthernetHeader(
src_mac_addr=arp_packet.sender_mac_addr, dst_mac_addr=arp_packet.target_mac_addr
)
frame = Frame(ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, arp=arp_packet)
from_nic.send_frame(frame)
class ICMP:
@@ -721,30 +728,34 @@ class ICMP:
was not found in the ARP cache.
"""
nic = self.arp.get_arp_cache_nic(target_ip_address)
if nic:
sequence += 1
target_mac_address = self.arp.get_arp_cache_mac_address(target_ip_address)
src_nic = self.arp.get_arp_cache_nic(target_ip_address)
tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP)
# Network Layer
ip_packet = IPPacket(
src_ip=nic.ip_address,
dst_ip=target_ip_address,
protocol=IPProtocol.ICMP,
)
# Data Link Layer
ethernet_header = EthernetHeader(src_mac_addr=src_nic.mac_address, dst_mac_addr=target_mac_address)
icmp_packet = ICMPPacket(identifier=identifier, sequence=sequence)
frame = Frame(ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, icmp=icmp_packet)
self.sys_log.info(f"Sending echo request to {target_ip_address}")
nic.send_frame(frame)
return sequence, icmp_packet.identifier
else:
# TODO: Eventually this ARP request needs to be done elsewhere. It's not the resonsibility of the
# ping function to handle ARP lookups
# No existing ARP entry
if not nic:
self.sys_log.info(f"No entry in ARP cache for {target_ip_address}")
self.arp.send_arp_request(target_ip_address)
return 0, None
# ARP entry exists
sequence += 1
target_mac_address = self.arp.get_arp_cache_mac_address(target_ip_address)
src_nic = self.arp.get_arp_cache_nic(target_ip_address)
tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP)
# Network Layer
ip_packet = IPPacket(
src_ip=nic.ip_address,
dst_ip=target_ip_address,
protocol=IPProtocol.ICMP,
)
# Data Link Layer
ethernet_header = EthernetHeader(src_mac_addr=src_nic.mac_address, dst_mac_addr=target_mac_address)
icmp_packet = ICMPPacket(identifier=identifier, sequence=sequence)
frame = Frame(ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, icmp=icmp_packet)
self.sys_log.info(f"Sending echo request to {target_ip_address}")
nic.send_frame(frame)
return sequence, icmp_packet.identifier
class NodeOperatingState(Enum):
"""Enumeration of Node Operating States."""