diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 32db95c6..ee19abb3 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -192,6 +192,9 @@ class SimComponent(BaseModel): :param action: List describing the action to apply to this object. :type action: List[str] + + :param: context: Dict containing context for actions + :type context: Dict """ if self.action_manager is None: return diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index be20a28d..f8e97442 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -32,27 +32,23 @@ class Session(SimComponent): """ protocol: IPProtocol - src_ip_address: IPv4Address - dst_ip_address: IPv4Address + with_ip_address: IPv4Address src_port: Optional[Port] dst_port: Optional[Port] connected: bool = False @classmethod - def from_session_key( - cls, session_key: Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]] - ) -> Session: + def from_session_key(cls, session_key: Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]) -> Session: """ Create a Session instance from a session key tuple. :param session_key: Tuple containing the session details. :return: A Session instance. """ - protocol, src_ip_address, dst_ip_address, src_port, dst_port = session_key + protocol, with_ip_address, src_port, dst_port = session_key return Session( protocol=protocol, - src_ip_address=src_ip_address, - dst_ip_address=dst_ip_address, + with_ip_address=with_ip_address, src_port=src_port, dst_port=dst_port, ) @@ -78,9 +74,7 @@ class SessionManager: """ def __init__(self, sys_log: SysLog, arp_cache: "ARPCache"): - self.sessions_by_key: Dict[ - Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]], Session - ] = {} + self.sessions_by_key: Dict[Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]], Session] = {} self.sessions_by_uuid: Dict[str, Session] = {} self.sys_log: SysLog = sys_log self.software_manager: SoftwareManager = None # Noqa @@ -99,8 +93,8 @@ class SessionManager: @staticmethod def _get_session_key( - frame: Frame, from_source: bool = True - ) -> Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]]: + frame: Frame, inbound_frame: bool = True + ) -> Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]: """ Extracts the session key from the given frame. @@ -112,36 +106,36 @@ class SessionManager: - Optional[Port]: The destination port number (if applicable). :param frame: The network frame from which to extract the session key. - :param from_source: A flag to indicate if the key should be extracted from the source or destination. :return: A tuple containing the session key. """ protocol = frame.ip.protocol - src_ip_address = frame.ip.src_ip_address - dst_ip_address = frame.ip.dst_ip_address + with_ip_address = frame.ip.src_ip_address if protocol == IPProtocol.TCP: - if from_source: + if inbound_frame: src_port = frame.tcp.src_port dst_port = frame.tcp.dst_port else: dst_port = frame.tcp.src_port src_port = frame.tcp.dst_port + with_ip_address = frame.ip.dst_ip_address elif protocol == IPProtocol.UDP: - if from_source: + if inbound_frame: src_port = frame.udp.src_port dst_port = frame.udp.dst_port else: dst_port = frame.udp.src_port src_port = frame.udp.dst_port + with_ip_address = frame.ip.dst_ip_address else: src_port = None dst_port = None - return protocol, src_ip_address, dst_ip_address, src_port, dst_port + return protocol, with_ip_address, src_port, dst_port def receive_payload_from_software_manager( self, payload: Any, - dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[Port] = None, + dst_ip_address: Optional[IPv4Address] = None, + dst_port: Optional[Port] = None, session_id: Optional[str] = None, is_reattempt: bool = False, ) -> Union[Any, None]: @@ -154,20 +148,21 @@ class SessionManager: :param session_id: The Session ID the payload is to originate from. Optional. If None, one will be created. """ if session_id: - dest_ip_address = self.sessions_by_uuid[session_id].dst_ip_address - dest_port = self.sessions_by_uuid[session_id].dst_port + session = self.sessions_by_uuid[session_id] + dst_ip_address = self.sessions_by_uuid[session_id].with_ip_address + dst_port = self.sessions_by_uuid[session_id].dst_port - dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dest_ip_address) + dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dst_ip_address) if dst_mac_address: - outbound_nic = self.arp_cache.get_arp_cache_nic(dest_ip_address) + outbound_nic = self.arp_cache.get_arp_cache_nic(dst_ip_address) else: if not is_reattempt: - self.arp_cache.send_arp_request(dest_ip_address) + self.arp_cache.send_arp_request(dst_ip_address) return self.receive_payload_from_software_manager( payload=payload, - dest_ip_address=dest_ip_address, - dest_port=dest_port, + dst_ip_address=dst_ip_address, + dst_port=dst_port, session_id=session_id, is_reattempt=True, ) @@ -178,17 +173,17 @@ class SessionManager: ethernet=EthernetHeader(src_mac_addr=outbound_nic.mac_address, dst_mac_addr=dst_mac_address), ip=IPPacket( src_ip_address=outbound_nic.ip_address, - dst_ip_address=dest_ip_address, + dst_ip_address=dst_ip_address, ), tcp=TCPHeader( - src_port=dest_port, - dst_port=dest_port, + src_port=dst_port, + dst_port=dst_port, ), payload=payload, ) if not session_id: - session_key = self._get_session_key(frame, from_source=True) + session_key = self._get_session_key(frame, inbound_frame=False) session = self.sessions_by_key.get(session_key) if not session: # Create new session @@ -198,33 +193,25 @@ class SessionManager: outbound_nic.send_frame(frame) - def send_payload_to_software_manager(self, payload: Any, session_id: int): + def receive_frame(self, frame: Frame): """ - Send a payload to the software manager. - - :param payload: The payload to be sent. - :param session_id: The Session ID the payload originates from. - """ - self.software_manager.receive_payload_from_session_manger() - - def receive_payload_from_nic(self, frame: Frame): - """ - Receive a Frame from the NIC. + Receive a Frame. Extract the session key using the _get_session_key method, and forward the payload to the appropriate session. If the session does not exist, a new one is created. :param frame: The frame being received. """ - session_key = self._get_session_key(frame) - session = self.sessions_by_key.get(session_key) + session_key = self._get_session_key(frame, inbound_frame=True) + session: Session = self.sessions_by_key.get(session_key) if not session: # Create new session session = Session.from_session_key(session_key) self.sessions_by_key[session_key] = session self.sessions_by_uuid[session.uuid] = session - self.software_manager.receive_payload_from_session_manger(payload=frame, session=session) - # TODO: Implement the frame deconstruction and send to SoftwareManager. + self.software_manager.receive_payload_from_session_manager( + payload=frame.payload, port=frame.tcp.dst_port, protocol=frame.ip.protocol, session_id=session.uuid + ) def show(self, markdown: bool = False): """ diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index 28e37963..71519ac7 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -6,7 +6,6 @@ from prettytable import MARKDOWN, PrettyTable from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application -from primaite.simulator.system.core.session_manager import Session from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.services.service import Service from primaite.simulator.system.software import SoftwareType @@ -86,7 +85,7 @@ class SoftwareManager: payload: Any, dest_ip_address: Optional[IPv4Address] = None, dest_port: Optional[Port] = None, - session_id: Optional[int] = None, + session_id: Optional[str] = None, ): """ Send a payload to the SessionManager. @@ -97,22 +96,21 @@ class SoftwareManager: :param session_id: The Session ID the payload is to originate from. Optional. """ self.session_manager.receive_payload_from_software_manager( - payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id + payload=payload, dst_ip_address=dest_ip_address, dst_port=dest_port, session_id=session_id ) - def receive_payload_from_session_manger(self, payload: Any, session: Session): + def receive_payload_from_session_manager(self, payload: Any, port: Port, protocol: IPProtocol, session_id: str): """ Receive a payload from the SessionManager and forward it to the corresponding service or application. :param payload: The payload being received. :param session: The transport session the payload originates from. """ - # receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None) - # if receiver: - # receiver.receive_payload(None, payload) - # else: - # raise ValueError(f"No service or application found for port {port} and protocol {protocol}") - pass + receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None) + if receiver: + receiver.receive_payload(payload=payload, session_id=session_id) + else: + self.sys_log.error(f"No service or application found for port {port} and protocol {protocol}") def show(self, markdown: bool = False): """ diff --git a/src/primaite/simulator/system/services/dns_client.py b/src/primaite/simulator/system/services/dns_client.py index 97968407..3929065d 100644 --- a/src/primaite/simulator/system/services/dns_client.py +++ b/src/primaite/simulator/system/services/dns_client.py @@ -1,19 +1,26 @@ -from abc import abstractmethod from ipaddress import IPv4Address -from typing import Any, Dict, List - -from pydantic import BaseModel +from typing import Any, Dict, Optional from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.services.service import Service -class DNSClient(BaseModel): +class DNSClient(Service): """Represents a DNS Client as a Service.""" - dns_cache: Dict[str:IPv4Address] = {} + dns_cache: Dict[str, IPv4Address] = {} "A dict of known mappings between domain/URLs names and IPv4 addresses." - @abstractmethod + def __init__(self, **kwargs): + kwargs["name"] = "DNSClient" + kwargs["port"] = Port.DNS + # DNS uses UDP by default + # it switches to TCP when the bytes exceed 512 (or 4096) bytes + kwargs["protocol"] = IPProtocol.UDP + super().__init__(**kwargs) + def describe_state(self) -> Dict: """ Describes the current state of the software. @@ -26,57 +33,109 @@ class DNSClient(BaseModel): """ return {"Operating State": self.operating_state} - def apply_action(self, action: List[str]) -> None: - """ - Applies a list of actions to the Service. - - :param action: A list of actions to apply. - """ - pass - - def reset_component_for_episode(self): + def reset_component_for_episode(self, episode: int): """ Resets the Service component for a new episode. This method ensures the Service is ready for a new episode, including resetting any stateful properties or statistics, and clearing any message queues. """ + super().reset_component_for_episode(episode=episode) self.dns_cache = {} - def check_domain_in_cache(self, target_domain: str, session_id: str): + def add_domain_to_cache(self, domain_name: str, ip_address: IPv4Address): + """ + Adds a domain name to the DNS Client cache. + + :param: domain_name: The domain name to save to cache + :param: ip_address: The IP Address to attach the domain name to + """ + self.dns_cache[domain_name] = ip_address + + def check_domain_in_cache( + self, + target_domain: str, + dest_ip_address: Optional[IPv4Address] = None, + dest_port: Optional[Port] = None, + session_id: Optional[str] = None, + is_reattempt: bool = False, + ) -> bool: """Function to check if domain name is in DNS client cache. - :param target_domain: The domain requested for an IP address. - :param session_id: The ID of the session in order to send the response to the DNS server or application. + :param: target_domain: The domain requested for an IP address. + :param: dest_ip_address: The ip address of the payload destination. + :param: dest_port: The port of the payload destination. + :param: session_id: The Session ID the payload is to originate from. Optional. + :param: is_reattempt: Checks if the request has been reattempted. Default is False. """ - if target_domain in self.dns_cache: - ip_address = self.dns_cache[target_domain] - self.send(ip_address, session_id) - else: - self.send(target_domain, session_id) + # check if the target domain is in the client's DNS cache + payload = DNSPacket(dns_request=DNSRequest(domain_name_request=target_domain)) - def send(self, payload: Any, session_id: str, **kwargs) -> bool: + # check if the domain is already in the DNS cache + if target_domain in self.dns_cache: + return True + else: + # return False if already reattempted + if is_reattempt: + return False + else: + # send a request to check if domain name exists in the DNS Server + self.send(payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id) + # call function again + return self.check_domain_in_cache( + target_domain=target_domain, + dest_ip_address=dest_ip_address, + dest_port=dest_port, + session_id=session_id, + is_reattempt=True, + ) + + def send( + self, + payload: Any, + dest_ip_address: Optional[IPv4Address] = None, + dest_port: Optional[Port] = None, + session_id: Optional[str] = None, + **kwargs, + ) -> bool: """ Sends a payload to the SessionManager. The specifics of how the payload is processed and whether a response payload is generated should be implemented in subclasses. - :param payload: The payload to send. + :param payload: The payload to be sent. + :param dest_ip_address: The ip address of the payload destination. + :param dest_port: The port of the payload destination. + :param session_id: The Session ID the payload is to originate from. Optional. + :return: True if successful, False otherwise. """ - DNSPacket(dns_request=DNSRequest(domain_name_request=payload), dns_reply=None) + # create DNS request packet + self.software_manager.send_payload_to_session_manager( + payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id + ) - def receive(self, payload: Any, session_id: str, **kwargs) -> bool: + def receive( + self, + payload: Any, + dest_ip_address: Optional[IPv4Address] = None, + dest_port: Optional[Port] = None, + session_id: Optional[str] = None, + **kwargs, + ) -> bool: """ Receives a payload from the SessionManager. The specifics of how the payload is processed and whether a response payload is generated should be implemented in subclasses. - :param payload: The payload to receive. (receive a DNS packet with dns request and dns reply in, send to web - browser) + :param payload: The payload to be sent. + :param dest_ip_address: The ip address of the payload destination. + :param dest_port: The port of the payload destination. + :param session_id: The Session ID the payload is to originate from. Optional. :return: True if successful, False otherwise. """ + super().send() # check the DNS packet (dns request, dns reply) here and see if it actually worked pass diff --git a/src/primaite/simulator/system/services/dns_server.py b/src/primaite/simulator/system/services/dns_server.py index a2eaf9d9..3dcd89f9 100644 --- a/src/primaite/simulator/system/services/dns_server.py +++ b/src/primaite/simulator/system/services/dns_server.py @@ -1,19 +1,31 @@ -from abc import abstractmethod from ipaddress import IPv4Address -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional -from pydantic import BaseModel +from prettytable import MARKDOWN, PrettyTable -from primaite.simulator.network.protocols.dns import DNSPacket, DNSReply, DNSRequest +from primaite import getLogger +from primaite.simulator.network.protocols.dns import DNSPacket +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.services.service import Service + +_LOGGER = getLogger(__name__) -class DNSServer(BaseModel): +class DNSServer(Service): """Represents a DNS Server as a Service.""" - dns_table: dict[str:IPv4Address] = {} + dns_table: Dict[str, IPv4Address] = {} "A dict of mappings between domain names and IPv4 addresses." - @abstractmethod + def __init__(self, **kwargs): + kwargs["name"] = "DNSServer" + kwargs["port"] = Port.DNS + # DNS uses UDP by default + # it switches to TCP when the bytes exceed 512 (or 4096) bytes + kwargs["protocol"] = IPProtocol.UDP + super().__init__(**kwargs) + def describe_state(self) -> Dict: """ Describes the current state of the software. @@ -26,15 +38,7 @@ class DNSServer(BaseModel): """ return {"Operating State": self.operating_state} - def apply_action(self, action: List[str]) -> None: - """ - Applies a list of actions to the Service. - - :param action: A list of actions to apply. (unsure) - """ - pass - - def dns_lookup(self, target_domain: str) -> Optional[IPv4Address]: + def dns_lookup(self, target_domain: Any) -> Optional[IPv4Address]: """ Attempts to find the IP address for a domain name. @@ -42,11 +46,23 @@ class DNSServer(BaseModel): :return ip_address: The IP address of that domain name or None. """ if target_domain in self.dns_table: - self.dns_table[target_domain] + return self.dns_table[target_domain] else: return None - def reset_component_for_episode(self): + def dns_register(self, domain_name: str, domain_ip_address: IPv4Address): + """ + Register a domain name and its IP address. + + :param: domain_name: The domain name to register + :type: domain_name: str + + :param: domain_ip_address: The IP address that the domain should route to + :type: domain_ip_address: IPv4Address + """ + self.dns_table[domain_name] = domain_ip_address + + def reset_component_for_episode(self, episode: int): """ Resets the Service component for a new episode. @@ -54,36 +70,78 @@ class DNSServer(BaseModel): stateful properties or statistics, and clearing any message queues. """ self.dns_table = {} + super().reset_component_for_episode(episode=episode) - def send(self, payload: Any, session_id: str, **kwargs) -> bool: + def send( + self, + payload: Any, + dest_ip_address: Optional[IPv4Address] = None, + dest_port: Optional[Port] = None, + session_id: Optional[str] = None, + **kwargs, + ) -> bool: """ Sends a payload to the SessionManager. The specifics of how the payload is processed and whether a response payload is generated should be implemented in subclasses. - :param payload: The payload to send. + :param: payload: The payload to send. + :param: dest_ip_address: The ip address of the machine that the payload will be sent to + :param: dest_port: The port of the machine that the payload will be sent to + :param: session_id: The id of the session + :return: True if successful, False otherwise. """ - # DNS packet will be sent from DNS Server to the DNS client - DNSPacket( - dns_request=DNSRequest(domain_name_request=self.dns_table), - dns_reply=DNSReply(domain_name_ip_address=payload), - ) + try: + self.software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id) + except Exception as e: + _LOGGER.error(e) + return False - def receive(self, payload: Any, session_id: str, **kwargs) -> bool: + def receive( + self, + payload: Any, + dest_ip_address: Optional[IPv4Address] = None, + dest_port: Optional[Port] = None, + session_id: Optional[str] = None, + **kwargs, + ) -> bool: """ Receives a payload from the SessionManager. The specifics of how the payload is processed and whether a response payload is generated should be implemented in subclasses. - :param payload: The payload to receive. (take the domain name and do dns lookup) + :param: payload: The payload to send. + :param: dest_ip_address: The ip address of the machine that the payload will be sent to + :param: dest_port: The port of the machine that the payload will be sent to + :param: session_id: The id of the session + :return: True if successful, False otherwise. """ - ip_address = self.dns_lookup(payload) - if ip_address is not None: - self.send(ip_address, session_id) + # The payload should be a DNS packet + if not isinstance(payload, DNSPacket): + _LOGGER.debug(f"{payload} is not a DNSPacket") + return False + # cast payload into a DNS packet + payload: DNSPacket = payload + if payload.dns_request is not None: + # generate a reply with the correct DNS IP address + payload.generate_reply(self.dns_lookup(payload.dns_request.domain_name_request)) + # send reply + self.send(payload, session_id) return True return False + + def show(self, markdown: bool = False): + """Prints a table of DNS Lookup table.""" + table = PrettyTable(["Domain Name", "IP Address"]) + if markdown: + table.set_style(MARKDOWN) + table.align = "l" + table.title = f"{self.sys_log.hostname} DNS Lookup table" + for dns in self.dns_table.items(): + table.add_row([dns[0], dns[1]]) + print(table) diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index b9340103..3011c74d 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -1,8 +1,10 @@ from enum import Enum +from ipaddress import IPv4Address from typing import Any, Dict, Optional from primaite import getLogger from primaite.simulator.core import Action, ActionManager +from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.software import IOSoftware _LOGGER = getLogger(__name__) @@ -72,29 +74,54 @@ class Service(IOSoftware): """ pass - def send(self, payload: Any, session_id: str, **kwargs) -> bool: + def send( + self, + payload: Any, + dest_ip_address: Optional[IPv4Address] = None, + dest_port: Optional[Port] = None, + session_id: Optional[str] = None, + **kwargs, + ) -> bool: """ Sends a payload to the SessionManager. The specifics of how the payload is processed and whether a response payload is generated should be implemented in subclasses. - :param payload: The payload to send. + :param: payload: The payload to send. + :param: dest_ip_address: The ip address of the machine that the payload will be sent to + :param: dest_port: The port of the machine that the payload will be sent to + :param: session_id: The id of the session + :return: True if successful, False otherwise. """ - pass + self.software_manager.send_payload_to_session_manager( + payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port, session_id=session_id + ) - def receive(self, payload: Any, session_id: str, **kwargs) -> bool: + def receive( + self, + payload: Any, + dest_ip_address: Optional[IPv4Address] = None, + dest_port: Optional[Port] = None, + session_id: Optional[str] = None, + **kwargs, + ) -> bool: """ Receives a payload from the SessionManager. The specifics of how the payload is processed and whether a response payload is generated should be implemented in subclasses. - :param payload: The payload to receive. + :param: payload: The payload to send. + :param: dest_ip_address: The ip address of the machine that the payload will be sent to + :param: dest_port: The port of the machine that the payload will be sent to + :param: session_id: The id of the session + :return: True if successful, False otherwise. """ - pass + + pass def stop(self) -> None: """Stop the service.""" diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns.py new file mode 100644 index 00000000..fdb3426d --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns.py @@ -0,0 +1,66 @@ +import sys +from ipaddress import IPv4Address + +import pytest + +from primaite.simulator.network.hardware.base import Node +from primaite.simulator.network.networks import arcd_uc2_network +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.services.dns_client import DNSClient +from primaite.simulator.system.services.dns_server import DNSServer + + +@pytest.fixture(scope="function") +def dns_server() -> Node: + node = Node(hostname="dns_server") + node.software_manager.add_service(service_class=DNSServer) + node.software_manager.services["DNSServer"].start() + return node + + +@pytest.fixture(scope="function") +def dns_client() -> Node: + node = Node(hostname="dns_client") + node.software_manager.add_service(service_class=DNSClient) + node.software_manager.services["DNSClient"].start() + return node + + +def test_create_dns_server(dns_server): + assert dns_server is not None + dns_server_service: DNSServer = dns_server.software_manager.services["DNSServer"] + assert dns_server_service.name is "DNSServer" + assert dns_server_service.port is Port.DNS + assert dns_server_service.protocol is IPProtocol.UDP + + +def test_create_dns_client(dns_client): + assert dns_client is not None + dns_client_service: DNSClient = dns_client.software_manager.services["DNSClient"] + assert dns_client_service.name is "DNSClient" + assert dns_client_service.port is Port.DNS + assert dns_client_service.protocol is IPProtocol.UDP + + +def test_dns_server_domain_name_registration(dns_server): + """Test to check if the domain name registration works.""" + dns_server_service: DNSServer = dns_server.software_manager.services["DNSServer"] + + # register the web server in the domain controller + dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12")) + + # return none for an unknown domain + assert dns_server_service.dns_lookup("fake-domain.com") is None + assert dns_server_service.dns_lookup("real-domain.com") is not None + + +def test_dns_client_check_domain_in_cache(dns_client): + """Test to make sure that the check_domain_in_cache returns the correct values.""" + dns_client_service: DNSClient = dns_client.software_manager.services["DNSClient"] + + # add a domain to the dns client cache + dns_client_service.add_domain_to_cache("real-domain.com", IPv4Address("192.168.1.12")) + + assert dns_client_service.check_domain_in_cache("fake-domain.com") is False + assert dns_client_service.check_domain_in_cache("real-domain.com") is True