diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index ad157c9c..9a7ba596 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -16,6 +16,8 @@ from platformdirs import PlatformDirs with open(Path(__file__).parent.resolve() / "VERSION", "r") as file: __version__ = file.readline().strip() +_PRIMAITE_ROOT: Path = Path(__file__).parent + class _PrimaitePaths: """ diff --git a/src/primaite/simulator/__init__.py b/src/primaite/simulator/__init__.py index e69de29b..5b65ad40 100644 --- a/src/primaite/simulator/__init__.py +++ b/src/primaite/simulator/__init__.py @@ -0,0 +1,4 @@ +from primaite import _PRIMAITE_ROOT + +TEMP_SIM_OUTPUT = _PRIMAITE_ROOT.parent.parent / "simulation_output" +"A path at the repo root dir to use temporarily for sim output testing while in dev." diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 138c444c..739fb933 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -13,8 +13,10 @@ from primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader -from primaite.simulator.system.packet_capture import PacketCapture -from primaite.simulator.system.sys_log import SysLog +from primaite.simulator.system.core.packet_capture import PacketCapture +from primaite.simulator.system.core.session_manager import SessionManager +from primaite.simulator.system.core.software_manager import SoftwareManager +from primaite.simulator.system.core.sys_log import SysLog _LOGGER = getLogger(__name__) @@ -103,8 +105,8 @@ class NIC(SimComponent): kwargs["ip_address"] = IPv4Address(kwargs["ip_address"]) if not isinstance(kwargs["gateway"], IPv4Address): kwargs["gateway"] = IPv4Address(kwargs["gateway"]) - if "mac_address" not in kwargs: - kwargs["mac_address"] = generate_mac_address() + if "mac_address" not in kwargs: + kwargs["mac_address"] = generate_mac_address() super().__init__(**kwargs) if self.ip_address == self.gateway: @@ -163,9 +165,9 @@ class NIC(SimComponent): """ if not self.connected_link: if self.connected_link != link: - _LOGGER.info(f"NIC {self} connected to Link") # TODO: Inform the Node that a link has been connected self.connected_link = link + _LOGGER.info(f"NIC {self} connected to Link {link}") else: _LOGGER.warning(f"Cannot connect link to NIC ({self.mac_address}) as it is already connected") else: @@ -254,23 +256,155 @@ class NIC(SimComponent): return f"{self.mac_address}/{self.ip_address}" +class SwitchPort(SimComponent): + """ + Models a switch port in a network switch device. + + :param mac_address: The MAC address of the SwitchPort. Defaults to a randomly set MAC address. + :param speed: The speed of the SwitchPort in Mbps (default is 100 Mbps). + :param mtu: The Maximum Transmission Unit (MTU) of the SwitchPort in Bytes, representing the largest data packet + size it can handle without fragmentation (default is 1500 B). + """ + + port_num: int = 1 + mac_address: str + "The MAC address of the SwitchPort. Defaults to a randomly set MAC address." + speed: int = 100 + "The speed of the SwitchPort in Mbps. Default is 100 Mbps." + mtu: int = 1500 + "The Maximum Transmission Unit (MTU) of the SwitchPort in Bytes. Default is 1500 B" + connected_node: Optional[Switch] = None + "The Node to which the SwitchPort is connected." + connected_link: Optional[Link] = None + "The Link to which the SwitchPort is connected." + enabled: bool = False + "Indicates whether the SwitchPort is enabled." + pcap: Optional[PacketCapture] = None + + def __init__(self, **kwargs): + """The SwitchPort constructor.""" + if "mac_address" not in kwargs: + kwargs["mac_address"] = generate_mac_address() + super().__init__(**kwargs) + + def enable(self): + """Attempt to enable the SwitchPort.""" + if not self.enabled: + if self.connected_node: + if self.connected_node.operating_state == NodeOperatingState.ON: + self.enabled = True + _LOGGER.info(f"SwitchPort {self} enabled") + self.pcap = PacketCapture(hostname=self.connected_node.hostname) + if self.connected_link: + self.connected_link.endpoint_up() + else: + _LOGGER.info(f"SwitchPort {self} cannot be enabled as the endpoint is not turned on") + else: + msg = f"SwitchPort {self} cannot be enabled as it is not connected to a Node" + _LOGGER.error(msg) + raise NetworkError(msg) + + def disable(self): + """Disable the SwitchPort.""" + if self.enabled: + self.enabled = False + _LOGGER.info(f"SwitchPort {self} disabled") + if self.connected_link: + self.connected_link.endpoint_down() + + def connect_link(self, link: Link): + """ + Connect the SwitchPort to a link. + + :param link: The link to which the SwitchPort is connected. + :raise NetworkError: When an attempt to connect a Link is made while the SwitchPort has a connected Link. + """ + if not self.connected_link: + if self.connected_link != link: + # TODO: Inform the Switch that a link has been connected + self.connected_link = link + _LOGGER.info(f"SwitchPort {self} connected to Link {link}") + self.enable() + else: + _LOGGER.warning(f"Cannot connect link to SwitchPort ({self.mac_address}) as it is already connected") + else: + msg = f"Cannot connect link to SwitchPort ({self.mac_address}) as it already has a connection" + _LOGGER.error(msg) + raise NetworkError(msg) + + def disconnect_link(self): + """Disconnect the SwitchPort from the connected Link.""" + if self.connected_link.endpoint_a == self: + self.connected_link.endpoint_a = None + if self.connected_link.endpoint_b == self: + self.connected_link.endpoint_b = None + self.connected_link = None + + def send_frame(self, frame: Frame) -> bool: + """ + Send a network frame from the SwitchPort to the connected link. + + :param frame: The network frame to be sent. + """ + if self.enabled: + self.pcap.capture(frame) + self.connected_link.transmit_frame(sender_nic=self, frame=frame) + return True + else: + # Cannot send Frame as the SwitchPort is not enabled + return False + + def receive_frame(self, frame: Frame) -> bool: + """ + Receive a network frame from the connected link if the SwitchPort is enabled. + + The Frame is passed to the Node. + + :param frame: The network frame being received. + """ + if self.enabled: + frame.decrement_ttl() + self.pcap.capture(frame) + self.connected_node.forward_frame(frame=frame, incoming_port=self) + return True + else: + return False + + def describe_state(self) -> Dict: + """ + Get the current state of the SwitchPort as a dict. + + :return: A dict containing the current state of the SwitchPort. + """ + pass + + def apply_action(self, action: str): + """ + Apply an action to the SwitchPort. + + :param action: The action to be applied. + :type action: str + """ + pass + + def __str__(self) -> str: + return f"{self.mac_address}" + + class Link(SimComponent): """ - Represents a network link between two network interface cards (NICs). + Represents a network link between NIC<-->, NIC<-->SwitchPort, or SwitchPort<-->SwitchPort. - :param endpoint_a: The first NIC connected to the Link. - :type endpoint_a: NIC - :param endpoint_b: The second NIC connected to the Link. - :type endpoint_b: NIC + :param endpoint_a: The first NIC or SwitchPort connected to the Link. + :param endpoint_b: The second NIC or SwitchPort connected to the Link. :param bandwidth: The bandwidth of the Link in Mbps (default is 100 Mbps). - :type bandwidth: int """ - endpoint_a: NIC - "The first NIC connected to the Link." - endpoint_b: NIC - "The second NIC connected to the Link." - bandwidth: int = 100 + endpoint_a: Union[NIC, SwitchPort] + "The first NIC or SwitchPort connected to the Link." + endpoint_b: Union[NIC, SwitchPort] + "The second NIC or SwitchPort connected to the Link." + bandwidth: float = 100.0 "The bandwidth of the Link in Mbps (default is 100 Mbps)." current_load: float = 0.0 "The current load on the link in Mbps." @@ -284,7 +418,7 @@ class Link(SimComponent): :raises ValueError: If endpoint_a and endpoint_b are the same NIC. """ if kwargs["endpoint_a"] == kwargs["endpoint_b"]: - msg = "endpoint_a and endpoint_b cannot be the same NIC" + msg = "endpoint_a and endpoint_b cannot be the same NIC or SwitchPort" _LOGGER.error(msg) raise ValueError(msg) super().__init__(**kwargs) @@ -292,6 +426,11 @@ class Link(SimComponent): self.endpoint_b.connect_link(self) self.endpoint_up() + @property + def current_load_percent(self) -> str: + """Get the current load formatted as a percentage string.""" + return f"{self.current_load / self.bandwidth:.5f}%" + def endpoint_up(self): """Let the Link know and endpoint has been brought up.""" if self.up: @@ -318,25 +457,30 @@ class Link(SimComponent): return self.current_load + frame_size_Mbits <= self.bandwidth return False - def transmit_frame(self, sender_nic: NIC, frame: Frame) -> bool: + def transmit_frame(self, sender_nic: Union[NIC, SwitchPort], frame: Frame) -> bool: """ - Send a network frame from one NIC to another connected NIC. + Send a network frame from one NIC or SwitchPort to another connected NIC or SwitchPort. - :param sender_nic: The NIC sending the frame. + :param sender_nic: The NIC or SwitchPort sending the frame. :param frame: The network frame to be sent. :return: True if the Frame can be sent, otherwise False. """ if self._can_transmit(frame): - receiver_nic = self.endpoint_a - if receiver_nic == sender_nic: - receiver_nic = self.endpoint_b + receiver = self.endpoint_a + if receiver == sender_nic: + receiver = self.endpoint_b frame_size = frame.size_Mbits - sent = receiver_nic.receive_frame(frame) + sent = receiver.receive_frame(frame) if sent: # Frame transmitted successfully # Load the frame size on the link self.current_load += frame_size - _LOGGER.info(f"Added {frame_size:.3f} Mbits to {self}, current load {self.current_load:.3f} Mbits") + ( + _LOGGER.info( + f"Added {frame_size:.3f} Mbits to {self}, current load {self.current_load:.3f} Mbits " + f"({self.current_load_percent})" + ) + ) return True # Received NIC disabled, reply @@ -345,7 +489,7 @@ class Link(SimComponent): _LOGGER.info(f"Cannot transmit frame as {self} is at capacity") return False - def reset_component_for_episode(self): + def reset_component_for_episode(self, episode: int): """ Link reset function. @@ -356,7 +500,7 @@ class Link(SimComponent): def describe_state(self) -> Dict: """ - Get the current state of the Libk as a dict. + Get the current state of the Link as a dict. :return: A dict containing the current state of the Link. """ @@ -375,108 +519,23 @@ class Link(SimComponent): return f"{self.endpoint_a}<-->{self.endpoint_b}" -class NodeOperatingState(Enum): - """Enumeration of Node Operating States.""" - - OFF = 0 - "The node is powered off." - ON = 1 - "The node is powered on." - SHUTTING_DOWN = 2 - "The node is in the process of shutting down." - BOOTING = 3 - "The node is in the process of booting up." - - -class Node(SimComponent): +class ARPCache: """ - A basic Node class. + The ARPCache (Address Resolution Protocol) class. - :param hostname: The node hostname on the network. - :param operating_state: The node operating state. + Responsible for maintaining a mapping between IP addresses and MAC addresses (ARP cache) for the network. It + provides methods for looking up, adding, and removing entries, and for processing ARPPackets. """ - hostname: str - "The node hostname on the network." - operating_state: NodeOperatingState = NodeOperatingState.OFF - "The hardware state of the node." - nics: Dict[str, NIC] = {} - "The NICs on the node." - - accounts: Dict = {} - "All accounts on the node." - applications: Dict = {} - "All applications on the node." - services: Dict = {} - "All services on the node." - processes: Dict = {} - "All processes on the node." - file_system: Any = None - "The nodes file system." - arp_cache: Dict[IPv4Address, ARPEntry] = {} - "The ARP cache." - sys_log: Optional[SysLog] = None - - revealed_to_red: bool = False - "Informs whether the node has been revealed to a red agent." - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.sys_log = SysLog(self.hostname) - - def turn_on(self): - """Turn on the Node.""" - if self.operating_state == NodeOperatingState.OFF: - self.operating_state = NodeOperatingState.ON - self.sys_log.info("Turned on") - for nic in self.nics.values(): - nic.enable() - - def turn_off(self): - """Turn off the Node.""" - if self.operating_state == NodeOperatingState.ON: - for nic in self.nics.values(): - nic.disable() - self.operating_state = NodeOperatingState.OFF - self.sys_log.info("Turned off") - - def connect_nic(self, nic: NIC): + def __init__(self, sys_log: "SysLog"): """ - Connect a NIC. + Initialize an ARP (Address Resolution Protocol) cache. - :param nic: The NIC to connect. - :raise NetworkError: If the NIC is already connected. + :param sys_log: The nodes sys log. """ - if nic.uuid not in self.nics: - self.nics[nic.uuid] = nic - nic.connected_node = self - self.sys_log.info(f"Connected NIC {nic}") - if self.operating_state == NodeOperatingState.ON: - nic.enable() - else: - msg = f"Cannot connect NIC {nic} as it is already connected" - self.sys_log.logger.error(msg) - _LOGGER.error(msg) - raise NetworkError(msg) - - def disconnect_nic(self, nic: Union[NIC, str]): - """ - Disconnect a NIC. - - :param nic: The NIC to Disconnect. - :raise NetworkError: If the NIC is not connected. - """ - if isinstance(nic, str): - nic = self.nics.get(nic) - if nic or nic.uuid in self.nics: - self.nics.pop(nic.uuid) - nic.disable() - self.sys_log.info(f"Disconnected NIC {nic}") - else: - msg = f"Cannot disconnect NIC {nic} as it is not connected" - self.sys_log.logger.error(msg) - _LOGGER.error(msg) - raise NetworkError(msg) + self.sys_log: "SysLog" = sys_log + self.arp: Dict[IPv4Address, ARPEntry] = {} + self.nics: Dict[str, "NIC"] = {} def _add_arp_cache_entry(self, ip_address: IPv4Address, mac_address: str, nic: NIC): """ @@ -488,7 +547,7 @@ class Node(SimComponent): """ self.sys_log.info(f"Adding ARP cache entry for {mac_address}/{ip_address} via NIC {nic}") arp_entry = ARPEntry(mac_address=mac_address, nic_uuid=nic.uuid) - self.arp_cache[ip_address] = arp_entry + self.arp[ip_address] = arp_entry def _remove_arp_cache_entry(self, ip_address: IPv4Address): """ @@ -496,37 +555,44 @@ class Node(SimComponent): :param ip_address: The IP address to be removed from the cache. """ - if ip_address in self.arp_cache: - del self.arp_cache[ip_address] + if ip_address in self.arp: + del self.arp[ip_address] - def _get_arp_cache_mac_address(self, ip_address: IPv4Address) -> Optional[str]: + def get_arp_cache_mac_address(self, ip_address: IPv4Address) -> Optional[str]: """ Get the MAC address associated with an IP address. :param ip_address: The IP address to look up in the cache. :return: The MAC address associated with the IP address, or None if not found. """ - arp_entry = self.arp_cache.get(ip_address) + arp_entry = self.arp.get(ip_address) if arp_entry: return arp_entry.mac_address - def _get_arp_cache_nic(self, ip_address: IPv4Address) -> Optional[NIC]: + def get_arp_cache_nic(self, ip_address: IPv4Address) -> Optional[NIC]: """ Get the NIC associated with an IP address. :param ip_address: The IP address to look up in the cache. :return: The NIC associated with the IP address, or None if not found. """ - arp_entry = self.arp_cache.get(ip_address) + arp_entry = self.arp.get(ip_address) if arp_entry: return self.nics[arp_entry.nic_uuid] - def _clear_arp_cache(self): - """Clear the entire ARP cache.""" - self.arp_cache.clear() + def clear_arp_cache(self): + """Clear the entire ARP cache, removing all stored entries.""" + self.arp.clear() - def _send_arp_request(self, target_ip_address: Union[IPv4Address, str]): - """Perform a standard ARP request for a given target IP address.""" + def send_arp_request(self, target_ip_address: Union[IPv4Address, str]): + """ + Perform a standard ARP request for a given target IP address. + + Broadcasts the request through all enabled NICs to determine the MAC address corresponding to the target IP + address. + + :param target_ip_address: The target IP address to send an ARP request for. + """ for nic in self.nics.values(): if nic.enabled: self.sys_log.info(f"Sending ARP request from NIC {nic} for ip {target_ip_address}") @@ -547,12 +613,13 @@ class Node(SimComponent): def process_arp_packet(self, from_nic: NIC, arp_packet: ARPPacket): """ - Process an ARP packet. + Process a received ARP packet, handling both ARP requests and responses. - # TODO: This will become a service that sits on the Node. + If an ARP request is received for the local IP, a response is sent back. + If an ARP response is received, the ARP cache is updated with the new entry. - :param from_nic: The NIC the arp packet was received at. - :param arp_packet:The ARP packet to process. + :param from_nic: The NIC that received the ARP packet. + :param arp_packet: The ARP packet to be processed. """ if arp_packet.request: self.sys_log.info( @@ -581,7 +648,7 @@ class Node(SimComponent): src_mac_addr=arp_packet.sender_mac_addr, dst_mac_addr=arp_packet.target_mac_addr ) frame = Frame(ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, arp=arp_packet) - self.send_frame(frame) + from_nic.send_frame(frame) else: self.sys_log.info(f"Ignoring ARP request for {arp_packet.target_ip}") else: @@ -592,18 +659,34 @@ class Node(SimComponent): ip_address=arp_packet.sender_ip, mac_address=arp_packet.sender_mac_addr, nic=from_nic ) + +class ICMP: + """ + The ICMP (Internet Control Message Protocol) class. + + Provides functionalities for managing and handling ICMP packets, including echo requests and replies. + """ + + def __init__(self, sys_log: SysLog, arp_cache: ARPCache): + """ + Initialize the ICMP (Internet Control Message Protocol) service. + + :param sys_log: The system log to store system messages and information. + :param arp_cache: The ARP cache for resolving IP to MAC address mappings. + """ + self.sys_log: SysLog = sys_log + self.arp: ARPCache = arp_cache + def process_icmp(self, frame: Frame): """ - Process an ICMP packet. + Process an ICMP packet, including handling echo requests and replies. - # TODO: This will become a service that sits on the Node. - - :param frame: The Frame containing the icmp packet to process. + :param frame: The Frame containing the ICMP packet to process. """ if frame.icmp.icmp_type == ICMPType.ECHO_REQUEST: self.sys_log.info(f"Received echo request from {frame.ip.src_ip}") - target_mac_address = self._get_arp_cache_mac_address(frame.ip.src_ip) - src_nic = self._get_arp_cache_nic(frame.ip.src_ip) + target_mac_address = self.arp.get_arp_cache_mac_address(frame.ip.src_ip) + src_nic = self.arp.get_arp_cache_nic(frame.ip.src_ip) tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP) # Network Layer @@ -617,19 +700,28 @@ class Node(SimComponent): sequence=frame.icmp.sequence + 1, ) frame = Frame(ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, icmp=icmp_reply_packet) - self.sys_log.info(f"Sending echo reply to {frame.ip.src_ip}") + self.sys_log.info(f"Sending echo reply to {frame.ip.dst_ip}") src_nic.send_frame(frame) elif frame.icmp.icmp_type == ICMPType.ECHO_REPLY: self.sys_log.info(f"Received echo reply from {frame.ip.src_ip}") - def _ping( + def ping( self, target_ip_address: IPv4Address, sequence: int = 0, identifier: Optional[int] = None ) -> Tuple[int, Union[int, None]]: - nic = self._get_arp_cache_nic(target_ip_address) + """ + Send an ICMP echo request (ping) to a target IP address and manage the sequence and identifier. + + :param target_ip_address: The target IP address to send the ping. + :param sequence: The sequence number of the echo request. Defaults to 0. + :param identifier: An optional identifier for the ICMP packet. If None, a default will be used. + :return: A tuple containing the next sequence number and the identifier, or (0, None) if the target IP address + was not found in the ARP cache. + """ + nic = self.arp.get_arp_cache_nic(target_ip_address) if nic: sequence += 1 - target_mac_address = self._get_arp_cache_mac_address(target_ip_address) - src_nic = self._get_arp_cache_nic(target_ip_address) + target_mac_address = self.arp.get_arp_cache_mac_address(target_ip_address) + src_nic = self.arp.get_arp_cache_nic(target_ip_address) tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP) # Network Layer @@ -647,17 +739,143 @@ class Node(SimComponent): return sequence, icmp_packet.identifier else: self.sys_log.info(f"No entry in ARP cache for {target_ip_address}") - self._send_arp_request(target_ip_address) + self.arp.send_arp_request(target_ip_address) return 0, None + +class NodeOperatingState(Enum): + """Enumeration of Node Operating States.""" + + OFF = 0 + "The node is powered off." + ON = 1 + "The node is powered on." + SHUTTING_DOWN = 2 + "The node is in the process of shutting down." + BOOTING = 3 + "The node is in the process of booting up." + + +class Node(SimComponent): + """ + A basic Node class that represents a node on the network. + + This class manages the state of the node, including the NICs (Network Interface Cards), accounts, applications, + services, processes, file system, and various managers like ARP, ICMP, SessionManager, and SoftwareManager. + + :param hostname: The node hostname on the network. + :param operating_state: The node operating state, either ON or OFF. + """ + + hostname: str + "The node hostname on the network." + operating_state: NodeOperatingState = NodeOperatingState.OFF + "The hardware state of the node." + nics: Dict[str, NIC] = {} + "The NICs on the node." + + accounts: Dict = {} + "All accounts on the node." + applications: Dict = {} + "All applications on the node." + services: Dict = {} + "All services on the node." + processes: Dict = {} + "All processes on the node." + file_system: Any = None + "The nodes file system." + sys_log: SysLog + arp: ARPCache + icmp: ICMP + session_manager: SessionManager + software_manager: SoftwareManager + + revealed_to_red: bool = False + "Informs whether the node has been revealed to a red agent." + + def __init__(self, **kwargs): + """ + Initialize the Node with various components and managers. + + This method initializes the ARP cache, ICMP handler, session manager, and software manager if they are not + provided. + """ + if not kwargs.get("sys_log"): + kwargs["sys_log"] = SysLog(kwargs["hostname"]) + if not kwargs.get("arp_cache"): + kwargs["arp"] = ARPCache(sys_log=kwargs.get("sys_log")) + if not kwargs.get("icmp"): + kwargs["icmp"] = ICMP(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp")) + if not kwargs.get("session_manager"): + kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp")) + if not kwargs.get("software_manager"): + kwargs["software_manager"] = SoftwareManager( + sys_log=kwargs.get("sys_log"), session_manager=kwargs.get("session_manager") + ) + super().__init__(**kwargs) + self.arp.nics = self.nics + + def turn_on(self): + """Turn on the Node, enabling its NICs if it is in the OFF state.""" + if self.operating_state == NodeOperatingState.OFF: + self.operating_state = NodeOperatingState.ON + self.sys_log.info("Turned on") + for nic in self.nics.values(): + nic.enable() + + def turn_off(self): + """Turn off the Node, disabling its NICs if it is in the ON state.""" + if self.operating_state == NodeOperatingState.ON: + for nic in self.nics.values(): + nic.disable() + self.operating_state = NodeOperatingState.OFF + self.sys_log.info("Turned off") + + def connect_nic(self, nic: NIC): + """ + Connect a NIC (Network Interface Card) to the node. + + :param nic: The NIC to connect. + :raise NetworkError: If the NIC is already connected. + """ + if nic.uuid not in self.nics: + self.nics[nic.uuid] = nic + nic.connected_node = self + self.sys_log.info(f"Connected NIC {nic}") + if self.operating_state == NodeOperatingState.ON: + nic.enable() + else: + msg = f"Cannot connect NIC {nic} as it is already connected" + self.sys_log.logger.error(msg) + _LOGGER.error(msg) + raise NetworkError(msg) + + def disconnect_nic(self, nic: Union[NIC, str]): + """ + Disconnect a NIC (Network Interface Card) from the node. + + :param nic: The NIC to Disconnect, or its UUID. + :raise NetworkError: If the NIC is not connected. + """ + if isinstance(nic, str): + nic = self.nics.get(nic) + if nic or nic.uuid in self.nics: + self.nics.pop(nic.uuid) + nic.disable() + self.sys_log.info(f"Disconnected NIC {nic}") + else: + msg = f"Cannot disconnect NIC {nic} as it is not connected" + self.sys_log.logger.error(msg) + _LOGGER.error(msg) + raise NetworkError(msg) + def ping(self, target_ip_address: Union[IPv4Address, str], pings: int = 4) -> bool: """ - Ping an IP address. - - Performs a standard ICMP echo request/response four times. + Ping an IP address, performing a standard ICMP echo request/response. :param target_ip_address: The target IP address to ping. - :return: True if successful, otherwise False. + :param pings: The number of pings to attempt, default is 4. + :return: True if the ping is successful, otherwise False. """ if not isinstance(target_ip_address, IPv4Address): target_ip_address = IPv4Address(target_ip_address) @@ -665,7 +883,7 @@ class Node(SimComponent): self.sys_log.info(f"Attempting to ping {target_ip_address}") sequence, identifier = 0, None while sequence < pings: - sequence, identifier = self._ping(target_ip_address, sequence, identifier) + sequence, identifier = self.icmp.ping(target_ip_address, sequence, identifier) return True self.sys_log.info("Ping failed as the node is turned off") return False @@ -681,20 +899,101 @@ class Node(SimComponent): def receive_frame(self, frame: Frame, from_nic: NIC): """ - Receive a Frame from the connected NIC. + Receive a Frame from the connected NIC and process it. - The Frame is passed to up to the SessionManager. + Depending on the protocol, the frame is passed to the appropriate handler such as ARP or ICMP, or up to the + SessionManager if no code manager exists. :param frame: The Frame being received. + :param from_nic: The NIC that received the frame. """ if frame.ip.protocol == IPProtocol.TCP: if frame.tcp.src_port == Port.ARP: - self.process_arp_packet(from_nic=from_nic, arp_packet=frame.arp) + self.arp.process_arp_packet(from_nic=from_nic, arp_packet=frame.arp) elif frame.ip.protocol == IPProtocol.UDP: pass elif frame.ip.protocol == IPProtocol.ICMP: - self.process_icmp(frame=frame) + self.icmp.process_icmp(frame=frame) def describe_state(self) -> Dict: - """Describe the state of a Node.""" + """ + Describe the state of the Node. + + :return: A dictionary representing the state of the node. + """ pass + + +class Switch(Node): + """A class representing a Layer 2 network switch.""" + + num_ports: int = 24 + "The number of ports on the switch." + switch_ports: Dict[int, SwitchPort] = {} + "The SwitchPorts on the switch." + dst_mac_table: Dict[str, SwitchPort] = {} + "A MAC address table mapping destination MAC addresses to corresponding SwitchPorts." + + def describe_state(self) -> Dict: + """TODO.""" + pass + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if not self.switch_ports: + self.switch_ports = {i: SwitchPort() for i in range(1, self.num_ports + 1)} + for port_num, port in self.switch_ports.items(): + port.connected_node = self + port.port_num = port_num + + def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort): + mac_table_port = self.dst_mac_table.get(mac_address) + if not mac_table_port: + self.dst_mac_table[mac_address] = switch_port + self.sys_log.info(f"Added MAC table entry: Port {switch_port.port_num} -> {mac_address}") + else: + if mac_table_port != switch_port: + self.dst_mac_table.pop(mac_address) + self.sys_log.info(f"Removed MAC table entry: Port {mac_table_port.port_num} -> {mac_address}") + self._add_mac_table_entry(mac_address, switch_port) + + def forward_frame(self, frame: Frame, incoming_port: SwitchPort): + """ + Forward a frame to the appropriate port based on the destination MAC address. + + :param frame: The Frame to be forwarded. + :param incoming_port: The port number from which the frame was received. + """ + src_mac = frame.ethernet.src_mac_addr + dst_mac = frame.ethernet.dst_mac_addr + self._add_mac_table_entry(src_mac, incoming_port) + + outgoing_port = self.dst_mac_table.get(dst_mac) + if outgoing_port or dst_mac != "ff:ff:ff:ff:ff:ff": + outgoing_port.send_frame(frame) + else: + # If the destination MAC is not in the table, flood to all ports except incoming + for port in self.switch_ports.values(): + if port != incoming_port: + port.send_frame(frame) + + def disconnect_link_from_port(self, link: Link, port_number: int): + """ + Disconnect a given link from the specified port number on the switch. + + :param link: The Link object to be disconnected. + :param port_number: The port number on the switch from where the link should be disconnected. + :raise NetworkError: When an invalid port number is provided or the link does not match the connection. + """ + port = self.switch_ports.get(port_number) + if port is None: + msg = f"Invalid port number {port_number} on the switch" + _LOGGER.error(msg) + raise NetworkError(msg) + + if port.connected_link != link: + msg = f"The link does not match the connection at port number {port_number}" + _LOGGER.error(msg) + raise NetworkError(msg) + + port.disconnect_link() diff --git a/src/primaite/simulator/system/services/icmp.py b/src/primaite/simulator/network/nodes/switch.py similarity index 100% rename from src/primaite/simulator/system/services/icmp.py rename to src/primaite/simulator/network/nodes/switch.py diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 31a645b5..f9c5827d 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -1,6 +1,6 @@ from abc import abstractmethod from enum import Enum -from typing import Any, List, Dict, Set +from typing import Any, Dict, List, Set from primaite.simulator.system.software import IOSoftware @@ -22,6 +22,7 @@ class Application(IOSoftware): Applications are user-facing programs that may perform input/output operations. """ + operating_state: ApplicationOperatingState "The current operating state of the Application." execution_control_status: str @@ -61,9 +62,9 @@ class Application(IOSoftware): """ pass - def send(self, payload: Any) -> bool: + def send(self, payload: Any, session_id: str, **kwargs) -> bool: """ - Sends a payload to the SessionManager + Sends a payload to the SessionManager. The specifics of how the payload is processed and whether a response payload is generated should be implemented in subclasses. @@ -73,7 +74,7 @@ class Application(IOSoftware): """ pass - def receive(self, payload: Any) -> bool: + def receive(self, payload: Any, session_id: str, **kwargs) -> bool: """ Receives a payload from the SessionManager. diff --git a/src/primaite/simulator/system/arp_cache.py b/src/primaite/simulator/system/arp_cache.py deleted file mode 100644 index 1fb830ab..00000000 --- a/src/primaite/simulator/system/arp_cache.py +++ /dev/null @@ -1,30 +0,0 @@ -from ipaddress import IPv4Address - -from pydantic import BaseModel - - -class ARPCacheService(BaseModel): - def __init__(self, node): - super().__init__() - self.node = node - - def _add_arp_cache_entry(self, ip_address: IPv4Address, mac_address: str, nic: NIC): - ... - - def _remove_arp_cache_entry(self, ip_address: IPv4Address): - ... - - def _get_arp_cache_mac_address(self, ip_address: IPv4Address) -> Optional[str]: - ... - - def _get_arp_cache_nic(self, ip_address: IPv4Address) -> Optional[NIC]: - ... - - def _clear_arp_cache(self): - ... - - def _send_arp_request(self, target_ip_address: Union[IPv4Address, str]): - ... - - def process_arp_packet(self, from_nic: NIC, arp_packet: ARPPacket): - ... \ No newline at end of file diff --git a/src/primaite/simulator/system/core/__init__.py b/src/primaite/simulator/system/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/simulator/system/packet_capture.py b/src/primaite/simulator/system/core/packet_capture.py similarity index 82% rename from src/primaite/simulator/system/packet_capture.py rename to src/primaite/simulator/system/core/packet_capture.py index c05b6db9..7741416d 100644 --- a/src/primaite/simulator/system/packet_capture.py +++ b/src/primaite/simulator/system/core/packet_capture.py @@ -1,5 +1,8 @@ import logging from pathlib import Path +from typing import Optional + +from primaite.simulator import TEMP_SIM_OUTPUT class _JSONFilter(logging.Filter): @@ -17,7 +20,7 @@ class PacketCapture: The PCAPs are logged to: //__pcap.log """ - def __init__(self, hostname: str, ip_address: str): + def __init__(self, hostname: str, ip_address: Optional[str] = None): """ Initialize the PacketCapture process. @@ -40,7 +43,7 @@ class PacketCapture: log_format = "%(message)s" file_handler.setFormatter(logging.Formatter(log_format)) - logger_name = f"{self.hostname}_{self.ip_address}_pcap" + logger_name = f"{self.hostname}_{self.ip_address}_pcap" if self.ip_address else f"{self.hostname}_pcap" self.logger = logging.getLogger(logger_name) self.logger.setLevel(60) # Custom log level > CRITICAL to prevent any unwanted standard DEBUG-CRITICAL logs self.logger.addHandler(file_handler) @@ -49,9 +52,11 @@ class PacketCapture: def _get_log_path(self) -> Path: """Get the path for the log file.""" - root = Path(__file__).parent.parent.parent.parent.parent.parent / "simulation_output" / self.hostname + root = TEMP_SIM_OUTPUT / self.hostname root.mkdir(exist_ok=True, parents=True) - return root / f"{self.hostname}_{self.ip_address}_pcap.log" + if self.ip_address: + return root / f"{self.hostname}_{self.ip_address}_pcap.log" + return root / f"{self.hostname}_pcap.log" def capture(self, frame): # noqa - I'll have a circular import and cant use if TYPE_CHECKING ;( """ diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py new file mode 100644 index 00000000..96d6251d --- /dev/null +++ b/src/primaite/simulator/system/core/session_manager.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +from ipaddress import IPv4Address +from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING + +from primaite.simulator.core import SimComponent +from primaite.simulator.network.transmission.data_link_layer import Frame +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port + +if TYPE_CHECKING: + from primaite.simulator.network.hardware.base import ARPCache + from primaite.simulator.system.core.software_manager import SoftwareManager + from primaite.simulator.system.core.sys_log import SysLog + + +class Session(SimComponent): + """ + Models a network session. + + Encapsulates information related to communication between two network endpoints, including the protocol, + source and destination IPs and ports. + + :param protocol: The IP protocol used in the session. + :param src_ip: The source IP address. + :param dst_ip: The destination IP address. + :param src_port: The source port number (optional). + :param dst_port: The destination port number (optional). + :param connected: A flag indicating whether the session is connected. + """ + + protocol: IPProtocol + src_ip: IPv4Address + dst_ip: IPv4Address + src_port: Optional[Port] + dst_port: Optional[Port] + connected: bool = False + + @classmethod + def from_session_key( + cls, session_key: Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]] + ) -> Session: + """ + Create a Session instance from a session key tuple. + + :param session_key: Tuple containing the session details. + :return: A Session instance. + """ + protocol, src_ip, dst_ip, src_port, dst_port = session_key + return Session(protocol=protocol, src_ip=src_ip, dst_ip=dst_ip, src_port=src_port, dst_port=dst_port) + + def describe_state(self) -> Dict: + """ + Describes the current state of the session as a dictionary. + + :return: A dictionary containing the current state of the session. + """ + pass + + +class SessionManager: + """ + Manages network sessions, including session creation, lookup, and communication with other components. + + :param sys_log: A reference to the system log component. + :param arp_cache: A reference to the ARP cache component. + """ + + def __init__(self, sys_log: SysLog, arp_cache: "ARPCache"): + self.sessions_by_key: Dict[ + Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]], Session + ] = {} + self.sessions_by_uuid: Dict[str, Session] = {} + self.sys_log: SysLog = sys_log + self.software_manager: SoftwareManager = None # Noqa + self.arp_cache: "ARPCache" = arp_cache + + def describe_state(self) -> Dict: + """ + Describes the current state of the session manager as a dictionary. + + :return: A dictionary containing the current state of the session manager. + """ + pass + + @staticmethod + def _get_session_key( + frame: Frame, from_source: bool = True + ) -> Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]]: + """ + Extracts the session key from the given frame. + + The session key is a tuple containing the following elements: + - IPProtocol: The transport protocol (e.g. TCP, UDP, ICMP). + - IPv4Address: The source IP address. + - IPv4Address: The destination IP address. + - Optional[Port]: The source port number (if applicable). + - Optional[Port]: The destination port number (if applicable). + + :param frame: The network frame from which to extract the session key. + :param from_source: A flag to indicate if the key should be extracted from the source or destination. + :return: A tuple containing the session key. + """ + protocol = frame.ip.protocol + src_ip = frame.ip.src_ip + dst_ip = frame.ip.dst_ip + if protocol == IPProtocol.TCP: + if from_source: + src_port = frame.tcp.src_port + dst_port = frame.tcp.dst_port + else: + dst_port = frame.tcp.src_port + src_port = frame.tcp.dst_port + elif protocol == IPProtocol.UDP: + if from_source: + src_port = frame.udp.src_port + dst_port = frame.udp.dst_port + else: + dst_port = frame.udp.src_port + src_port = frame.udp.dst_port + else: + src_port = None + dst_port = None + return protocol, src_ip, dst_ip, src_port, dst_port + + def receive_payload_from_software_manager(self, payload: Any, session_id: Optional[int] = None): + """ + Receive a payload from the SoftwareManager. + + If no session_id, a Session is established. Once established, the payload is sent to ``send_payload_to_nic``. + + :param payload: The payload to be sent. + :param session_id: The Session ID the payload is to originate from. Optional. If None, one will be created. + """ + # TODO: Implement session creation and + + self.send_payload_to_nic(payload, session_id) + + def send_payload_to_software_manager(self, payload: Any, session_id: int): + """ + Send a payload to the software manager. + + :param payload: The payload to be sent. + :param session_id: The Session ID the payload originates from. + """ + self.software_manager.receive_payload_from_session_manger() + + def send_payload_to_nic(self, payload: Any, session_id: int): + """ + Send a payload across the Network. + + Takes a payload and a session_id. Builds a Frame and sends it across the network via a NIC. + + :param payload: The payload to be sent. + :param session_id: The Session ID the payload originates from + """ + # TODO: Implement frame construction and sent to NIC. + pass + + def receive_payload_from_nic(self, frame: Frame): + """ + Receive a Frame from the NIC. + + Extract the session key using the _get_session_key method, and forward the payload to the appropriate + session. If the session does not exist, a new one is created. + + :param frame: The frame being received. + """ + session_key = self._get_session_key(frame) + session = self.sessions_by_key.get(session_key) + if not session: + # Create new session + session = Session.from_session_key(session_key) + self.sessions_by_key[session_key] = session + self.sessions_by_uuid[session.uuid] = session + self.software_manager.receive_payload_from_session_manger(payload=frame, session=session) + # TODO: Implement the frame deconstruction and send to SoftwareManager. diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py new file mode 100644 index 00000000..411fb6e9 --- /dev/null +++ b/src/primaite/simulator/system/core/software_manager.py @@ -0,0 +1,99 @@ +from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union + +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.application import Application +from primaite.simulator.system.core.session_manager import Session +from primaite.simulator.system.core.sys_log import SysLog +from primaite.simulator.system.services.service import Service +from primaite.simulator.system.software import SoftwareType + +if TYPE_CHECKING: + from primaite.simulator.system.core.session_manager import SessionManager + from primaite.simulator.system.core.sys_log import SysLog + + +class SoftwareManager: + """A class that manages all running Services and Applications on a Node and facilitates their communication.""" + + def __init__(self, session_manager: "SessionManager", sys_log: "SysLog"): + """ + Initialize a new instance of SoftwareManager. + + :param session_manager: The session manager handling network communications. + """ + self.session_manager = session_manager + self.services: Dict[str, Service] = {} + self.applications: Dict[str, Application] = {} + self.port_protocol_mapping: Dict[Tuple[Port, IPProtocol], Union[Service, Application]] = {} + self.sys_log: SysLog = sys_log + + def add_service(self, name: str, service: Service, port: Port, protocol: IPProtocol): + """ + Add a Service to the manager. + + :param name: The name of the service. + :param service: The service instance. + :param port: The port used by the service. + :param protocol: The network protocol used by the service. + """ + service.software_manager = self + self.services[name] = service + self.port_protocol_mapping[(port, protocol)] = service + + def add_application(self, name: str, application: Application, port: Port, protocol: IPProtocol): + """ + Add an Application to the manager. + + :param name: The name of the application. + :param application: The application instance. + :param port: The port used by the application. + :param protocol: The network protocol used by the application. + """ + application.software_manager = self + self.applications[name] = application + self.port_protocol_mapping[(port, protocol)] = application + + def send_internal_payload(self, target_software: str, target_software_type: SoftwareType, payload: Any): + """ + Send a payload to a specific service or application. + + :param target_software: The name of the target service or application. + :param target_software_type: The type of software (Service, Application, Process). + :param payload: The data to be sent. + :param receiver_type: The type of the target, either 'service' or 'application'. + """ + if target_software_type is SoftwareType.SERVICE: + receiver = self.services.get(target_software) + elif target_software_type is SoftwareType.APPLICATION: + receiver = self.applications.get(target_software) + else: + raise ValueError(f"Invalid receiver type {target_software_type}") + + if receiver: + receiver.receive_payload(payload) + else: + raise ValueError(f"No {target_software_type.name.lower()} found with the name {target_software}") + + def send_payload_to_session_manger(self, payload: Any, session_id: Optional[int] = None): + """ + Send a payload to the SessionManager. + + :param payload: The payload to be sent. + :param session_id: The Session ID the payload is to originate from. Optional. + """ + self.session_manager.receive_payload_from_software_manager(payload, session_id) + + def receive_payload_from_session_manger(self, payload: Any, session: Session): + """ + Receive a payload from the SessionManager and forward it to the corresponding service or application. + + :param payload: The payload being received. + :param session: The transport session the payload originates from. + """ + # receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None) + # if receiver: + # receiver.receive_payload(None, payload) + # else: + # raise ValueError(f"No service or application found for port {port} and protocol {protocol}") + pass diff --git a/src/primaite/simulator/system/sys_log.py b/src/primaite/simulator/system/core/sys_log.py similarity index 90% rename from src/primaite/simulator/system/sys_log.py rename to src/primaite/simulator/system/core/sys_log.py index bb2fd7ec..4b858c2e 100644 --- a/src/primaite/simulator/system/sys_log.py +++ b/src/primaite/simulator/system/core/sys_log.py @@ -1,6 +1,8 @@ import logging from pathlib import Path +from primaite.simulator import TEMP_SIM_OUTPUT + class _NotJSONFilter(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: @@ -31,8 +33,10 @@ class SysLog: def _setup_logger(self): """ - Configures the logger for this SysLog instance. The logger is set to the DEBUG level, - and is equipped with a handler that writes to a file and filters out JSON-like messages. + Configures the logger for this SysLog instance. + + The logger is set to the DEBUG level, and is equipped with a handler that writes to a file and filters out + JSON-like messages. """ log_path = self._get_log_path() @@ -54,7 +58,7 @@ class SysLog: :return: Path object representing the location of the log file. """ - root = Path(__file__).parent.parent.parent.parent.parent.parent / "simulation_output" / self.hostname + root = TEMP_SIM_OUTPUT / self.hostname root.mkdir(exist_ok=True, parents=True) return root / f"{self.hostname}_sys.log" diff --git a/src/primaite/simulator/system/processes/process.py b/src/primaite/simulator/system/processes/process.py index 68f3102f..bbd94345 100644 --- a/src/primaite/simulator/system/processes/process.py +++ b/src/primaite/simulator/system/processes/process.py @@ -1,6 +1,6 @@ from abc import abstractmethod from enum import Enum -from typing import List, Dict, Any +from typing import Dict from primaite.simulator.system.software import Software @@ -20,6 +20,7 @@ class Process(Software): Processes are executed by a Node and do not have the ability to performing input/output operations. """ + operating_state: ProcessOperatingState "The current operating state of the Process." diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index a66249ad..c820cef3 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -28,6 +28,7 @@ class Service(IOSoftware): Services are programs that run in the background and may perform input/output operations. """ + operating_state: ServiceOperatingState "The current operating state of the Service." @@ -61,9 +62,9 @@ class Service(IOSoftware): """ pass - def send(self, payload: Any) -> bool: + def send(self, payload: Any, session_id: str, **kwargs) -> bool: """ - Sends a payload to the SessionManager + Sends a payload to the SessionManager. The specifics of how the payload is processed and whether a response payload is generated should be implemented in subclasses. @@ -73,7 +74,7 @@ class Service(IOSoftware): """ pass - def receive(self, payload: Any) -> bool: + def receive(self, payload: Any, session_id: str, **kwargs) -> bool: """ Receives a payload from the SessionManager. @@ -84,4 +85,3 @@ class Service(IOSoftware): :return: True if successful, False otherwise. """ pass - diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index e5991429..854e7e2b 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -6,6 +6,24 @@ from primaite.simulator.core import SimComponent from primaite.simulator.network.transmission.transport_layer import Port +class SoftwareType(Enum): + """ + An enumeration representing the different types of software within a simulated environment. + + Members: + - APPLICATION: User-facing programs that may perform input/output operations. + - SERVICE: Represents programs that run in the background and may perform input/output operations. + - PROCESS: Software executed by a Node that does not have the ability to performing input/output operations. + """ + + APPLICATION = 1 + "User-facing software that may perform input/output operations." + SERVICE = 2 + "Software that runs in the background and may perform input/output operations." + PROCESS = 3 + "Software executed by a Node that does not have the ability to performing input/output operations." + + class SoftwareHealthState(Enum): """Enumeration of the Software Health States.""" @@ -41,6 +59,7 @@ class Software(SimComponent): This class is intended to be subclassed by specific types of software entities. It outlines the fundamental attributes and behaviors expected of any software in the simulation. """ + name: str "The name of the software." health_state_actual: SoftwareHealthState @@ -100,6 +119,7 @@ class IOSoftware(Software): OSI Model), process them according to their internals, and send a response payload back to the SessionManager if required. """ + installing_count: int = 0 "The number of times the software has been installed. Default is 0." max_sessions: int = 1 @@ -111,26 +131,44 @@ class IOSoftware(Software): ports: Set[Port] "The set of ports to which the software is connected." - def send(self, payload: Any) -> bool: + @abstractmethod + def describe_state(self) -> Dict: """ - Sends a payload to the SessionManager + Describes the current state of the software. + + The specifics of the software's state, including its health, criticality, + and any other pertinent information, should be implemented in subclasses. + + :return: A dictionary containing key-value pairs representing the current state of the software. + :rtype: Dict + """ + pass + + def send(self, payload: Any, session_id: str, **kwargs) -> bool: + """ + Sends a payload to the SessionManager. The specifics of how the payload is processed and whether a response payload is generated should be implemented in subclasses. :param payload: The payload to send. - :return: True if successful, False otherwise. + :param session_id: The identifier of the session that the payload is associated with. + :param kwargs: Additional keyword arguments specific to the implementation. + :return: True if the payload was successfully sent, False otherwise. """ pass - def receive(self, payload: Any) -> bool: + def receive(self, payload: Any, session_id: str, **kwargs) -> bool: """ Receives a payload from the SessionManager. The specifics of how the payload is processed and whether a response payload is generated should be implemented in subclasses. + :param payload: The payload to receive. - :return: True if successful, False otherwise. + :param session_id: The identifier of the session that the payload is associated with. + :param kwargs: Additional keyword arguments specific to the implementation. + :return: True if the payload was successfully received and processed, False otherwise. """ pass diff --git a/tests/integration_tests/network/test_frame_transmission.py b/tests/integration_tests/network/test_frame_transmission.py index 9681e72d..27545edc 100644 --- a/tests/integration_tests/network/test_frame_transmission.py +++ b/tests/integration_tests/network/test_frame_transmission.py @@ -1,4 +1,4 @@ -from primaite.simulator.network.hardware.base import Link, NIC, Node +from primaite.simulator.network.hardware.base import Link, NIC, Node, Switch def test_node_to_node_ping(): @@ -35,10 +35,40 @@ def test_multi_nic(): node_c.connect_nic(nic_c) node_c.turn_on() - link_a_b1 = Link(endpoint_a=nic_a, endpoint_b=nic_b1) + Link(endpoint_a=nic_a, endpoint_b=nic_b1) - link_b2_c = Link(endpoint_a=nic_b2, endpoint_b=nic_c) + Link(endpoint_a=nic_b2, endpoint_b=nic_c) node_a.ping("192.168.0.11") - node_c.ping("10.0.0.12") \ No newline at end of file + node_c.ping("10.0.0.12") + + +def test_switched_network(): + node_a = Node(hostname="node_a") + nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", gateway="192.168.0.1") + node_a.connect_nic(nic_a) + node_a.turn_on() + + node_b = Node(hostname="node_b") + nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0", gateway="192.168.0.1") + node_b.connect_nic(nic_b) + node_b.turn_on() + + node_c = Node(hostname="node_c") + nic_c = NIC(ip_address="192.168.0.12", subnet_mask="255.255.255.0", gateway="192.168.0.1") + node_c.connect_nic(nic_c) + node_c.turn_on() + + switch_1 = Switch(hostname="switch_1") + switch_1.turn_on() + + switch_2 = Switch(hostname="switch_2") + switch_2.turn_on() + + Link(endpoint_a=nic_a, endpoint_b=switch_1.switch_ports[1]) + Link(endpoint_a=nic_b, endpoint_b=switch_1.switch_ports[2]) + Link(endpoint_a=switch_1.switch_ports[24], endpoint_b=switch_2.switch_ports[24]) + Link(endpoint_a=nic_c, endpoint_b=switch_2.switch_ports[1]) + + node_a.ping("192.168.0.12")