diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 403d9638..69f93f51 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -679,7 +679,7 @@ class Node(SimComponent): if not kwargs.get("sys_log"): kwargs["sys_log"] = SysLog(kwargs["hostname"]) if not kwargs.get("session_manager"): - kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp")) + kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log")) if not kwargs.get("root"): kwargs["root"] = SIM_OUTPUT.path / kwargs["hostname"] if not kwargs.get("file_system"): diff --git a/src/primaite/simulator/network/transmission/network_layer.py b/src/primaite/simulator/network/transmission/network_layer.py index b581becd..38fc1977 100644 --- a/src/primaite/simulator/network/transmission/network_layer.py +++ b/src/primaite/simulator/network/transmission/network_layer.py @@ -1,6 +1,6 @@ import secrets from enum import Enum -from ipaddress import IPv4Address +from ipaddress import IPv4Address, IPv4Network from typing import Union from pydantic import BaseModel, field_validator, validate_call @@ -86,10 +86,3 @@ class IPPacket(BaseModel): "Time to Live (TTL) for the packet." precedence: Precedence = Precedence.ROUTINE "Precedence level for Quality of Service (default is Precedence.ROUTINE)." - - def __init__(self, **kwargs): - if not isinstance(kwargs["src_ip_address"], IPv4Address): - kwargs["src_ip_address"] = IPv4Address(kwargs["src_ip_address"]) - if not isinstance(kwargs["dst_ip_address"], IPv4Address): - kwargs["dst_ip_address"] = IPv4Address(kwargs["dst_ip_address"]) - super().__init__(**kwargs) diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index a748b7df..ce05193f 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -75,7 +75,7 @@ class SessionManager: :param arp_cache: A reference to the ARP cache component. """ - def __init__(self, sys_log: SysLog, arp_cache: "ARPCache"): + def __init__(self, sys_log: SysLog): self.sessions_by_key: Dict[ Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]], Session ] = {} @@ -150,8 +150,8 @@ class SessionManager: def resolve_outbound_transmission_details( self, dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, session_id: Optional[str] = None - ) -> Tuple[Optional["NIC"], Optional[str], Optional[IPProtocol], bool]: - if not isinstance(dst_ip_address, IPv4Address): + ) -> Tuple[Optional["NIC"], Optional[str], IPv4Address, Optional[IPProtocol], bool]: + if not isinstance(dst_ip_address, (IPv4Address, IPv4Network)): dst_ip_address = IPv4Address(dst_ip_address) is_broadcast = False outbound_nic = None @@ -192,7 +192,7 @@ class SessionManager: if use_default_gateway: dst_mac_address = self.software_manager.arp.get_default_gateway_mac_address() outbound_nic = self.software_manager.arp.get_default_gateway_nic() - return outbound_nic, dst_mac_address, protocol, is_broadcast + return outbound_nic, dst_mac_address, dst_ip_address, protocol, is_broadcast def receive_payload_from_software_manager( self, @@ -226,14 +226,13 @@ class SessionManager: is_broadcast = payload.request ip_protocol = IPProtocol.UDP else: - outbound_nic, dst_mac_address, protocol, is_broadcast = self.resolve_outbound_transmission_details( + vals = self.resolve_outbound_transmission_details( dst_ip_address=dst_ip_address, session_id=session_id ) - + outbound_nic, dst_mac_address, dst_ip_address, protocol, is_broadcast = vals if protocol: ip_protocol = protocol - # Check if outbound NIC and destination MAC address are resolved if not outbound_nic or not dst_mac_address: return False @@ -241,7 +240,7 @@ class SessionManager: tcp_header = None udp_header = None if ip_protocol == IPProtocol.TCP: - TCPHeader( + tcp_header = TCPHeader( src_port=dst_port, dst_port=dst_port, ) @@ -250,7 +249,6 @@ class SessionManager: src_port=dst_port, dst_port=dst_port, ) - # Construct the frame for transmission frame = Frame( ethernet=EthernetHeader(src_mac_addr=outbound_nic.mac_address, dst_mac_addr=dst_mac_address), diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index ac765018..99dc5f38 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -162,6 +162,7 @@ class SoftwareManager: payload=payload, dst_ip_address=dest_ip_address, dst_port=dest_port, + ip_protocol=ip_protocol, session_id=session_id, ) diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 8930fa2f..91629f9a 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -356,6 +356,7 @@ class IOSoftware(Software): session_id: Optional[str] = None, dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, dest_port: Optional[Port] = None, + ip_protocol: IPProtocol = IPProtocol.TCP, **kwargs, ) -> bool: """ @@ -375,7 +376,11 @@ class IOSoftware(Software): return False return self.software_manager.send_payload_to_session_manager( - payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id + payload=payload, + dest_ip_address=dest_ip_address, + dest_port=dest_port, + ip_protocol=ip_protocol, + session_id=session_id ) @abstractmethod diff --git a/tests/conftest.py b/tests/conftest.py index c37226a5..8e458878 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -134,31 +134,72 @@ def temp_primaite_session(request, monkeypatch) -> TempPrimaiteSession: @pytest.fixture(scope="function") def client_server() -> Tuple[Computer, Server]: + network = Network() + # Create Computer - computer: Computer = Computer( - hostname="test_computer", - ip_address="192.168.0.1", + computer = Computer( + hostname="computer", + ip_address="192.168.1.2", subnet_mask="255.255.255.0", default_gateway="192.168.1.1", - operating_state=NodeOperatingState.ON, + start_up_duration=0, ) + computer.power_on() # Create Server server = Server( - hostname="server", ip_address="192.168.0.2", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON + hostname="server", + ip_address="192.168.1.3", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, ) + server.power_on() # Connect Computer and Server - computer_nic = computer.nics[next(iter(computer.nics))] - server_nic = server.nics[next(iter(server.nics))] - link = Link(endpoint_a=computer_nic, endpoint_b=server_nic) + network.connect(computer.ethernet_port[1], server.ethernet_port[1]) # Should be linked - assert link.is_up + assert next(iter(network.links.values())).is_up return computer, server +@pytest.fixture(scope="function") +def client_switch_server() -> Tuple[Computer, Switch, Server]: + network = Network() + + # Create Computer + computer = Computer( + hostname="computer", + ip_address="192.168.1.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + computer.power_on() + + # Create Server + server = Server( + hostname="server", + ip_address="192.168.1.3", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + server.power_on() + + switch = Switch(hostname="switch", start_up_duration=0) + switch.power_on() + + network.connect(endpoint_a=computer.ethernet_port[1], endpoint_b=switch.switch_ports[1]) + network.connect(endpoint_a=server.ethernet_port[1], endpoint_b=switch.switch_ports[2]) + + assert all(link.is_up for link in network.links.values()) + + return computer, switch, server + + @pytest.fixture(scope="function") def example_network() -> Network: """ diff --git a/tests/integration_tests/network/test_broadcast.py b/tests/integration_tests/network/test_broadcast.py index b9ecb28b..5fb0917e 100644 --- a/tests/integration_tests/network/test_broadcast.py +++ b/tests/integration_tests/network/test_broadcast.py @@ -41,6 +41,7 @@ class BroadcastService(Service): payload="broadcast", dest_ip_address=ip_network, dest_port=Port.HTTP, + ip_protocol=self.protocol ) diff --git a/tests/integration_tests/network/test_frame_transmission.py b/tests/integration_tests/network/test_frame_transmission.py index 7da9fe76..527e4b4c 100644 --- a/tests/integration_tests/network/test_frame_transmission.py +++ b/tests/integration_tests/network/test_frame_transmission.py @@ -1,34 +1,54 @@ -from primaite.simulator.network.hardware.base import Link, NIC, Node, NodeOperatingState +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.nodes.computer import Computer +from primaite.simulator.network.hardware.nodes.server import Server +from primaite.simulator.network.hardware.nodes.switch import Switch + def test_node_to_node_ping(): - """Tests two Nodes are able to ping each other.""" - node_a = Node(hostname="node_a", operating_state=NodeOperatingState.ON) - nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON) - node_a.connect_nic(nic_a) + """Tests two Computers are able to ping each other.""" + network = Network() - node_b = Node(hostname="node_b", operating_state=NodeOperatingState.ON) - nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0") - node_b.connect_nic(nic_b) + client_1 = Computer( + hostname="client_1", + ip_address="192.168.1.10", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + client_1.power_on() - Link(endpoint_a=nic_a, endpoint_b=nic_b) + server_1 = Server( + hostname="server_1", + ip_address="192.168.1.11", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + server_1.power_on() - assert node_a.ping("192.168.0.11") + switch_1 = Switch(hostname="switch_1", start_up_duration=0) + switch_1.power_on() + + network.connect(endpoint_a=client_1.ethernet_port[1], endpoint_b=switch_1.switch_ports[1]) + network.connect(endpoint_a=server_1.ethernet_port[1], endpoint_b=switch_1.switch_ports[2]) + + assert client_1.ping("192.168.1.11") def test_multi_nic(): - """Tests that Nodes with multiple NICs can ping each other and the data go across the correct links.""" - node_a = Node(hostname="node_a", operating_state=NodeOperatingState.ON) + """Tests that Computers with multiple NICs can ping each other and the data go across the correct links.""" + node_a = Computer(hostname="node_a", operating_state=ComputerOperatingState.ON) nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0") node_a.connect_nic(nic_a) - node_b = Node(hostname="node_b", operating_state=NodeOperatingState.ON) + node_b = Computer(hostname="node_b", operating_state=ComputerOperatingState.ON) nic_b1 = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0") nic_b2 = NIC(ip_address="10.0.0.12", subnet_mask="255.0.0.0") node_b.connect_nic(nic_b1) node_b.connect_nic(nic_b2) - node_c = Node(hostname="node_c", operating_state=NodeOperatingState.ON) + node_c = Computer(hostname="node_c", operating_state=ComputerOperatingState.ON) nic_c = NIC(ip_address="10.0.0.13", subnet_mask="255.0.0.0") node_c.connect_nic(nic_c) diff --git a/tests/integration_tests/network/test_link_connection.py b/tests/integration_tests/network/test_link_connection.py deleted file mode 100644 index c6aeac24..00000000 --- a/tests/integration_tests/network/test_link_connection.py +++ /dev/null @@ -1,24 +0,0 @@ -from primaite.simulator.network.hardware.base import Link, NIC, Node, NodeOperatingState - - -def test_link_up(): - """Tests Nodes, NICs, and Links can all be connected and be in an enabled/up state.""" - node_a = Node(hostname="node_a", operating_state=NodeOperatingState.ON) - nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0") - node_a.connect_nic(nic_a) - - node_b = Node(hostname="node_b", operating_state=NodeOperatingState.ON) - nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0") - node_b.connect_nic(nic_b) - - link = Link(endpoint_a=nic_a, endpoint_b=nic_b) - - assert nic_a.enabled - assert nic_b.enabled - assert link.is_up - - -def test_ping_between_computer_and_server(client_server): - computer, server = client_server - - assert computer.ping(target_ip_address=server.nics[next(iter(server.nics))].ip_address) diff --git a/tests/integration_tests/network/test_switched_network.py b/tests/integration_tests/network/test_switched_network.py index 103dda21..8a2bd0a2 100644 --- a/tests/integration_tests/network/test_switched_network.py +++ b/tests/integration_tests/network/test_switched_network.py @@ -5,32 +5,8 @@ from primaite.simulator.network.hardware.nodes.server import Server from primaite.simulator.network.hardware.nodes.switch import Switch -def test_switched_network(): +def test_switched_network(client_switch_server): """Tests a node can ping another node via the switch.""" - network = Network() + computer, switch, server = client_switch_server - client_1 = Computer( - hostname="client_1", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) - client_1.power_on() - - server_1 = Server( - hostname="server_1", - ip_address="192.168.1.11", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) - server_1.power_on() - - switch_1 = Switch(hostname="switch_1", start_up_duration=0) - switch_1.power_on() - - network.connect(endpoint_a=client_1.ethernet_port[1], endpoint_b=switch_1.switch_ports[1]) - network.connect(endpoint_a=server_1.ethernet_port[1], endpoint_b=switch_1.switch_ports[2]) - - assert client_1.ping("192.168.1.11") + assert computer.ping(server.ethernet_port[1].ip_address)