Merged PR 177: DNS Server and Client
## Summary Carrying on where Sunil left the task. Added more functionality to the DNS. - DNS Server can keep a list of domain names that the DNS Client can request - DNS Client has a cache of domain names that are valid from the DNS Server ## Test process https://dev.azure.com/ma-dev-uk/PrimAITE/_git/PrimAITE/pullrequest/177?_a=files&path=/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns.py ## Checklist - [x] This PR is linked to a **work item** - [x] I have performed **self-review** of the code - [x] I have written **tests** for any new functionality added with this PR - [X] I have updated the **documentation** if this PR changes or adds functionality - [ ] I have written/updated **design docs** if this PR implements new functionality - [X] I have update the **change log** - [x] I have run **pre-commit** checks for code style Related work items: #1752
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -144,9 +144,11 @@ cython_debug/
|
||||
# IDE
|
||||
.idea/
|
||||
docs/source/primaite-dependencies.rst
|
||||
.vscode/
|
||||
|
||||
# outputs
|
||||
src/primaite/outputs/
|
||||
simulation_output/
|
||||
|
||||
# benchmark session outputs
|
||||
benchmark/output
|
||||
|
||||
@@ -29,6 +29,7 @@ SessionManager.
|
||||
1. Creating a simulation - this notebook explains how to build up a simulation using the Python package. (WIP)
|
||||
- Red Agent Services:
|
||||
- Data Manipulator Bot - A red agent service which sends a payload to a target machine. (By default this payload is a SQL query that breaks a database)
|
||||
- DNS Services: DNS Client and DNS Server
|
||||
|
||||
## [2.0.0] - 2023-07-26
|
||||
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
DNS Client Server
|
||||
=================
|
||||
|
||||
DNS Server
|
||||
----------
|
||||
Also known as a DNS Resolver, the ``DNSServer`` provides a DNS Server simulation by extending the base Service class.
|
||||
|
||||
Key capabilities
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
- Simulates DNS requests and DNSPacket transfer across a network
|
||||
- Registers domain names and the IP Address linked to the domain name
|
||||
- Returns the IP address for a given domain name within a DNS Packet that a DNS Client can read
|
||||
- Leverages the Service base class for install/uninstall, status tracking, etc.
|
||||
|
||||
Usage
|
||||
^^^^^
|
||||
- Install on a Node via the ``SoftwareManager`` to start the database service.
|
||||
- Service runs on TCP port 53 by default. (TODO: TCP for now, should be UDP in future)
|
||||
|
||||
Implementation
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
- DNS request and responses use a ``DNSPacket`` object
|
||||
- Extends Service class for integration with ``SoftwareManager``.
|
||||
|
||||
DNS Client
|
||||
----------
|
||||
|
||||
The DNSClient provides a client interface for connecting to the ``DNSServer``.
|
||||
|
||||
Key features
|
||||
^^^^^^^^^^^^
|
||||
|
||||
- Connects to the ``DNSServer`` via the ``SoftwareManager``.
|
||||
- Executes DNS lookup requests and keeps a cache of known domain name IP addresses.
|
||||
- Handles connection to DNSServer and querying for domain name IP addresses.
|
||||
|
||||
Usage
|
||||
^^^^^
|
||||
|
||||
- Install on a Node via the ``SoftwareManager`` to start the database service.
|
||||
- Service runs on TCP port 53 by default. (TODO: TCP for now, should be UDP in future)
|
||||
- Execute domain name checks with ``check_domain_exists``.
|
||||
- ``DNSClient`` will automatically add the IP Address of the domain into its cache
|
||||
|
||||
Implementation
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
- Leverages ``SoftwareManager`` for sending payloads over the network.
|
||||
- Provides easy interface for Nodes to find IP addresses via domain names.
|
||||
- Extends base Service class.
|
||||
@@ -17,3 +17,4 @@ Contents
|
||||
|
||||
database_client_server
|
||||
data_manipulation_bot
|
||||
dns_client_server
|
||||
|
||||
@@ -192,6 +192,9 @@ class SimComponent(BaseModel):
|
||||
|
||||
:param action: List describing the action to apply to this object.
|
||||
:type action: List[str]
|
||||
|
||||
:param: context: Dict containing context for actions
|
||||
:type context: Dict
|
||||
"""
|
||||
if self.action_manager is None:
|
||||
return
|
||||
|
||||
@@ -5,7 +5,7 @@ import secrets
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Literal, Optional, Tuple, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
@@ -89,8 +89,6 @@ class NIC(SimComponent):
|
||||
"The Maximum Transmission Unit (MTU) of the NIC in Bytes. Default is 1500 B"
|
||||
wake_on_lan: bool = False
|
||||
"Indicates if the NIC supports Wake-on-LAN functionality."
|
||||
dns_servers: List[IPv4Address] = []
|
||||
"List of IP addresses of DNS servers used for name resolution."
|
||||
connected_node: Optional[Node] = None
|
||||
"The Node to which the NIC is connected."
|
||||
connected_link: Optional[Link] = None
|
||||
@@ -406,7 +404,8 @@ class SwitchPort(SimComponent):
|
||||
if self.enabled:
|
||||
frame.decrement_ttl()
|
||||
self.pcap.capture(frame)
|
||||
self.connected_node.forward_frame(frame=frame, incoming_port=self)
|
||||
connected_node: Node = self.connected_node
|
||||
connected_node.forward_frame(frame=frame, incoming_port=self)
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -881,6 +880,8 @@ class Node(SimComponent):
|
||||
"The NICs on the node."
|
||||
ethernet_port: Dict[int, NIC] = {}
|
||||
"The NICs on the node by port id."
|
||||
dns_server: Optional[IPv4Address] = None
|
||||
"List of IP addresses of DNS servers used for name resolution."
|
||||
|
||||
accounts: Dict[str, Account] = {}
|
||||
"All accounts on the node."
|
||||
@@ -930,6 +931,7 @@ class Node(SimComponent):
|
||||
sys_log=kwargs.get("sys_log"),
|
||||
session_manager=kwargs.get("session_manager"),
|
||||
file_system=kwargs.get("file_system"),
|
||||
dns_server=kwargs.get("dns_server"),
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
self.arp.nics = self.nics
|
||||
|
||||
@@ -10,6 +10,8 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.services.database_service import DatabaseService
|
||||
from primaite.simulator.system.services.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
|
||||
|
||||
|
||||
@@ -126,9 +128,16 @@ def arcd_uc2_network() -> Network:
|
||||
|
||||
# Client 1
|
||||
client_1 = Computer(
|
||||
hostname="client_1", ip_address="192.168.10.21", subnet_mask="255.255.255.0", default_gateway="192.168.10.1"
|
||||
hostname="client_1",
|
||||
ip_address="192.168.10.21",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.10.1",
|
||||
dns_server=IPv4Address("192.168.1.10"),
|
||||
)
|
||||
client_1.power_on()
|
||||
client_1.software_manager.install(DNSClient)
|
||||
client_1_dns_client_service: DNSServer = client_1.software_manager.software["DNSClient"] # noqa
|
||||
client_1_dns_client_service.start()
|
||||
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
|
||||
client_1.software_manager.install(DataManipulationBot)
|
||||
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
|
||||
@@ -136,9 +145,16 @@ def arcd_uc2_network() -> Network:
|
||||
|
||||
# Client 2
|
||||
client_2 = Computer(
|
||||
hostname="client_2", ip_address="192.168.10.22", subnet_mask="255.255.255.0", default_gateway="192.168.10.1"
|
||||
hostname="client_2",
|
||||
ip_address="192.168.10.22",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.10.1",
|
||||
dns_server=IPv4Address("192.168.1.10"),
|
||||
)
|
||||
client_2.power_on()
|
||||
client_2.software_manager.install(DNSClient)
|
||||
client_2_dns_client_service: DNSServer = client_2.software_manager.software["DNSClient"] # noqa
|
||||
client_2_dns_client_service.start()
|
||||
network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2])
|
||||
|
||||
# Domain Controller
|
||||
@@ -149,6 +165,8 @@ def arcd_uc2_network() -> Network:
|
||||
default_gateway="192.168.1.1",
|
||||
)
|
||||
domain_controller.power_on()
|
||||
domain_controller.software_manager.install(DNSServer)
|
||||
|
||||
network.connect(endpoint_b=domain_controller.ethernet_port[1], endpoint_a=switch_1.switch_ports[1])
|
||||
|
||||
# Database Server
|
||||
@@ -157,6 +175,7 @@ def arcd_uc2_network() -> Network:
|
||||
ip_address="192.168.1.14",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
dns_server=IPv4Address("192.168.1.10"),
|
||||
)
|
||||
database_server.power_on()
|
||||
network.connect(endpoint_b=database_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[3])
|
||||
@@ -196,19 +215,33 @@ def arcd_uc2_network() -> Network:
|
||||
|
||||
# Web Server
|
||||
web_server = Server(
|
||||
hostname="web_server", ip_address="192.168.1.12", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
|
||||
hostname="web_server",
|
||||
ip_address="192.168.1.12",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
dns_server=IPv4Address("192.168.1.10"),
|
||||
)
|
||||
web_server.power_on()
|
||||
web_server.software_manager.install(DatabaseClient)
|
||||
|
||||
database_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
|
||||
database_client.configure(server_ip_address=IPv4Address("192.168.1.14"))
|
||||
network.connect(endpoint_b=web_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[2])
|
||||
database_client.run()
|
||||
database_client.connect()
|
||||
|
||||
# register the web_server to a domain
|
||||
dns_server_service: DNSServer = domain_controller.software_manager.software["DNSServer"] # noqa
|
||||
dns_server_service.start()
|
||||
dns_server_service.dns_register("arcd.com", web_server.ip_address)
|
||||
|
||||
# Backup Server
|
||||
backup_server = Server(
|
||||
hostname="backup_server", ip_address="192.168.1.16", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
|
||||
hostname="backup_server",
|
||||
ip_address="192.168.1.16",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
dns_server=IPv4Address("192.168.1.10"),
|
||||
)
|
||||
backup_server.power_on()
|
||||
network.connect(endpoint_b=backup_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[4])
|
||||
@@ -219,6 +252,7 @@ def arcd_uc2_network() -> Network:
|
||||
ip_address="192.168.1.110",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
dns_server=IPv4Address("192.168.1.10"),
|
||||
)
|
||||
security_suite.power_on()
|
||||
network.connect(endpoint_b=security_suite.ethernet_port[1], endpoint_a=switch_1.switch_ports[7])
|
||||
@@ -229,6 +263,12 @@ def arcd_uc2_network() -> Network:
|
||||
|
||||
router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
|
||||
|
||||
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER)
|
||||
# Allow PostgreSQL requests
|
||||
router_1.acl.add_rule(
|
||||
action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0
|
||||
)
|
||||
|
||||
# Allow DNS requests
|
||||
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1)
|
||||
|
||||
return network
|
||||
|
||||
61
src/primaite/simulator/network/protocols/dns.py
Normal file
61
src/primaite/simulator/network/protocols/dns.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DNSRequest(BaseModel):
|
||||
"""Represents a DNS Request packet of a network frame.
|
||||
|
||||
:param domain_name_request: Domain Name Request for IP address.
|
||||
"""
|
||||
|
||||
domain_name_request: str
|
||||
"Domain Name Request for IP address."
|
||||
|
||||
|
||||
class DNSReply(BaseModel):
|
||||
"""Represents a DNS Reply packet of a network frame.
|
||||
|
||||
:param domain_name_ip_address: IP Address of the Domain Name requested.
|
||||
"""
|
||||
|
||||
domain_name_ip_address: Optional[IPv4Address] = None
|
||||
"IP Address of the Domain Name requested."
|
||||
|
||||
|
||||
class DNSPacket(BaseModel):
|
||||
"""
|
||||
Represents the DNS layer of a network frame.
|
||||
|
||||
:param dns_request: DNS Request packet sent by DNS Client.
|
||||
:param dns_reply: DNS Reply packet generated by DNS Server.
|
||||
|
||||
:Example:
|
||||
|
||||
>>> dns_request = DNSPacket(
|
||||
... domain_name_request=DNSRequest(domain_name_request="www.google.co.uk"),
|
||||
... dns_reply=None
|
||||
... )
|
||||
>>> dns_response = DNSPacket(
|
||||
... dns_request=DNSRequest(domain_name_request="www.google.co.uk"),
|
||||
... dns_reply=DNSReply(domain_name_ip_address=IPv4Address("142.250.179.227"))
|
||||
... )
|
||||
"""
|
||||
|
||||
dns_request: DNSRequest
|
||||
"DNS Request packet sent by DNS Client."
|
||||
dns_reply: Optional[DNSReply] = None
|
||||
"DNS Reply packet generated by DNS Server."
|
||||
|
||||
def generate_reply(self, domain_ip_address: IPv4Address) -> DNSPacket:
|
||||
"""Generate a new DNSPacket to be sent as a response with a DNS Reply packet which contains the IP address.
|
||||
|
||||
:param domain_ip_address: The IP address that was being sought after from the original target domain name.
|
||||
:return: A new instance of DNSPacket.
|
||||
"""
|
||||
self.dns_reply = DNSReply(domain_name_ip_address=domain_ip_address)
|
||||
|
||||
return self
|
||||
54
src/primaite/simulator/system/applications/web_browser.py
Normal file
54
src/primaite/simulator/system/applications/web_browser.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
|
||||
|
||||
class WebBrowser(Application):
|
||||
"""
|
||||
Represents a web browser in the simulation environment.
|
||||
|
||||
The application requests and loads web pages using its domain name and requesting IP addresses using DNS.
|
||||
"""
|
||||
|
||||
domain_name: str
|
||||
"The domain name of the webpage."
|
||||
domain_name_ip_address: Optional[IPv4Address]
|
||||
"The IP address of the domain name for the webpage."
|
||||
history: Dict[str]
|
||||
"A dict that stores all of the previous domain names."
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
Resets the Application component for a new episode.
|
||||
|
||||
This method ensures the Application is ready for a new episode, including resetting any
|
||||
stateful properties or statistics, and clearing any message queues.
|
||||
"""
|
||||
self.domain_name = ""
|
||||
self.domain_name_ip_address = None
|
||||
self.history = {}
|
||||
|
||||
def send(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
"""
|
||||
Sends a payload to the SessionManager.
|
||||
|
||||
The specifics of how the payload is processed and whether a response payload
|
||||
is generated should be implemented in subclasses.
|
||||
|
||||
:param payload: The payload to send.
|
||||
:return: True if successful, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
"""
|
||||
Receives a payload from the SessionManager.
|
||||
|
||||
The specifics of how the payload is processed and whether a response payload
|
||||
is generated should be implemented in subclasses.
|
||||
|
||||
:param payload: The payload to receive.
|
||||
:return: True if successful, False otherwise.
|
||||
"""
|
||||
pass
|
||||
@@ -211,7 +211,7 @@ class SessionManager:
|
||||
session = Session.from_session_key(session_key)
|
||||
self.sessions_by_key[session_key] = session
|
||||
self.sessions_by_uuid[session.uuid] = session
|
||||
self.software_manager.receive_payload_from_session_manger(
|
||||
self.software_manager.receive_payload_from_session_manager(
|
||||
payload=frame.payload, port=frame.tcp.dst_port, protocol=frame.ip.protocol, session_id=session.uuid
|
||||
)
|
||||
|
||||
|
||||
@@ -23,7 +23,13 @@ IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware)
|
||||
class SoftwareManager:
|
||||
"""A class that manages all running Services and Applications on a Node and facilitates their communication."""
|
||||
|
||||
def __init__(self, session_manager: "SessionManager", sys_log: SysLog, file_system: FileSystem):
|
||||
def __init__(
|
||||
self,
|
||||
session_manager: "SessionManager",
|
||||
sys_log: SysLog,
|
||||
file_system: FileSystem,
|
||||
dns_server: Optional[IPv4Address],
|
||||
):
|
||||
"""
|
||||
Initialize a new instance of SoftwareManager.
|
||||
|
||||
@@ -35,6 +41,7 @@ class SoftwareManager:
|
||||
self.port_protocol_mapping: Dict[Tuple[Port, IPProtocol], Union[Service, Application]] = {}
|
||||
self.sys_log: SysLog = sys_log
|
||||
self.file_system: FileSystem = file_system
|
||||
self.dns_server: Optional[IPv4Address] = dns_server
|
||||
|
||||
def get_open_ports(self) -> List[Port]:
|
||||
"""
|
||||
@@ -58,7 +65,9 @@ class SoftwareManager:
|
||||
if software_class in self._software_class_to_name_map:
|
||||
self.sys_log.info(f"Cannot install {software_class} as it is already installed")
|
||||
return
|
||||
software = software_class(software_manager=self, sys_log=self.sys_log, file_system=self.file_system)
|
||||
software = software_class(
|
||||
software_manager=self, sys_log=self.sys_log, file_system=self.file_system, dns_server=self.dns_server
|
||||
)
|
||||
if isinstance(software, Application):
|
||||
software.install()
|
||||
software.software_manager = self
|
||||
@@ -114,7 +123,7 @@ class SoftwareManager:
|
||||
payload=payload, dst_ip_address=dest_ip_address, dst_port=dest_port, session_id=session_id
|
||||
)
|
||||
|
||||
def receive_payload_from_session_manger(self, payload: Any, port: Port, protocol: IPProtocol, session_id: str):
|
||||
def receive_payload_from_session_manager(self, payload: Any, port: Port, protocol: IPProtocol, session_id: str):
|
||||
"""
|
||||
Receive a payload from the SessionManager and forward it to the corresponding service or application.
|
||||
|
||||
|
||||
154
src/primaite/simulator/system/services/dns_client.py
Normal file
154
src/primaite/simulator/system/services/dns_client.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.simulator.system.services.service import Service
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class DNSClient(Service):
|
||||
"""Represents a DNS Client as a Service."""
|
||||
|
||||
dns_cache: Dict[str, IPv4Address] = {}
|
||||
"A dict of known mappings between domain/URLs names and IPv4 addresses."
|
||||
dns_server: Optional[IPv4Address] = None
|
||||
"The DNS Server the client sends requests to."
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "DNSClient"
|
||||
kwargs["port"] = Port.DNS
|
||||
# DNS uses UDP by default
|
||||
# it switches to TCP when the bytes exceed 512 (or 4096) bytes
|
||||
# TCP for now
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the software.
|
||||
|
||||
The specifics of the software's state, including its health, criticality,
|
||||
and any other pertinent information, should be implemented in subclasses.
|
||||
|
||||
:return: A dictionary containing key-value pairs representing the current state of the software.
|
||||
:rtype: Dict
|
||||
"""
|
||||
state = super().describe_state()
|
||||
return state
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
Resets the Service component for a new episode.
|
||||
|
||||
This method ensures the Service is ready for a new episode, including resetting any
|
||||
stateful properties or statistics, and clearing any message queues.
|
||||
"""
|
||||
pass
|
||||
|
||||
def add_domain_to_cache(self, domain_name: str, ip_address: IPv4Address):
|
||||
"""
|
||||
Adds a domain name to the DNS Client cache.
|
||||
|
||||
:param: domain_name: The domain name to save to cache
|
||||
:param: ip_address: The IP Address to attach the domain name to
|
||||
"""
|
||||
self.dns_cache[domain_name] = ip_address
|
||||
|
||||
def check_domain_exists(
|
||||
self,
|
||||
target_domain: str,
|
||||
session_id: Optional[str] = None,
|
||||
is_reattempt: bool = False,
|
||||
) -> bool:
|
||||
"""Function to check if domain name exists.
|
||||
|
||||
:param: target_domain: The domain requested for an IP address.
|
||||
:param: session_id: The Session ID the payload is to originate from. Optional.
|
||||
:param: is_reattempt: Checks if the request has been reattempted. Default is False.
|
||||
"""
|
||||
# check if the target domain is in the client's DNS cache
|
||||
payload = DNSPacket(dns_request=DNSRequest(domain_name_request=target_domain))
|
||||
|
||||
# check if the domain is already in the DNS cache
|
||||
if target_domain in self.dns_cache:
|
||||
self.sys_log.info(
|
||||
f"DNS Client: Domain lookup for {target_domain} successful, resolves to {self.dns_cache[target_domain]}"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
# return False if already reattempted
|
||||
if is_reattempt:
|
||||
self.sys_log.info(f"DNS Client: Domain lookup for {target_domain} failed")
|
||||
return False
|
||||
else:
|
||||
# send a request to check if domain name exists in the DNS Server
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_ip_address=self.dns_server, dest_port=Port.DNS
|
||||
)
|
||||
|
||||
# recursively re-call the function passing is_reattempt=True
|
||||
return self.check_domain_exists(
|
||||
target_domain=target_domain,
|
||||
session_id=session_id,
|
||||
is_reattempt=True,
|
||||
)
|
||||
|
||||
def send(
|
||||
self,
|
||||
payload: DNSPacket,
|
||||
session_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""
|
||||
Sends a payload to the SessionManager.
|
||||
|
||||
The specifics of how the payload is processed and whether a response payload
|
||||
is generated should be implemented in subclasses.
|
||||
|
||||
:param payload: The payload to be sent.
|
||||
:param dest_ip_address: The ip address of the payload destination.
|
||||
:param dest_port: The port of the payload destination.
|
||||
:param session_id: The Session ID the payload is to originate from. Optional.
|
||||
|
||||
:return: True if successful, False otherwise.
|
||||
"""
|
||||
# create DNS request packet
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id)
|
||||
return True
|
||||
|
||||
def receive(
|
||||
self,
|
||||
payload: DNSPacket,
|
||||
session_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""
|
||||
Receives a payload from the SessionManager.
|
||||
|
||||
The specifics of how the payload is processed and whether a response payload
|
||||
is generated should be implemented in subclasses.
|
||||
|
||||
:param payload: The payload to be sent.
|
||||
:param session_id: The Session ID the payload is to originate from. Optional.
|
||||
:return: True if successful, False otherwise.
|
||||
"""
|
||||
# The payload should be a DNS packet
|
||||
if not isinstance(payload, DNSPacket):
|
||||
_LOGGER.debug(f"{payload} is not a DNSPacket")
|
||||
return False
|
||||
# cast payload into a DNS packet
|
||||
payload: DNSPacket = payload
|
||||
if payload.dns_reply is not None:
|
||||
# add the IP address to the client cache
|
||||
if payload.dns_reply.domain_name_ip_address:
|
||||
self.dns_cache[payload.dns_request.domain_name_request] = payload.dns_reply.domain_name_ip_address
|
||||
return True
|
||||
|
||||
return False
|
||||
122
src/primaite/simulator/system/services/dns_server.py
Normal file
122
src/primaite/simulator/system/services/dns_server.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.dns import DNSPacket
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.services.service import Service
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class DNSServer(Service):
|
||||
"""Represents a DNS Server as a Service."""
|
||||
|
||||
dns_table: Dict[str, IPv4Address] = {}
|
||||
"A dict of mappings between domain names and IPv4 addresses."
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "DNSServer"
|
||||
kwargs["port"] = Port.DNS
|
||||
# DNS uses UDP by default
|
||||
# it switches to TCP when the bytes exceed 512 (or 4096) bytes
|
||||
# TCP for now
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the software.
|
||||
|
||||
The specifics of the software's state, including its health, criticality,
|
||||
and any other pertinent information, should be implemented in subclasses.
|
||||
|
||||
:return: A dictionary containing key-value pairs representing the current state of the software.
|
||||
:rtype: Dict
|
||||
"""
|
||||
state = super().describe_state()
|
||||
return state
|
||||
|
||||
def dns_lookup(self, target_domain: str) -> Optional[IPv4Address]:
|
||||
"""
|
||||
Attempts to find the IP address for a domain name.
|
||||
|
||||
:param target_domain: The single domain name requested by a DNS client.
|
||||
:return ip_address: The IP address of that domain name or None.
|
||||
"""
|
||||
return self.dns_table.get(target_domain)
|
||||
|
||||
def dns_register(self, domain_name: str, domain_ip_address: IPv4Address):
|
||||
"""
|
||||
Register a domain name and its IP address.
|
||||
|
||||
:param: domain_name: The domain name to register
|
||||
:type: domain_name: str
|
||||
|
||||
:param: domain_ip_address: The IP address that the domain should route to
|
||||
:type: domain_ip_address: IPv4Address
|
||||
"""
|
||||
self.dns_table[domain_name] = domain_ip_address
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
Resets the Service component for a new episode.
|
||||
|
||||
This method ensures the Service is ready for a new episode, including resetting any
|
||||
stateful properties or statistics, and clearing any message queues.
|
||||
"""
|
||||
pass
|
||||
|
||||
def receive(
|
||||
self,
|
||||
payload: Any,
|
||||
session_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""
|
||||
Receives a payload from the SessionManager.
|
||||
|
||||
The specifics of how the payload is processed and whether a response payload
|
||||
is generated should be implemented in subclasses.
|
||||
|
||||
:param: payload: The payload to send.
|
||||
:param: session_id: The id of the session. Optional.
|
||||
|
||||
:return: True if DNS request returns a valid IP, otherwise, False
|
||||
"""
|
||||
# The payload should be a DNS packet
|
||||
if not isinstance(payload, DNSPacket):
|
||||
_LOGGER.debug(f"{payload} is not a DNSPacket")
|
||||
return False
|
||||
# cast payload into a DNS packet
|
||||
payload: DNSPacket = payload
|
||||
if payload.dns_request is not None:
|
||||
self.sys_log.info(
|
||||
f"DNS Server: Received domain lookup request for {payload.dns_request.domain_name_request} "
|
||||
f"from session {session_id}"
|
||||
)
|
||||
# generate a reply with the correct DNS IP address
|
||||
payload = payload.generate_reply(self.dns_lookup(payload.dns_request.domain_name_request))
|
||||
self.sys_log.info(
|
||||
f"DNS Server: Responding to domain lookup request for {payload.dns_request.domain_name_request} "
|
||||
f"with ip address: {payload.dns_reply.domain_name_ip_address}"
|
||||
)
|
||||
# send reply
|
||||
self.send(payload, session_id)
|
||||
return payload.dns_reply.domain_name_ip_address is not None
|
||||
|
||||
return False
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
"""Prints a table of DNS Lookup table."""
|
||||
table = PrettyTable(["Domain Name", "IP Address"])
|
||||
if markdown:
|
||||
table.set_style(MARKDOWN)
|
||||
table.align = "l"
|
||||
table.title = f"{self.sys_log.hostname} DNS Lookup table"
|
||||
for dns in self.dns_table.items():
|
||||
table.add_row([dns[0], dns[1]])
|
||||
print(table)
|
||||
@@ -72,29 +72,44 @@ class Service(IOSoftware):
|
||||
"""
|
||||
pass
|
||||
|
||||
def send(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
def send(
|
||||
self,
|
||||
payload: Any,
|
||||
session_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""
|
||||
Sends a payload to the SessionManager.
|
||||
|
||||
The specifics of how the payload is processed and whether a response payload
|
||||
is generated should be implemented in subclasses.
|
||||
|
||||
:param payload: The payload to send.
|
||||
:param: payload: The payload to send.
|
||||
:param: session_id: The id of the session
|
||||
|
||||
:return: True if successful, False otherwise.
|
||||
"""
|
||||
pass
|
||||
self.software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id)
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
def receive(
|
||||
self,
|
||||
payload: Any,
|
||||
session_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""
|
||||
Receives a payload from the SessionManager.
|
||||
|
||||
The specifics of how the payload is processed and whether a response payload
|
||||
is generated should be implemented in subclasses.
|
||||
|
||||
:param payload: The payload to receive.
|
||||
:param: payload: The payload to send.
|
||||
:param: session_id: The id of the session
|
||||
|
||||
:return: True if successful, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
pass
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the service."""
|
||||
|
||||
28
tests/integration_tests/system/test_dns_client_server.py
Normal file
28
tests/integration_tests/system/test_dns_client_server.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from ipaddress import IPv4Address
|
||||
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.system.services.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
|
||||
|
||||
def test_dns_client_server(uc2_network):
|
||||
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
|
||||
domain_controller: Server = uc2_network.get_node_by_hostname("domain_controller")
|
||||
|
||||
dns_client: DNSClient = client_1.software_manager.software["DNSClient"]
|
||||
dns_server: DNSServer = domain_controller.software_manager.software["DNSServer"]
|
||||
|
||||
assert dns_client.operating_state == ServiceOperatingState.RUNNING
|
||||
assert dns_server.operating_state == ServiceOperatingState.RUNNING
|
||||
|
||||
dns_server.show()
|
||||
|
||||
# fake domain should not be added to dns cache
|
||||
assert not dns_client.check_domain_exists(target_domain="fake-domain.com")
|
||||
assert dns_client.dns_cache.get("fake-domain.com", None) is None
|
||||
|
||||
# arcd.com is registered in dns server and should be saved to cache
|
||||
assert dns_client.check_domain_exists(target_domain="arcd.com")
|
||||
assert dns_client.dns_cache.get("arcd.com", None) is not None
|
||||
@@ -45,7 +45,6 @@ def test_seeded_learning(temp_primaite_session):
|
||||
), "Expected output is based upon a agent that was trained with seed 67890"
|
||||
session.learn()
|
||||
actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict()
|
||||
print(actual_mean_reward_per_episode, "THISt")
|
||||
|
||||
assert actual_mean_reward_per_episode == expected_mean_reward_per_episode
|
||||
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
from ipaddress import IPv4Address
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.hardware.base import Node
|
||||
from primaite.simulator.network.protocols.dns import DNSPacket, DNSReply, DNSRequest
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.services.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.dns_server import DNSServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def dns_server() -> Node:
|
||||
node = Node(hostname="dns_server")
|
||||
node.software_manager.install(software_class=DNSServer)
|
||||
node.software_manager.software["DNSServer"].start()
|
||||
return node
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def dns_client() -> Node:
|
||||
node = Node(hostname="dns_client")
|
||||
node.software_manager.install(software_class=DNSClient)
|
||||
node.software_manager.software["DNSClient"].start()
|
||||
return node
|
||||
|
||||
|
||||
def test_create_dns_server(dns_server):
|
||||
assert dns_server is not None
|
||||
dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"]
|
||||
assert dns_server_service.name is "DNSServer"
|
||||
assert dns_server_service.port is Port.DNS
|
||||
assert dns_server_service.protocol is IPProtocol.TCP
|
||||
|
||||
|
||||
def test_create_dns_client(dns_client):
|
||||
assert dns_client is not None
|
||||
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"]
|
||||
assert dns_client_service.name is "DNSClient"
|
||||
assert dns_client_service.port is Port.DNS
|
||||
assert dns_client_service.protocol is IPProtocol.TCP
|
||||
|
||||
|
||||
def test_dns_server_domain_name_registration(dns_server):
|
||||
"""Test to check if the domain name registration works."""
|
||||
dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"]
|
||||
|
||||
# register the web server in the domain controller
|
||||
dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12"))
|
||||
|
||||
# return none for an unknown domain
|
||||
assert dns_server_service.dns_lookup("fake-domain.com") is None
|
||||
assert dns_server_service.dns_lookup("real-domain.com") is not None
|
||||
|
||||
|
||||
def test_dns_client_check_domain_in_cache(dns_client):
|
||||
"""Test to make sure that the check_domain_in_cache returns the correct values."""
|
||||
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"]
|
||||
|
||||
# add a domain to the dns client cache
|
||||
dns_client_service.add_domain_to_cache("real-domain.com", IPv4Address("192.168.1.12"))
|
||||
|
||||
assert dns_client_service.check_domain_exists("fake-domain.com") is False
|
||||
assert dns_client_service.check_domain_exists("real-domain.com") is True
|
||||
|
||||
|
||||
def test_dns_server_receive(dns_server):
|
||||
"""Test to make sure that the DNS Server correctly responds to a DNS Client request."""
|
||||
dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"]
|
||||
|
||||
# register the web server in the domain controller
|
||||
dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12"))
|
||||
|
||||
assert (
|
||||
dns_server_service.receive(payload=DNSPacket(dns_request=DNSRequest(domain_name_request="fake-domain.com")))
|
||||
is False
|
||||
)
|
||||
|
||||
assert (
|
||||
dns_server_service.receive(payload=DNSPacket(dns_request=DNSRequest(domain_name_request="real-domain.com")))
|
||||
is True
|
||||
)
|
||||
|
||||
dns_server_service.show()
|
||||
|
||||
|
||||
def test_dns_client_receive(dns_client):
|
||||
"""Test to make sure the DNS Client knows how to deal with request responses."""
|
||||
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"]
|
||||
|
||||
dns_client_service.receive(
|
||||
payload=DNSPacket(
|
||||
dns_request=DNSRequest(domain_name_request="real-domain.com"),
|
||||
dns_reply=DNSReply(domain_name_ip_address=IPv4Address("192.168.1.12")),
|
||||
)
|
||||
)
|
||||
|
||||
# domain name should be saved to cache
|
||||
assert dns_client_service.dns_cache["real-domain.com"] == IPv4Address("192.168.1.12")
|
||||
Reference in New Issue
Block a user