Files
PrimAITE/src/primaite/simulator/system/services/dns/dns_server.py
2025-03-13 15:07:38 +00:00

138 lines
4.9 KiB
Python

# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
"""DNS Server."""
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from pydantic import Field
from primaite import getLogger
from primaite.simulator.network.protocols.dns import DNSPacket
from primaite.simulator.system.services.service import Service
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.port import PORT_LOOKUP
_LOGGER = getLogger(__name__)
class DNSServer(Service, discriminator="dns-server"):
"""Represents a DNS Server as a Service."""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for DNSServer."""
type: str = "dns-server"
domain_mapping: dict = {}
config: ConfigSchema = Field(default_factory=lambda: DNSServer.ConfigSchema())
dns_table: Dict[str, IPv4Address] = {}
"A dict of mappings between domain names and IPv4 addresses."
def __init__(self, **kwargs):
kwargs["name"] = "dns-server"
kwargs["port"] = PORT_LOOKUP["DNS"]
# DNS uses UDP by default
# it switches to TCP when the bytes exceed 512 (or 4096) bytes
# TCP for now
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
self.dns_table = self.config.domain_mapping
self.start()
def describe_state(self) -> Dict:
"""
Describes the current state of the software.
The specifics of the software's state, including its health, criticality,
and any other pertinent information, should be implemented in subclasses.
:return: A dictionary containing key-value pairs representing the current state of the software.
:rtype: Dict
"""
state = super().describe_state()
return state
def dns_lookup(self, target_domain: str) -> Optional[IPv4Address]:
"""
Attempts to find the IP address for a domain name.
:param target_domain: The single domain name requested by a DNS client.
:return ip_address: The IP address of that domain name or None.
"""
if not self._can_perform_action():
return
return self.dns_table.get(target_domain)
def dns_register(self, domain_name: str, domain_ip_address: IPv4Address):
"""
Register a domain name and its IP address.
:param: domain_name: The domain name to register
:type: domain_name: str
:param: domain_ip_address: The IP address that the domain should route to
:type: domain_ip_address: IPv4Address
"""
if not self._can_perform_action():
return
self.dns_table[domain_name] = domain_ip_address
def receive(
self,
payload: Any,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""
Receives a payload from the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param: payload: The payload to send.
:param: session_id: The id of the session. Optional.
:return: True if DNS request returns a valid IP, otherwise, False
"""
if not super().receive(payload=payload, session_id=session_id, **kwargs):
return False
# The payload should be a DNS packet
if not isinstance(payload, DNSPacket):
self.sys_log.warning(f"{payload} is not a DNSPacket")
self.sys_log.debug(f"{payload} is not a DNSPacket")
return False
# cast payload into a DNS packet
payload: DNSPacket = payload
if payload.dns_request is not None:
self.sys_log.info(
f"{self.name}: Received domain lookup request for {payload.dns_request.domain_name_request} "
f"from session {session_id}"
)
# generate a reply with the correct DNS IP address
payload = payload.generate_reply(self.dns_lookup(payload.dns_request.domain_name_request))
self.sys_log.info(
f"{self.name}: Responding to domain lookup request for {payload.dns_request.domain_name_request} "
f"with ip address: {payload.dns_reply.domain_name_ip_address}"
)
# send reply
self.send(payload, session_id)
return payload.dns_reply.domain_name_ip_address is not None
return False
def show(self, markdown: bool = False):
"""Prints a table of DNS Lookup table."""
table = PrettyTable(["Domain Name", "IP Address"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.sys_log.hostname} DNS Lookup table"
for dns in self.dns_table.items():
table.add_row([dns[0], dns[1]])
print(table)