#2248 - Tidying up the tests so that they use updated networks

This commit is contained in:
Chris McCarthy
2024-02-02 16:55:43 +00:00
parent dc5aeede33
commit cb002d644f
10 changed files with 104 additions and 93 deletions

View File

@@ -679,7 +679,7 @@ class Node(SimComponent):
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(kwargs["hostname"])
if not kwargs.get("session_manager"):
kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp"))
kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log"))
if not kwargs.get("root"):
kwargs["root"] = SIM_OUTPUT.path / kwargs["hostname"]
if not kwargs.get("file_system"):

View File

@@ -1,6 +1,6 @@
import secrets
from enum import Enum
from ipaddress import IPv4Address
from ipaddress import IPv4Address, IPv4Network
from typing import Union
from pydantic import BaseModel, field_validator, validate_call
@@ -86,10 +86,3 @@ class IPPacket(BaseModel):
"Time to Live (TTL) for the packet."
precedence: Precedence = Precedence.ROUTINE
"Precedence level for Quality of Service (default is Precedence.ROUTINE)."
def __init__(self, **kwargs):
if not isinstance(kwargs["src_ip_address"], IPv4Address):
kwargs["src_ip_address"] = IPv4Address(kwargs["src_ip_address"])
if not isinstance(kwargs["dst_ip_address"], IPv4Address):
kwargs["dst_ip_address"] = IPv4Address(kwargs["dst_ip_address"])
super().__init__(**kwargs)

View File

@@ -75,7 +75,7 @@ class SessionManager:
:param arp_cache: A reference to the ARP cache component.
"""
def __init__(self, sys_log: SysLog, arp_cache: "ARPCache"):
def __init__(self, sys_log: SysLog):
self.sessions_by_key: Dict[
Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]], Session
] = {}
@@ -150,8 +150,8 @@ class SessionManager:
def resolve_outbound_transmission_details(
self, dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, session_id: Optional[str] = None
) -> Tuple[Optional["NIC"], Optional[str], Optional[IPProtocol], bool]:
if not isinstance(dst_ip_address, IPv4Address):
) -> Tuple[Optional["NIC"], Optional[str], IPv4Address, Optional[IPProtocol], bool]:
if not isinstance(dst_ip_address, (IPv4Address, IPv4Network)):
dst_ip_address = IPv4Address(dst_ip_address)
is_broadcast = False
outbound_nic = None
@@ -192,7 +192,7 @@ class SessionManager:
if use_default_gateway:
dst_mac_address = self.software_manager.arp.get_default_gateway_mac_address()
outbound_nic = self.software_manager.arp.get_default_gateway_nic()
return outbound_nic, dst_mac_address, protocol, is_broadcast
return outbound_nic, dst_mac_address, dst_ip_address, protocol, is_broadcast
def receive_payload_from_software_manager(
self,
@@ -226,14 +226,13 @@ class SessionManager:
is_broadcast = payload.request
ip_protocol = IPProtocol.UDP
else:
outbound_nic, dst_mac_address, protocol, is_broadcast = self.resolve_outbound_transmission_details(
vals = self.resolve_outbound_transmission_details(
dst_ip_address=dst_ip_address, session_id=session_id
)
outbound_nic, dst_mac_address, dst_ip_address, protocol, is_broadcast = vals
if protocol:
ip_protocol = protocol
# Check if outbound NIC and destination MAC address are resolved
if not outbound_nic or not dst_mac_address:
return False
@@ -241,7 +240,7 @@ class SessionManager:
tcp_header = None
udp_header = None
if ip_protocol == IPProtocol.TCP:
TCPHeader(
tcp_header = TCPHeader(
src_port=dst_port,
dst_port=dst_port,
)
@@ -250,7 +249,6 @@ class SessionManager:
src_port=dst_port,
dst_port=dst_port,
)
# Construct the frame for transmission
frame = Frame(
ethernet=EthernetHeader(src_mac_addr=outbound_nic.mac_address, dst_mac_addr=dst_mac_address),

View File

@@ -162,6 +162,7 @@ class SoftwareManager:
payload=payload,
dst_ip_address=dest_ip_address,
dst_port=dest_port,
ip_protocol=ip_protocol,
session_id=session_id,
)

View File

@@ -356,6 +356,7 @@ class IOSoftware(Software):
session_id: Optional[str] = None,
dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None,
dest_port: Optional[Port] = None,
ip_protocol: IPProtocol = IPProtocol.TCP,
**kwargs,
) -> bool:
"""
@@ -375,7 +376,11 @@ class IOSoftware(Software):
return False
return self.software_manager.send_payload_to_session_manager(
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id
payload=payload,
dest_ip_address=dest_ip_address,
dest_port=dest_port,
ip_protocol=ip_protocol,
session_id=session_id
)
@abstractmethod