#1752: added more functionality to DNS client and server + tests

This commit is contained in:
Czar Echavez
2023-09-07 15:45:37 +01:00
parent 2cb0c238c9
commit 47dd23311b
7 changed files with 321 additions and 123 deletions

View File

@@ -192,6 +192,9 @@ class SimComponent(BaseModel):
:param action: List describing the action to apply to this object.
:type action: List[str]
:param: context: Dict containing context for actions
:type context: Dict
"""
if self.action_manager is None:
return

View File

@@ -32,27 +32,23 @@ class Session(SimComponent):
"""
protocol: IPProtocol
src_ip_address: IPv4Address
dst_ip_address: IPv4Address
with_ip_address: IPv4Address
src_port: Optional[Port]
dst_port: Optional[Port]
connected: bool = False
@classmethod
def from_session_key(
cls, session_key: Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]]
) -> Session:
def from_session_key(cls, session_key: Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]) -> Session:
"""
Create a Session instance from a session key tuple.
:param session_key: Tuple containing the session details.
:return: A Session instance.
"""
protocol, src_ip_address, dst_ip_address, src_port, dst_port = session_key
protocol, with_ip_address, src_port, dst_port = session_key
return Session(
protocol=protocol,
src_ip_address=src_ip_address,
dst_ip_address=dst_ip_address,
with_ip_address=with_ip_address,
src_port=src_port,
dst_port=dst_port,
)
@@ -78,9 +74,7 @@ class SessionManager:
"""
def __init__(self, sys_log: SysLog, arp_cache: "ARPCache"):
self.sessions_by_key: Dict[
Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]], Session
] = {}
self.sessions_by_key: Dict[Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]], Session] = {}
self.sessions_by_uuid: Dict[str, Session] = {}
self.sys_log: SysLog = sys_log
self.software_manager: SoftwareManager = None # Noqa
@@ -99,8 +93,8 @@ class SessionManager:
@staticmethod
def _get_session_key(
frame: Frame, from_source: bool = True
) -> Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]]:
frame: Frame, inbound_frame: bool = True
) -> Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]:
"""
Extracts the session key from the given frame.
@@ -112,36 +106,36 @@ class SessionManager:
- Optional[Port]: The destination port number (if applicable).
:param frame: The network frame from which to extract the session key.
:param from_source: A flag to indicate if the key should be extracted from the source or destination.
:return: A tuple containing the session key.
"""
protocol = frame.ip.protocol
src_ip_address = frame.ip.src_ip_address
dst_ip_address = frame.ip.dst_ip_address
with_ip_address = frame.ip.src_ip_address
if protocol == IPProtocol.TCP:
if from_source:
if inbound_frame:
src_port = frame.tcp.src_port
dst_port = frame.tcp.dst_port
else:
dst_port = frame.tcp.src_port
src_port = frame.tcp.dst_port
with_ip_address = frame.ip.dst_ip_address
elif protocol == IPProtocol.UDP:
if from_source:
if inbound_frame:
src_port = frame.udp.src_port
dst_port = frame.udp.dst_port
else:
dst_port = frame.udp.src_port
src_port = frame.udp.dst_port
with_ip_address = frame.ip.dst_ip_address
else:
src_port = None
dst_port = None
return protocol, src_ip_address, dst_ip_address, src_port, dst_port
return protocol, with_ip_address, src_port, dst_port
def receive_payload_from_software_manager(
self,
payload: Any,
dest_ip_address: Optional[IPv4Address] = None,
dest_port: Optional[Port] = None,
dst_ip_address: Optional[IPv4Address] = None,
dst_port: Optional[Port] = None,
session_id: Optional[str] = None,
is_reattempt: bool = False,
) -> Union[Any, None]:
@@ -154,20 +148,21 @@ class SessionManager:
:param session_id: The Session ID the payload is to originate from. Optional. If None, one will be created.
"""
if session_id:
dest_ip_address = self.sessions_by_uuid[session_id].dst_ip_address
dest_port = self.sessions_by_uuid[session_id].dst_port
session = self.sessions_by_uuid[session_id]
dst_ip_address = self.sessions_by_uuid[session_id].with_ip_address
dst_port = self.sessions_by_uuid[session_id].dst_port
dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dest_ip_address)
dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dst_ip_address)
if dst_mac_address:
outbound_nic = self.arp_cache.get_arp_cache_nic(dest_ip_address)
outbound_nic = self.arp_cache.get_arp_cache_nic(dst_ip_address)
else:
if not is_reattempt:
self.arp_cache.send_arp_request(dest_ip_address)
self.arp_cache.send_arp_request(dst_ip_address)
return self.receive_payload_from_software_manager(
payload=payload,
dest_ip_address=dest_ip_address,
dest_port=dest_port,
dst_ip_address=dst_ip_address,
dst_port=dst_port,
session_id=session_id,
is_reattempt=True,
)
@@ -178,17 +173,17 @@ class SessionManager:
ethernet=EthernetHeader(src_mac_addr=outbound_nic.mac_address, dst_mac_addr=dst_mac_address),
ip=IPPacket(
src_ip_address=outbound_nic.ip_address,
dst_ip_address=dest_ip_address,
dst_ip_address=dst_ip_address,
),
tcp=TCPHeader(
src_port=dest_port,
dst_port=dest_port,
src_port=dst_port,
dst_port=dst_port,
),
payload=payload,
)
if not session_id:
session_key = self._get_session_key(frame, from_source=True)
session_key = self._get_session_key(frame, inbound_frame=False)
session = self.sessions_by_key.get(session_key)
if not session:
# Create new session
@@ -198,33 +193,25 @@ class SessionManager:
outbound_nic.send_frame(frame)
def send_payload_to_software_manager(self, payload: Any, session_id: int):
def receive_frame(self, frame: Frame):
"""
Send a payload to the software manager.
:param payload: The payload to be sent.
:param session_id: The Session ID the payload originates from.
"""
self.software_manager.receive_payload_from_session_manger()
def receive_payload_from_nic(self, frame: Frame):
"""
Receive a Frame from the NIC.
Receive a Frame.
Extract the session key using the _get_session_key method, and forward the payload to the appropriate
session. If the session does not exist, a new one is created.
:param frame: The frame being received.
"""
session_key = self._get_session_key(frame)
session = self.sessions_by_key.get(session_key)
session_key = self._get_session_key(frame, inbound_frame=True)
session: Session = self.sessions_by_key.get(session_key)
if not session:
# Create new session
session = Session.from_session_key(session_key)
self.sessions_by_key[session_key] = session
self.sessions_by_uuid[session.uuid] = session
self.software_manager.receive_payload_from_session_manger(payload=frame, session=session)
# TODO: Implement the frame deconstruction and send to SoftwareManager.
self.software_manager.receive_payload_from_session_manager(
payload=frame.payload, port=frame.tcp.dst_port, protocol=frame.ip.protocol, session_id=session.uuid
)
def show(self, markdown: bool = False):
"""

View File

@@ -6,7 +6,6 @@ from prettytable import MARKDOWN, PrettyTable
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.core.session_manager import Session
from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.software import SoftwareType
@@ -86,7 +85,7 @@ class SoftwareManager:
payload: Any,
dest_ip_address: Optional[IPv4Address] = None,
dest_port: Optional[Port] = None,
session_id: Optional[int] = None,
session_id: Optional[str] = None,
):
"""
Send a payload to the SessionManager.
@@ -97,22 +96,21 @@ class SoftwareManager:
:param session_id: The Session ID the payload is to originate from. Optional.
"""
self.session_manager.receive_payload_from_software_manager(
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id
payload=payload, dst_ip_address=dest_ip_address, dst_port=dest_port, session_id=session_id
)
def receive_payload_from_session_manger(self, payload: Any, session: Session):
def receive_payload_from_session_manager(self, payload: Any, port: Port, protocol: IPProtocol, session_id: str):
"""
Receive a payload from the SessionManager and forward it to the corresponding service or application.
:param payload: The payload being received.
:param session: The transport session the payload originates from.
"""
# receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None)
# if receiver:
# receiver.receive_payload(None, payload)
# else:
# raise ValueError(f"No service or application found for port {port} and protocol {protocol}")
pass
receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None)
if receiver:
receiver.receive_payload(payload=payload, session_id=session_id)
else:
self.sys_log.error(f"No service or application found for port {port} and protocol {protocol}")
def show(self, markdown: bool = False):
"""

View File

@@ -1,19 +1,26 @@
from abc import abstractmethod
from ipaddress import IPv4Address
from typing import Any, Dict, List
from pydantic import BaseModel
from typing import Any, Dict, Optional
from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.service import Service
class DNSClient(BaseModel):
class DNSClient(Service):
"""Represents a DNS Client as a Service."""
dns_cache: Dict[str:IPv4Address] = {}
dns_cache: Dict[str, IPv4Address] = {}
"A dict of known mappings between domain/URLs names and IPv4 addresses."
@abstractmethod
def __init__(self, **kwargs):
kwargs["name"] = "DNSClient"
kwargs["port"] = Port.DNS
# DNS uses UDP by default
# it switches to TCP when the bytes exceed 512 (or 4096) bytes
kwargs["protocol"] = IPProtocol.UDP
super().__init__(**kwargs)
def describe_state(self) -> Dict:
"""
Describes the current state of the software.
@@ -26,57 +33,109 @@ class DNSClient(BaseModel):
"""
return {"Operating State": self.operating_state}
def apply_action(self, action: List[str]) -> None:
"""
Applies a list of actions to the Service.
:param action: A list of actions to apply.
"""
pass
def reset_component_for_episode(self):
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
This method ensures the Service is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
super().reset_component_for_episode(episode=episode)
self.dns_cache = {}
def check_domain_in_cache(self, target_domain: str, session_id: str):
def add_domain_to_cache(self, domain_name: str, ip_address: IPv4Address):
"""
Adds a domain name to the DNS Client cache.
:param: domain_name: The domain name to save to cache
:param: ip_address: The IP Address to attach the domain name to
"""
self.dns_cache[domain_name] = ip_address
def check_domain_in_cache(
self,
target_domain: str,
dest_ip_address: Optional[IPv4Address] = None,
dest_port: Optional[Port] = None,
session_id: Optional[str] = None,
is_reattempt: bool = False,
) -> bool:
"""Function to check if domain name is in DNS client cache.
:param target_domain: The domain requested for an IP address.
:param session_id: The ID of the session in order to send the response to the DNS server or application.
:param: target_domain: The domain requested for an IP address.
:param: dest_ip_address: The ip address of the payload destination.
:param: dest_port: The port of the payload destination.
:param: session_id: The Session ID the payload is to originate from. Optional.
:param: is_reattempt: Checks if the request has been reattempted. Default is False.
"""
if target_domain in self.dns_cache:
ip_address = self.dns_cache[target_domain]
self.send(ip_address, session_id)
else:
self.send(target_domain, session_id)
# check if the target domain is in the client's DNS cache
payload = DNSPacket(dns_request=DNSRequest(domain_name_request=target_domain))
def send(self, payload: Any, session_id: str, **kwargs) -> bool:
# check if the domain is already in the DNS cache
if target_domain in self.dns_cache:
return True
else:
# return False if already reattempted
if is_reattempt:
return False
else:
# send a request to check if domain name exists in the DNS Server
self.send(payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id)
# call function again
return self.check_domain_in_cache(
target_domain=target_domain,
dest_ip_address=dest_ip_address,
dest_port=dest_port,
session_id=session_id,
is_reattempt=True,
)
def send(
self,
payload: Any,
dest_ip_address: Optional[IPv4Address] = None,
dest_port: Optional[Port] = None,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""
Sends a payload to the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to send.
:param payload: The payload to be sent.
:param dest_ip_address: The ip address of the payload destination.
:param dest_port: The port of the payload destination.
:param session_id: The Session ID the payload is to originate from. Optional.
:return: True if successful, False otherwise.
"""
DNSPacket(dns_request=DNSRequest(domain_name_request=payload), dns_reply=None)
# create DNS request packet
self.software_manager.send_payload_to_session_manager(
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id
)
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
def receive(
self,
payload: Any,
dest_ip_address: Optional[IPv4Address] = None,
dest_port: Optional[Port] = None,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""
Receives a payload from the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to receive. (receive a DNS packet with dns request and dns reply in, send to web
browser)
:param payload: The payload to be sent.
:param dest_ip_address: The ip address of the payload destination.
:param dest_port: The port of the payload destination.
:param session_id: The Session ID the payload is to originate from. Optional.
:return: True if successful, False otherwise.
"""
super().send()
# check the DNS packet (dns request, dns reply) here and see if it actually worked
pass

View File

@@ -1,19 +1,31 @@
from abc import abstractmethod
from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
from pydantic import BaseModel
from prettytable import MARKDOWN, PrettyTable
from primaite.simulator.network.protocols.dns import DNSPacket, DNSReply, DNSRequest
from primaite import getLogger
from primaite.simulator.network.protocols.dns import DNSPacket
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.service import Service
_LOGGER = getLogger(__name__)
class DNSServer(BaseModel):
class DNSServer(Service):
"""Represents a DNS Server as a Service."""
dns_table: dict[str:IPv4Address] = {}
dns_table: Dict[str, IPv4Address] = {}
"A dict of mappings between domain names and IPv4 addresses."
@abstractmethod
def __init__(self, **kwargs):
kwargs["name"] = "DNSServer"
kwargs["port"] = Port.DNS
# DNS uses UDP by default
# it switches to TCP when the bytes exceed 512 (or 4096) bytes
kwargs["protocol"] = IPProtocol.UDP
super().__init__(**kwargs)
def describe_state(self) -> Dict:
"""
Describes the current state of the software.
@@ -26,15 +38,7 @@ class DNSServer(BaseModel):
"""
return {"Operating State": self.operating_state}
def apply_action(self, action: List[str]) -> None:
"""
Applies a list of actions to the Service.
:param action: A list of actions to apply. (unsure)
"""
pass
def dns_lookup(self, target_domain: str) -> Optional[IPv4Address]:
def dns_lookup(self, target_domain: Any) -> Optional[IPv4Address]:
"""
Attempts to find the IP address for a domain name.
@@ -42,11 +46,23 @@ class DNSServer(BaseModel):
:return ip_address: The IP address of that domain name or None.
"""
if target_domain in self.dns_table:
self.dns_table[target_domain]
return self.dns_table[target_domain]
else:
return None
def reset_component_for_episode(self):
def dns_register(self, domain_name: str, domain_ip_address: IPv4Address):
"""
Register a domain name and its IP address.
:param: domain_name: The domain name to register
:type: domain_name: str
:param: domain_ip_address: The IP address that the domain should route to
:type: domain_ip_address: IPv4Address
"""
self.dns_table[domain_name] = domain_ip_address
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
@@ -54,36 +70,78 @@ class DNSServer(BaseModel):
stateful properties or statistics, and clearing any message queues.
"""
self.dns_table = {}
super().reset_component_for_episode(episode=episode)
def send(self, payload: Any, session_id: str, **kwargs) -> bool:
def send(
self,
payload: Any,
dest_ip_address: Optional[IPv4Address] = None,
dest_port: Optional[Port] = None,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""
Sends a payload to the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to send.
:param: payload: The payload to send.
:param: dest_ip_address: The ip address of the machine that the payload will be sent to
:param: dest_port: The port of the machine that the payload will be sent to
:param: session_id: The id of the session
:return: True if successful, False otherwise.
"""
# DNS packet will be sent from DNS Server to the DNS client
DNSPacket(
dns_request=DNSRequest(domain_name_request=self.dns_table),
dns_reply=DNSReply(domain_name_ip_address=payload),
)
try:
self.software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id)
except Exception as e:
_LOGGER.error(e)
return False
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
def receive(
self,
payload: Any,
dest_ip_address: Optional[IPv4Address] = None,
dest_port: Optional[Port] = None,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""
Receives a payload from the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to receive. (take the domain name and do dns lookup)
:param: payload: The payload to send.
:param: dest_ip_address: The ip address of the machine that the payload will be sent to
:param: dest_port: The port of the machine that the payload will be sent to
:param: session_id: The id of the session
:return: True if successful, False otherwise.
"""
ip_address = self.dns_lookup(payload)
if ip_address is not None:
self.send(ip_address, session_id)
# The payload should be a DNS packet
if not isinstance(payload, DNSPacket):
_LOGGER.debug(f"{payload} is not a DNSPacket")
return False
# cast payload into a DNS packet
payload: DNSPacket = payload
if payload.dns_request is not None:
# generate a reply with the correct DNS IP address
payload.generate_reply(self.dns_lookup(payload.dns_request.domain_name_request))
# send reply
self.send(payload, session_id)
return True
return False
def show(self, markdown: bool = False):
"""Prints a table of DNS Lookup table."""
table = PrettyTable(["Domain Name", "IP Address"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.sys_log.hostname} DNS Lookup table"
for dns in self.dns_table.items():
table.add_row([dns[0], dns[1]])
print(table)

View File

@@ -1,8 +1,10 @@
from enum import Enum
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from primaite import getLogger
from primaite.simulator.core import Action, ActionManager
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.software import IOSoftware
_LOGGER = getLogger(__name__)
@@ -72,29 +74,54 @@ class Service(IOSoftware):
"""
pass
def send(self, payload: Any, session_id: str, **kwargs) -> bool:
def send(
self,
payload: Any,
dest_ip_address: Optional[IPv4Address] = None,
dest_port: Optional[Port] = None,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""
Sends a payload to the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to send.
:param: payload: The payload to send.
:param: dest_ip_address: The ip address of the machine that the payload will be sent to
:param: dest_port: The port of the machine that the payload will be sent to
:param: session_id: The id of the session
:return: True if successful, False otherwise.
"""
pass
self.software_manager.send_payload_to_session_manager(
payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port, session_id=session_id
)
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
def receive(
self,
payload: Any,
dest_ip_address: Optional[IPv4Address] = None,
dest_port: Optional[Port] = None,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""
Receives a payload from the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to receive.
:param: payload: The payload to send.
:param: dest_ip_address: The ip address of the machine that the payload will be sent to
:param: dest_port: The port of the machine that the payload will be sent to
:param: session_id: The id of the session
:return: True if successful, False otherwise.
"""
pass
pass
def stop(self) -> None:
"""Stop the service."""

View File

@@ -0,0 +1,66 @@
import sys
from ipaddress import IPv4Address
import pytest
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.networks import arcd_uc2_network
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.dns_client import DNSClient
from primaite.simulator.system.services.dns_server import DNSServer
@pytest.fixture(scope="function")
def dns_server() -> Node:
node = Node(hostname="dns_server")
node.software_manager.add_service(service_class=DNSServer)
node.software_manager.services["DNSServer"].start()
return node
@pytest.fixture(scope="function")
def dns_client() -> Node:
node = Node(hostname="dns_client")
node.software_manager.add_service(service_class=DNSClient)
node.software_manager.services["DNSClient"].start()
return node
def test_create_dns_server(dns_server):
assert dns_server is not None
dns_server_service: DNSServer = dns_server.software_manager.services["DNSServer"]
assert dns_server_service.name is "DNSServer"
assert dns_server_service.port is Port.DNS
assert dns_server_service.protocol is IPProtocol.UDP
def test_create_dns_client(dns_client):
assert dns_client is not None
dns_client_service: DNSClient = dns_client.software_manager.services["DNSClient"]
assert dns_client_service.name is "DNSClient"
assert dns_client_service.port is Port.DNS
assert dns_client_service.protocol is IPProtocol.UDP
def test_dns_server_domain_name_registration(dns_server):
"""Test to check if the domain name registration works."""
dns_server_service: DNSServer = dns_server.software_manager.services["DNSServer"]
# register the web server in the domain controller
dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12"))
# return none for an unknown domain
assert dns_server_service.dns_lookup("fake-domain.com") is None
assert dns_server_service.dns_lookup("real-domain.com") is not None
def test_dns_client_check_domain_in_cache(dns_client):
"""Test to make sure that the check_domain_in_cache returns the correct values."""
dns_client_service: DNSClient = dns_client.software_manager.services["DNSClient"]
# add a domain to the dns client cache
dns_client_service.add_domain_to_cache("real-domain.com", IPv4Address("192.168.1.12"))
assert dns_client_service.check_domain_in_cache("fake-domain.com") is False
assert dns_client_service.check_domain_in_cache("real-domain.com") is True