diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index bceb385c..04262037 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -406,7 +406,8 @@ class SwitchPort(SimComponent): if self.enabled: frame.decrement_ttl() self.pcap.capture(frame) - self.connected_node.forward_frame(frame=frame, incoming_port=self) + connected_node: Node = self.connected_node + connected_node.forward_frame(frame=frame, incoming_port=self) return True return False diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index 79af75e4..0b9a2299 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -132,6 +132,8 @@ def arcd_uc2_network() -> Network: ) client_1.power_on() client_1.software_manager.install(DNSClient) + client_1_dns_client_service: DNSServer = client_1.software_manager.software["DNSClient"] # noqa + client_1_dns_client_service.start() network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1]) client_1.software_manager.install(DataManipulationBot) db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"] @@ -143,6 +145,8 @@ def arcd_uc2_network() -> Network: ) client_2.power_on() client_2.software_manager.install(DNSClient) + client_2_dns_client_service: DNSServer = client_2.software_manager.software["DNSClient"] # noqa + client_2_dns_client_service.start() network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2]) # Domain Controller @@ -215,6 +219,7 @@ def arcd_uc2_network() -> Network: # register the web_server to a domain dns_server_service: DNSServer = domain_controller.software_manager.software["DNSServer"] # noqa + dns_server_service.start() dns_server_service.dns_register("arcd.com", web_server.ip_address) # Backup Server diff --git a/src/primaite/simulator/network/transmission/data_link_layer.py b/src/primaite/simulator/network/transmission/data_link_layer.py index 5c09210a..b7986622 100644 --- a/src/primaite/simulator/network/transmission/data_link_layer.py +++ b/src/primaite/simulator/network/transmission/data_link_layer.py @@ -5,7 +5,6 @@ from pydantic import BaseModel from primaite import getLogger from primaite.simulator.network.protocols.arp import ARPPacket -from primaite.simulator.network.protocols.dns import DNSPacket from primaite.simulator.network.transmission.network_layer import ICMPPacket, IPPacket, IPProtocol from primaite.simulator.network.transmission.primaite_layer import PrimaiteHeader from primaite.simulator.network.transmission.transport_layer import TCPHeader, UDPHeader @@ -97,8 +96,6 @@ class Frame(BaseModel): "ICMP header." arp: Optional[ARPPacket] = None "ARP packet." - dns: Optional[DNSPacket] = None - "DNS packet." primaite: PrimaiteHeader "PrimAITE header." payload: Optional[Any] = None diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index 06701546..95ece9f9 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -74,7 +74,9 @@ class SessionManager: """ def __init__(self, sys_log: SysLog, arp_cache: "ARPCache"): - self.sessions_by_key: Dict[Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]], Session] = {} + self.sessions_by_key: Dict[ + Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]], Session + ] = {} self.sessions_by_uuid: Dict[str, Session] = {} self.sys_log: SysLog = sys_log self.software_manager: SoftwareManager = None # Noqa @@ -94,7 +96,7 @@ class SessionManager: @staticmethod def _get_session_key( frame: Frame, inbound_frame: bool = True - ) -> Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]]: + ) -> Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]: """ Extracts the session key from the given frame. diff --git a/src/primaite/simulator/system/services/dns_client.py b/src/primaite/simulator/system/services/dns_client.py index db01c05c..d6e4a05b 100644 --- a/src/primaite/simulator/system/services/dns_client.py +++ b/src/primaite/simulator/system/services/dns_client.py @@ -22,7 +22,8 @@ class DNSClient(Service): kwargs["port"] = Port.DNS # DNS uses UDP by default # it switches to TCP when the bytes exceed 512 (or 4096) bytes - kwargs["protocol"] = IPProtocol.UDP + # TCP for now + kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) def describe_state(self) -> Dict: @@ -84,14 +85,15 @@ class DNSClient(Service): return False else: # send a request to check if domain name exists in the DNS Server - self.software_manager.send_payload_to_session_manager( + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, ) # check if the domain has been added to cache - if self.dns_cache.get(target_domain) is None: + if self.dns_cache.get(target_domain, None) is None: # call function again return self.check_domain_exists( target_domain=target_domain, diff --git a/src/primaite/simulator/system/services/dns_server.py b/src/primaite/simulator/system/services/dns_server.py index b879d515..c36c7034 100644 --- a/src/primaite/simulator/system/services/dns_server.py +++ b/src/primaite/simulator/system/services/dns_server.py @@ -23,7 +23,8 @@ class DNSServer(Service): kwargs["port"] = Port.DNS # DNS uses UDP by default # it switches to TCP when the bytes exceed 512 (or 4096) bytes - kwargs["protocol"] = IPProtocol.UDP + # TCP for now + kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) def describe_state(self) -> Dict: diff --git a/tests/integration_tests/system/test_dns_client_server.py b/tests/integration_tests/system/test_dns_client_server.py index 77fa6017..a4514bad 100644 --- a/tests/integration_tests/system/test_dns_client_server.py +++ b/tests/integration_tests/system/test_dns_client_server.py @@ -4,6 +4,7 @@ from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.server import Server from primaite.simulator.system.services.dns_client import DNSClient from primaite.simulator.system.services.dns_server import DNSServer +from primaite.simulator.system.services.service import ServiceOperatingState def test_dns_client_server(uc2_network): @@ -13,12 +14,17 @@ def test_dns_client_server(uc2_network): dns_client: DNSClient = client_1.software_manager.software["DNSClient"] dns_server: DNSServer = domain_controller.software_manager.software["DNSServer"] - # register a domain to web server - dns_server.dns_register("real-domain.com", IPv4Address("192.168.1.12")) + assert dns_client.operating_state == ServiceOperatingState.RUNNING + assert dns_server.operating_state == ServiceOperatingState.RUNNING dns_server.show() - dns_client.check_domain_exists(target_domain="real-domain.com", dest_ip_address=IPv4Address("192.168.1.14")) + # fake domain should not be added to dns cache + dns_client.check_domain_exists( + target_domain="fake-domain.com", dest_ip_address=IPv4Address(domain_controller.ip_address) + ) + assert dns_client.dns_cache.get("fake-domain.com", None) is None - # should register the domain in the client cache - assert dns_client.dns_cache.get("real-domain.com") is not None + # arcd.com is registered in dns server and should be saved to cache + dns_client.check_domain_exists(target_domain="arcd.com", dest_ip_address=IPv4Address(domain_controller.ip_address)) + assert dns_client.dns_cache.get("arcd.com", None) is not None diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns.py index 943d3265..b4f20539 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns.py @@ -31,7 +31,7 @@ def test_create_dns_server(dns_server): dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"] assert dns_server_service.name is "DNSServer" assert dns_server_service.port is Port.DNS - assert dns_server_service.protocol is IPProtocol.UDP + assert dns_server_service.protocol is IPProtocol.TCP def test_create_dns_client(dns_client): @@ -39,7 +39,7 @@ def test_create_dns_client(dns_client): dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"] assert dns_client_service.name is "DNSClient" assert dns_client_service.port is Port.DNS - assert dns_client_service.protocol is IPProtocol.UDP + assert dns_client_service.protocol is IPProtocol.TCP def test_dns_server_domain_name_registration(dns_server):