diff --git a/.gitignore b/.gitignore index ff86b65f..66d528a8 100644 --- a/.gitignore +++ b/.gitignore @@ -144,9 +144,11 @@ cython_debug/ # IDE .idea/ docs/source/primaite-dependencies.rst +.vscode/ # outputs src/primaite/outputs/ +simulation_output/ # benchmark session outputs benchmark/output diff --git a/CHANGELOG.md b/CHANGELOG.md index 14a53d73..d9700f83 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ SessionManager. 1. Creating a simulation - this notebook explains how to build up a simulation using the Python package. (WIP) - Red Agent Services: - Data Manipulator Bot - A red agent service which sends a payload to a target machine. (By default this payload is a SQL query that breaks a database) +- DNS Services: DNS Client and DNS Server ## [2.0.0] - 2023-07-26 diff --git a/docs/source/simulation_components/system/dns_client_server.rst b/docs/source/simulation_components/system/dns_client_server.rst new file mode 100644 index 00000000..f57f903b --- /dev/null +++ b/docs/source/simulation_components/system/dns_client_server.rst @@ -0,0 +1,56 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +DNS Client Server +================= + +DNS Server +---------- +Also known as a DNS Resolver, the ``DNSServer`` provides a DNS Server simulation by extending the base Service class. + +Key capabilities +^^^^^^^^^^^^^^^^ + +- Simulates DNS requests and DNSPacket transfer across a network +- Registers domain names and the IP Address linked to the domain name +- Returns the IP address for a given domain name within a DNS Packet that a DNS Client can read +- Leverages the Service base class for install/uninstall, status tracking, etc. + +Usage +^^^^^ +- Install on a Node via the ``SoftwareManager`` to start the database service. +- Service runs on TCP port 53 by default. (TODO: TCP for now, should be UDP in future) + +Implementation +^^^^^^^^^^^^^^ + +- DNS request and responses use a ``DNSPacket`` object +- Extends Service class for integration with ``SoftwareManager``. + +DNS Client +---------- + +The DNSClient provides a client interface for connecting to the ``DNSServer``. + +Key features +^^^^^^^^^^^^ + +- Connects to the ``DNSServer`` via the ``SoftwareManager``. +- Executes DNS lookup requests and keeps a cache of known domain name IP addresses. +- Handles connection to DNSServer and querying for domain name IP addresses. + +Usage +^^^^^ + +- Install on a Node via the ``SoftwareManager`` to start the database service. +- Service runs on TCP port 53 by default. (TODO: TCP for now, should be UDP in future) +- Execute domain name checks with ``check_domain_exists``. +- ``DNSClient`` will automatically add the IP Address of the domain into its cache + +Implementation +^^^^^^^^^^^^^^ + +- Leverages ``SoftwareManager`` for sending payloads over the network. +- Provides easy interface for Nodes to find IP addresses via domain names. +- Extends base Service class. diff --git a/docs/source/simulation_components/system/software.rst b/docs/source/simulation_components/system/software.rst index d0355d3a..275fdaf9 100644 --- a/docs/source/simulation_components/system/software.rst +++ b/docs/source/simulation_components/system/software.rst @@ -17,3 +17,4 @@ Contents database_client_server data_manipulation_bot + dns_client_server diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 32db95c6..ee19abb3 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -192,6 +192,9 @@ class SimComponent(BaseModel): :param action: List describing the action to apply to this object. :type action: List[str] + + :param: context: Dict containing context for actions + :type context: Dict """ if self.action_manager is None: return diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index bceb385c..dd2130d2 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -5,7 +5,7 @@ import secrets from enum import Enum from ipaddress import IPv4Address, IPv4Network from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, Literal, Optional, Tuple, Union from prettytable import MARKDOWN, PrettyTable @@ -89,8 +89,6 @@ class NIC(SimComponent): "The Maximum Transmission Unit (MTU) of the NIC in Bytes. Default is 1500 B" wake_on_lan: bool = False "Indicates if the NIC supports Wake-on-LAN functionality." - dns_servers: List[IPv4Address] = [] - "List of IP addresses of DNS servers used for name resolution." connected_node: Optional[Node] = None "The Node to which the NIC is connected." connected_link: Optional[Link] = None @@ -406,7 +404,8 @@ class SwitchPort(SimComponent): if self.enabled: frame.decrement_ttl() self.pcap.capture(frame) - self.connected_node.forward_frame(frame=frame, incoming_port=self) + connected_node: Node = self.connected_node + connected_node.forward_frame(frame=frame, incoming_port=self) return True return False @@ -881,6 +880,8 @@ class Node(SimComponent): "The NICs on the node." ethernet_port: Dict[int, NIC] = {} "The NICs on the node by port id." + dns_server: Optional[IPv4Address] = None + "List of IP addresses of DNS servers used for name resolution." accounts: Dict[str, Account] = {} "All accounts on the node." @@ -930,6 +931,7 @@ class Node(SimComponent): sys_log=kwargs.get("sys_log"), session_manager=kwargs.get("session_manager"), file_system=kwargs.get("file_system"), + dns_server=kwargs.get("dns_server"), ) super().__init__(**kwargs) self.arp.nics = self.nics diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index ce1ef338..78d2e68f 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -10,6 +10,8 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.services.database_service import DatabaseService +from primaite.simulator.system.services.dns_client import DNSClient +from primaite.simulator.system.services.dns_server import DNSServer from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot @@ -126,9 +128,16 @@ def arcd_uc2_network() -> Network: # Client 1 client_1 = Computer( - hostname="client_1", ip_address="192.168.10.21", subnet_mask="255.255.255.0", default_gateway="192.168.10.1" + hostname="client_1", + ip_address="192.168.10.21", + subnet_mask="255.255.255.0", + default_gateway="192.168.10.1", + dns_server=IPv4Address("192.168.1.10"), ) client_1.power_on() + client_1.software_manager.install(DNSClient) + client_1_dns_client_service: DNSServer = client_1.software_manager.software["DNSClient"] # noqa + client_1_dns_client_service.start() network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1]) client_1.software_manager.install(DataManipulationBot) db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"] @@ -136,9 +145,16 @@ def arcd_uc2_network() -> Network: # Client 2 client_2 = Computer( - hostname="client_2", ip_address="192.168.10.22", subnet_mask="255.255.255.0", default_gateway="192.168.10.1" + hostname="client_2", + ip_address="192.168.10.22", + subnet_mask="255.255.255.0", + default_gateway="192.168.10.1", + dns_server=IPv4Address("192.168.1.10"), ) client_2.power_on() + client_2.software_manager.install(DNSClient) + client_2_dns_client_service: DNSServer = client_2.software_manager.software["DNSClient"] # noqa + client_2_dns_client_service.start() network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2]) # Domain Controller @@ -149,6 +165,8 @@ def arcd_uc2_network() -> Network: default_gateway="192.168.1.1", ) domain_controller.power_on() + domain_controller.software_manager.install(DNSServer) + network.connect(endpoint_b=domain_controller.ethernet_port[1], endpoint_a=switch_1.switch_ports[1]) # Database Server @@ -157,6 +175,7 @@ def arcd_uc2_network() -> Network: ip_address="192.168.1.14", subnet_mask="255.255.255.0", default_gateway="192.168.1.1", + dns_server=IPv4Address("192.168.1.10"), ) database_server.power_on() network.connect(endpoint_b=database_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[3]) @@ -196,19 +215,33 @@ def arcd_uc2_network() -> Network: # Web Server web_server = Server( - hostname="web_server", ip_address="192.168.1.12", subnet_mask="255.255.255.0", default_gateway="192.168.1.1" + hostname="web_server", + ip_address="192.168.1.12", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + dns_server=IPv4Address("192.168.1.10"), ) web_server.power_on() web_server.software_manager.install(DatabaseClient) + database_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] database_client.configure(server_ip_address=IPv4Address("192.168.1.14")) network.connect(endpoint_b=web_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[2]) database_client.run() database_client.connect() + # register the web_server to a domain + dns_server_service: DNSServer = domain_controller.software_manager.software["DNSServer"] # noqa + dns_server_service.start() + dns_server_service.dns_register("arcd.com", web_server.ip_address) + # Backup Server backup_server = Server( - hostname="backup_server", ip_address="192.168.1.16", subnet_mask="255.255.255.0", default_gateway="192.168.1.1" + hostname="backup_server", + ip_address="192.168.1.16", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + dns_server=IPv4Address("192.168.1.10"), ) backup_server.power_on() network.connect(endpoint_b=backup_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[4]) @@ -219,6 +252,7 @@ def arcd_uc2_network() -> Network: ip_address="192.168.1.110", subnet_mask="255.255.255.0", default_gateway="192.168.1.1", + dns_server=IPv4Address("192.168.1.10"), ) security_suite.power_on() network.connect(endpoint_b=security_suite.ethernet_port[1], endpoint_a=switch_1.switch_ports[7]) @@ -229,6 +263,12 @@ def arcd_uc2_network() -> Network: router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER) + # Allow PostgreSQL requests + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + ) + + # Allow DNS requests + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1) return network diff --git a/src/primaite/simulator/network/protocols/dns.py b/src/primaite/simulator/network/protocols/dns.py new file mode 100644 index 00000000..41bf5e0c --- /dev/null +++ b/src/primaite/simulator/network/protocols/dns.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from ipaddress import IPv4Address +from typing import Optional + +from pydantic import BaseModel + + +class DNSRequest(BaseModel): + """Represents a DNS Request packet of a network frame. + + :param domain_name_request: Domain Name Request for IP address. + """ + + domain_name_request: str + "Domain Name Request for IP address." + + +class DNSReply(BaseModel): + """Represents a DNS Reply packet of a network frame. + + :param domain_name_ip_address: IP Address of the Domain Name requested. + """ + + domain_name_ip_address: Optional[IPv4Address] = None + "IP Address of the Domain Name requested." + + +class DNSPacket(BaseModel): + """ + Represents the DNS layer of a network frame. + + :param dns_request: DNS Request packet sent by DNS Client. + :param dns_reply: DNS Reply packet generated by DNS Server. + + :Example: + + >>> dns_request = DNSPacket( + ... domain_name_request=DNSRequest(domain_name_request="www.google.co.uk"), + ... dns_reply=None + ... ) + >>> dns_response = DNSPacket( + ... dns_request=DNSRequest(domain_name_request="www.google.co.uk"), + ... dns_reply=DNSReply(domain_name_ip_address=IPv4Address("142.250.179.227")) + ... ) + """ + + dns_request: DNSRequest + "DNS Request packet sent by DNS Client." + dns_reply: Optional[DNSReply] = None + "DNS Reply packet generated by DNS Server." + + def generate_reply(self, domain_ip_address: IPv4Address) -> DNSPacket: + """Generate a new DNSPacket to be sent as a response with a DNS Reply packet which contains the IP address. + + :param domain_ip_address: The IP address that was being sought after from the original target domain name. + :return: A new instance of DNSPacket. + """ + self.dns_reply = DNSReply(domain_name_ip_address=domain_ip_address) + + return self diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py new file mode 100644 index 00000000..78d196b7 --- /dev/null +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -0,0 +1,54 @@ +from ipaddress import IPv4Address +from typing import Any, Dict, Optional + +from primaite.simulator.system.applications.application import Application + + +class WebBrowser(Application): + """ + Represents a web browser in the simulation environment. + + The application requests and loads web pages using its domain name and requesting IP addresses using DNS. + """ + + domain_name: str + "The domain name of the webpage." + domain_name_ip_address: Optional[IPv4Address] + "The IP address of the domain name for the webpage." + history: Dict[str] + "A dict that stores all of the previous domain names." + + def reset_component_for_episode(self, episode: int): + """ + Resets the Application component for a new episode. + + This method ensures the Application is ready for a new episode, including resetting any + stateful properties or statistics, and clearing any message queues. + """ + self.domain_name = "" + self.domain_name_ip_address = None + self.history = {} + + def send(self, payload: Any, session_id: str, **kwargs) -> bool: + """ + Sends a payload to the SessionManager. + + The specifics of how the payload is processed and whether a response payload + is generated should be implemented in subclasses. + + :param payload: The payload to send. + :return: True if successful, False otherwise. + """ + pass + + def receive(self, payload: Any, session_id: str, **kwargs) -> bool: + """ + Receives a payload from the SessionManager. + + The specifics of how the payload is processed and whether a response payload + is generated should be implemented in subclasses. + + :param payload: The payload to receive. + :return: True if successful, False otherwise. + """ + pass diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index 71b7dcec..95ece9f9 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -211,7 +211,7 @@ class SessionManager: session = Session.from_session_key(session_key) self.sessions_by_key[session_key] = session self.sessions_by_uuid[session.uuid] = session - self.software_manager.receive_payload_from_session_manger( + self.software_manager.receive_payload_from_session_manager( payload=frame.payload, port=frame.tcp.dst_port, protocol=frame.ip.protocol, session_id=session.uuid ) diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index 6860ebc2..99445bf8 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -23,7 +23,13 @@ IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware) class SoftwareManager: """A class that manages all running Services and Applications on a Node and facilitates their communication.""" - def __init__(self, session_manager: "SessionManager", sys_log: SysLog, file_system: FileSystem): + def __init__( + self, + session_manager: "SessionManager", + sys_log: SysLog, + file_system: FileSystem, + dns_server: Optional[IPv4Address], + ): """ Initialize a new instance of SoftwareManager. @@ -35,6 +41,7 @@ class SoftwareManager: self.port_protocol_mapping: Dict[Tuple[Port, IPProtocol], Union[Service, Application]] = {} self.sys_log: SysLog = sys_log self.file_system: FileSystem = file_system + self.dns_server: Optional[IPv4Address] = dns_server def get_open_ports(self) -> List[Port]: """ @@ -58,7 +65,9 @@ class SoftwareManager: if software_class in self._software_class_to_name_map: self.sys_log.info(f"Cannot install {software_class} as it is already installed") return - software = software_class(software_manager=self, sys_log=self.sys_log, file_system=self.file_system) + software = software_class( + software_manager=self, sys_log=self.sys_log, file_system=self.file_system, dns_server=self.dns_server + ) if isinstance(software, Application): software.install() software.software_manager = self @@ -114,7 +123,7 @@ class SoftwareManager: payload=payload, dst_ip_address=dest_ip_address, dst_port=dest_port, session_id=session_id ) - def receive_payload_from_session_manger(self, payload: Any, port: Port, protocol: IPProtocol, session_id: str): + def receive_payload_from_session_manager(self, payload: Any, port: Port, protocol: IPProtocol, session_id: str): """ Receive a payload from the SessionManager and forward it to the corresponding service or application. diff --git a/src/primaite/simulator/system/services/dns_client.py b/src/primaite/simulator/system/services/dns_client.py new file mode 100644 index 00000000..cf5278af --- /dev/null +++ b/src/primaite/simulator/system/services/dns_client.py @@ -0,0 +1,154 @@ +from ipaddress import IPv4Address +from typing import Dict, Optional + +from primaite import getLogger +from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.core.software_manager import SoftwareManager +from primaite.simulator.system.services.service import Service + +_LOGGER = getLogger(__name__) + + +class DNSClient(Service): + """Represents a DNS Client as a Service.""" + + dns_cache: Dict[str, IPv4Address] = {} + "A dict of known mappings between domain/URLs names and IPv4 addresses." + dns_server: Optional[IPv4Address] = None + "The DNS Server the client sends requests to." + + def __init__(self, **kwargs): + kwargs["name"] = "DNSClient" + kwargs["port"] = Port.DNS + # DNS uses UDP by default + # it switches to TCP when the bytes exceed 512 (or 4096) bytes + # TCP for now + kwargs["protocol"] = IPProtocol.TCP + super().__init__(**kwargs) + + def describe_state(self) -> Dict: + """ + Describes the current state of the software. + + The specifics of the software's state, including its health, criticality, + and any other pertinent information, should be implemented in subclasses. + + :return: A dictionary containing key-value pairs representing the current state of the software. + :rtype: Dict + """ + state = super().describe_state() + return state + + def reset_component_for_episode(self, episode: int): + """ + Resets the Service component for a new episode. + + This method ensures the Service is ready for a new episode, including resetting any + stateful properties or statistics, and clearing any message queues. + """ + pass + + def add_domain_to_cache(self, domain_name: str, ip_address: IPv4Address): + """ + Adds a domain name to the DNS Client cache. + + :param: domain_name: The domain name to save to cache + :param: ip_address: The IP Address to attach the domain name to + """ + self.dns_cache[domain_name] = ip_address + + def check_domain_exists( + self, + target_domain: str, + session_id: Optional[str] = None, + is_reattempt: bool = False, + ) -> bool: + """Function to check if domain name exists. + + :param: target_domain: The domain requested for an IP address. + :param: session_id: The Session ID the payload is to originate from. Optional. + :param: is_reattempt: Checks if the request has been reattempted. Default is False. + """ + # check if the target domain is in the client's DNS cache + payload = DNSPacket(dns_request=DNSRequest(domain_name_request=target_domain)) + + # check if the domain is already in the DNS cache + if target_domain in self.dns_cache: + self.sys_log.info( + f"DNS Client: Domain lookup for {target_domain} successful, resolves to {self.dns_cache[target_domain]}" + ) + return True + else: + # return False if already reattempted + if is_reattempt: + self.sys_log.info(f"DNS Client: Domain lookup for {target_domain} failed") + return False + else: + # send a request to check if domain name exists in the DNS Server + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( + payload=payload, dest_ip_address=self.dns_server, dest_port=Port.DNS + ) + + # recursively re-call the function passing is_reattempt=True + return self.check_domain_exists( + target_domain=target_domain, + session_id=session_id, + is_reattempt=True, + ) + + def send( + self, + payload: DNSPacket, + session_id: Optional[str] = None, + **kwargs, + ) -> bool: + """ + Sends a payload to the SessionManager. + + The specifics of how the payload is processed and whether a response payload + is generated should be implemented in subclasses. + + :param payload: The payload to be sent. + :param dest_ip_address: The ip address of the payload destination. + :param dest_port: The port of the payload destination. + :param session_id: The Session ID the payload is to originate from. Optional. + + :return: True if successful, False otherwise. + """ + # create DNS request packet + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id) + return True + + def receive( + self, + payload: DNSPacket, + session_id: Optional[str] = None, + **kwargs, + ) -> bool: + """ + Receives a payload from the SessionManager. + + The specifics of how the payload is processed and whether a response payload + is generated should be implemented in subclasses. + + :param payload: The payload to be sent. + :param session_id: The Session ID the payload is to originate from. Optional. + :return: True if successful, False otherwise. + """ + # The payload should be a DNS packet + if not isinstance(payload, DNSPacket): + _LOGGER.debug(f"{payload} is not a DNSPacket") + return False + # cast payload into a DNS packet + payload: DNSPacket = payload + if payload.dns_reply is not None: + # add the IP address to the client cache + if payload.dns_reply.domain_name_ip_address: + self.dns_cache[payload.dns_request.domain_name_request] = payload.dns_reply.domain_name_ip_address + return True + + return False diff --git a/src/primaite/simulator/system/services/dns_server.py b/src/primaite/simulator/system/services/dns_server.py new file mode 100644 index 00000000..c6a9afd3 --- /dev/null +++ b/src/primaite/simulator/system/services/dns_server.py @@ -0,0 +1,122 @@ +from ipaddress import IPv4Address +from typing import Any, Dict, Optional + +from prettytable import MARKDOWN, PrettyTable + +from primaite import getLogger +from primaite.simulator.network.protocols.dns import DNSPacket +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.services.service import Service + +_LOGGER = getLogger(__name__) + + +class DNSServer(Service): + """Represents a DNS Server as a Service.""" + + dns_table: Dict[str, IPv4Address] = {} + "A dict of mappings between domain names and IPv4 addresses." + + def __init__(self, **kwargs): + kwargs["name"] = "DNSServer" + kwargs["port"] = Port.DNS + # DNS uses UDP by default + # it switches to TCP when the bytes exceed 512 (or 4096) bytes + # TCP for now + kwargs["protocol"] = IPProtocol.TCP + super().__init__(**kwargs) + + def describe_state(self) -> Dict: + """ + Describes the current state of the software. + + The specifics of the software's state, including its health, criticality, + and any other pertinent information, should be implemented in subclasses. + + :return: A dictionary containing key-value pairs representing the current state of the software. + :rtype: Dict + """ + state = super().describe_state() + return state + + def dns_lookup(self, target_domain: str) -> Optional[IPv4Address]: + """ + Attempts to find the IP address for a domain name. + + :param target_domain: The single domain name requested by a DNS client. + :return ip_address: The IP address of that domain name or None. + """ + return self.dns_table.get(target_domain) + + def dns_register(self, domain_name: str, domain_ip_address: IPv4Address): + """ + Register a domain name and its IP address. + + :param: domain_name: The domain name to register + :type: domain_name: str + + :param: domain_ip_address: The IP address that the domain should route to + :type: domain_ip_address: IPv4Address + """ + self.dns_table[domain_name] = domain_ip_address + + def reset_component_for_episode(self, episode: int): + """ + Resets the Service component for a new episode. + + This method ensures the Service is ready for a new episode, including resetting any + stateful properties or statistics, and clearing any message queues. + """ + pass + + def receive( + self, + payload: Any, + session_id: Optional[str] = None, + **kwargs, + ) -> bool: + """ + Receives a payload from the SessionManager. + + The specifics of how the payload is processed and whether a response payload + is generated should be implemented in subclasses. + + :param: payload: The payload to send. + :param: session_id: The id of the session. Optional. + + :return: True if DNS request returns a valid IP, otherwise, False + """ + # The payload should be a DNS packet + if not isinstance(payload, DNSPacket): + _LOGGER.debug(f"{payload} is not a DNSPacket") + return False + # cast payload into a DNS packet + payload: DNSPacket = payload + if payload.dns_request is not None: + self.sys_log.info( + f"DNS Server: Received domain lookup request for {payload.dns_request.domain_name_request} " + f"from session {session_id}" + ) + # generate a reply with the correct DNS IP address + payload = payload.generate_reply(self.dns_lookup(payload.dns_request.domain_name_request)) + self.sys_log.info( + f"DNS Server: Responding to domain lookup request for {payload.dns_request.domain_name_request} " + f"with ip address: {payload.dns_reply.domain_name_ip_address}" + ) + # send reply + self.send(payload, session_id) + return payload.dns_reply.domain_name_ip_address is not None + + return False + + def show(self, markdown: bool = False): + """Prints a table of DNS Lookup table.""" + table = PrettyTable(["Domain Name", "IP Address"]) + if markdown: + table.set_style(MARKDOWN) + table.align = "l" + table.title = f"{self.sys_log.hostname} DNS Lookup table" + for dns in self.dns_table.items(): + table.add_row([dns[0], dns[1]]) + print(table) diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 30b48527..20b92027 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -72,29 +72,44 @@ class Service(IOSoftware): """ pass - def send(self, payload: Any, session_id: str, **kwargs) -> bool: + def send( + self, + payload: Any, + session_id: Optional[str] = None, + **kwargs, + ) -> bool: """ Sends a payload to the SessionManager. The specifics of how the payload is processed and whether a response payload is generated should be implemented in subclasses. - :param payload: The payload to send. + :param: payload: The payload to send. + :param: session_id: The id of the session + :return: True if successful, False otherwise. """ - pass + self.software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id) - def receive(self, payload: Any, session_id: str, **kwargs) -> bool: + def receive( + self, + payload: Any, + session_id: Optional[str] = None, + **kwargs, + ) -> bool: """ Receives a payload from the SessionManager. The specifics of how the payload is processed and whether a response payload is generated should be implemented in subclasses. - :param payload: The payload to receive. + :param: payload: The payload to send. + :param: session_id: The id of the session + :return: True if successful, False otherwise. """ - pass + + pass def stop(self) -> None: """Stop the service.""" diff --git a/tests/integration_tests/system/test_dns_client_server.py b/tests/integration_tests/system/test_dns_client_server.py new file mode 100644 index 00000000..640c268a --- /dev/null +++ b/tests/integration_tests/system/test_dns_client_server.py @@ -0,0 +1,28 @@ +from ipaddress import IPv4Address + +from primaite.simulator.network.hardware.nodes.computer import Computer +from primaite.simulator.network.hardware.nodes.server import Server +from primaite.simulator.system.services.dns_client import DNSClient +from primaite.simulator.system.services.dns_server import DNSServer +from primaite.simulator.system.services.service import ServiceOperatingState + + +def test_dns_client_server(uc2_network): + client_1: Computer = uc2_network.get_node_by_hostname("client_1") + domain_controller: Server = uc2_network.get_node_by_hostname("domain_controller") + + dns_client: DNSClient = client_1.software_manager.software["DNSClient"] + dns_server: DNSServer = domain_controller.software_manager.software["DNSServer"] + + assert dns_client.operating_state == ServiceOperatingState.RUNNING + assert dns_server.operating_state == ServiceOperatingState.RUNNING + + dns_server.show() + + # fake domain should not be added to dns cache + assert not dns_client.check_domain_exists(target_domain="fake-domain.com") + assert dns_client.dns_cache.get("fake-domain.com", None) is None + + # arcd.com is registered in dns server and should be saved to cache + assert dns_client.check_domain_exists(target_domain="arcd.com") + assert dns_client.dns_cache.get("arcd.com", None) is not None diff --git a/tests/test_seeding_and_deterministic_session.py b/tests/test_seeding_and_deterministic_session.py index 9500c4a3..aff5496a 100644 --- a/tests/test_seeding_and_deterministic_session.py +++ b/tests/test_seeding_and_deterministic_session.py @@ -45,7 +45,6 @@ def test_seeded_learning(temp_primaite_session): ), "Expected output is based upon a agent that was trained with seed 67890" session.learn() actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict() - print(actual_mean_reward_per_episode, "THISt") assert actual_mean_reward_per_episode == expected_mean_reward_per_episode diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns.py new file mode 100644 index 00000000..b4f20539 --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns.py @@ -0,0 +1,100 @@ +from ipaddress import IPv4Address + +import pytest + +from primaite.simulator.network.hardware.base import Node +from primaite.simulator.network.protocols.dns import DNSPacket, DNSReply, DNSRequest +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.services.dns_client import DNSClient +from primaite.simulator.system.services.dns_server import DNSServer + + +@pytest.fixture(scope="function") +def dns_server() -> Node: + node = Node(hostname="dns_server") + node.software_manager.install(software_class=DNSServer) + node.software_manager.software["DNSServer"].start() + return node + + +@pytest.fixture(scope="function") +def dns_client() -> Node: + node = Node(hostname="dns_client") + node.software_manager.install(software_class=DNSClient) + node.software_manager.software["DNSClient"].start() + return node + + +def test_create_dns_server(dns_server): + assert dns_server is not None + dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"] + assert dns_server_service.name is "DNSServer" + assert dns_server_service.port is Port.DNS + assert dns_server_service.protocol is IPProtocol.TCP + + +def test_create_dns_client(dns_client): + assert dns_client is not None + dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"] + assert dns_client_service.name is "DNSClient" + assert dns_client_service.port is Port.DNS + assert dns_client_service.protocol is IPProtocol.TCP + + +def test_dns_server_domain_name_registration(dns_server): + """Test to check if the domain name registration works.""" + dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"] + + # register the web server in the domain controller + dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12")) + + # return none for an unknown domain + assert dns_server_service.dns_lookup("fake-domain.com") is None + assert dns_server_service.dns_lookup("real-domain.com") is not None + + +def test_dns_client_check_domain_in_cache(dns_client): + """Test to make sure that the check_domain_in_cache returns the correct values.""" + dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"] + + # add a domain to the dns client cache + dns_client_service.add_domain_to_cache("real-domain.com", IPv4Address("192.168.1.12")) + + assert dns_client_service.check_domain_exists("fake-domain.com") is False + assert dns_client_service.check_domain_exists("real-domain.com") is True + + +def test_dns_server_receive(dns_server): + """Test to make sure that the DNS Server correctly responds to a DNS Client request.""" + dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"] + + # register the web server in the domain controller + dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12")) + + assert ( + dns_server_service.receive(payload=DNSPacket(dns_request=DNSRequest(domain_name_request="fake-domain.com"))) + is False + ) + + assert ( + dns_server_service.receive(payload=DNSPacket(dns_request=DNSRequest(domain_name_request="real-domain.com"))) + is True + ) + + dns_server_service.show() + + +def test_dns_client_receive(dns_client): + """Test to make sure the DNS Client knows how to deal with request responses.""" + dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"] + + dns_client_service.receive( + payload=DNSPacket( + dns_request=DNSRequest(domain_name_request="real-domain.com"), + dns_reply=DNSReply(domain_name_ip_address=IPv4Address("192.168.1.12")), + ) + ) + + # domain name should be saved to cache + assert dns_client_service.dns_cache["real-domain.com"] == IPv4Address("192.168.1.12")