#1752: added more functionality to DNS client and server + tests
This commit is contained in:
@@ -192,6 +192,9 @@ class SimComponent(BaseModel):
|
||||
|
||||
:param action: List describing the action to apply to this object.
|
||||
:type action: List[str]
|
||||
|
||||
:param: context: Dict containing context for actions
|
||||
:type context: Dict
|
||||
"""
|
||||
if self.action_manager is None:
|
||||
return
|
||||
|
||||
@@ -32,27 +32,23 @@ class Session(SimComponent):
|
||||
"""
|
||||
|
||||
protocol: IPProtocol
|
||||
src_ip_address: IPv4Address
|
||||
dst_ip_address: IPv4Address
|
||||
with_ip_address: IPv4Address
|
||||
src_port: Optional[Port]
|
||||
dst_port: Optional[Port]
|
||||
connected: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_session_key(
|
||||
cls, session_key: Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]]
|
||||
) -> Session:
|
||||
def from_session_key(cls, session_key: Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]) -> Session:
|
||||
"""
|
||||
Create a Session instance from a session key tuple.
|
||||
|
||||
:param session_key: Tuple containing the session details.
|
||||
:return: A Session instance.
|
||||
"""
|
||||
protocol, src_ip_address, dst_ip_address, src_port, dst_port = session_key
|
||||
protocol, with_ip_address, src_port, dst_port = session_key
|
||||
return Session(
|
||||
protocol=protocol,
|
||||
src_ip_address=src_ip_address,
|
||||
dst_ip_address=dst_ip_address,
|
||||
with_ip_address=with_ip_address,
|
||||
src_port=src_port,
|
||||
dst_port=dst_port,
|
||||
)
|
||||
@@ -78,9 +74,7 @@ class SessionManager:
|
||||
"""
|
||||
|
||||
def __init__(self, sys_log: SysLog, arp_cache: "ARPCache"):
|
||||
self.sessions_by_key: Dict[
|
||||
Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]], Session
|
||||
] = {}
|
||||
self.sessions_by_key: Dict[Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]], Session] = {}
|
||||
self.sessions_by_uuid: Dict[str, Session] = {}
|
||||
self.sys_log: SysLog = sys_log
|
||||
self.software_manager: SoftwareManager = None # Noqa
|
||||
@@ -99,8 +93,8 @@ class SessionManager:
|
||||
|
||||
@staticmethod
|
||||
def _get_session_key(
|
||||
frame: Frame, from_source: bool = True
|
||||
) -> Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]]:
|
||||
frame: Frame, inbound_frame: bool = True
|
||||
) -> Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]:
|
||||
"""
|
||||
Extracts the session key from the given frame.
|
||||
|
||||
@@ -112,36 +106,36 @@ class SessionManager:
|
||||
- Optional[Port]: The destination port number (if applicable).
|
||||
|
||||
:param frame: The network frame from which to extract the session key.
|
||||
:param from_source: A flag to indicate if the key should be extracted from the source or destination.
|
||||
:return: A tuple containing the session key.
|
||||
"""
|
||||
protocol = frame.ip.protocol
|
||||
src_ip_address = frame.ip.src_ip_address
|
||||
dst_ip_address = frame.ip.dst_ip_address
|
||||
with_ip_address = frame.ip.src_ip_address
|
||||
if protocol == IPProtocol.TCP:
|
||||
if from_source:
|
||||
if inbound_frame:
|
||||
src_port = frame.tcp.src_port
|
||||
dst_port = frame.tcp.dst_port
|
||||
else:
|
||||
dst_port = frame.tcp.src_port
|
||||
src_port = frame.tcp.dst_port
|
||||
with_ip_address = frame.ip.dst_ip_address
|
||||
elif protocol == IPProtocol.UDP:
|
||||
if from_source:
|
||||
if inbound_frame:
|
||||
src_port = frame.udp.src_port
|
||||
dst_port = frame.udp.dst_port
|
||||
else:
|
||||
dst_port = frame.udp.src_port
|
||||
src_port = frame.udp.dst_port
|
||||
with_ip_address = frame.ip.dst_ip_address
|
||||
else:
|
||||
src_port = None
|
||||
dst_port = None
|
||||
return protocol, src_ip_address, dst_ip_address, src_port, dst_port
|
||||
return protocol, with_ip_address, src_port, dst_port
|
||||
|
||||
def receive_payload_from_software_manager(
|
||||
self,
|
||||
payload: Any,
|
||||
dest_ip_address: Optional[IPv4Address] = None,
|
||||
dest_port: Optional[Port] = None,
|
||||
dst_ip_address: Optional[IPv4Address] = None,
|
||||
dst_port: Optional[Port] = None,
|
||||
session_id: Optional[str] = None,
|
||||
is_reattempt: bool = False,
|
||||
) -> Union[Any, None]:
|
||||
@@ -154,20 +148,21 @@ class SessionManager:
|
||||
:param session_id: The Session ID the payload is to originate from. Optional. If None, one will be created.
|
||||
"""
|
||||
if session_id:
|
||||
dest_ip_address = self.sessions_by_uuid[session_id].dst_ip_address
|
||||
dest_port = self.sessions_by_uuid[session_id].dst_port
|
||||
session = self.sessions_by_uuid[session_id]
|
||||
dst_ip_address = self.sessions_by_uuid[session_id].with_ip_address
|
||||
dst_port = self.sessions_by_uuid[session_id].dst_port
|
||||
|
||||
dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dest_ip_address)
|
||||
dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dst_ip_address)
|
||||
|
||||
if dst_mac_address:
|
||||
outbound_nic = self.arp_cache.get_arp_cache_nic(dest_ip_address)
|
||||
outbound_nic = self.arp_cache.get_arp_cache_nic(dst_ip_address)
|
||||
else:
|
||||
if not is_reattempt:
|
||||
self.arp_cache.send_arp_request(dest_ip_address)
|
||||
self.arp_cache.send_arp_request(dst_ip_address)
|
||||
return self.receive_payload_from_software_manager(
|
||||
payload=payload,
|
||||
dest_ip_address=dest_ip_address,
|
||||
dest_port=dest_port,
|
||||
dst_ip_address=dst_ip_address,
|
||||
dst_port=dst_port,
|
||||
session_id=session_id,
|
||||
is_reattempt=True,
|
||||
)
|
||||
@@ -178,17 +173,17 @@ class SessionManager:
|
||||
ethernet=EthernetHeader(src_mac_addr=outbound_nic.mac_address, dst_mac_addr=dst_mac_address),
|
||||
ip=IPPacket(
|
||||
src_ip_address=outbound_nic.ip_address,
|
||||
dst_ip_address=dest_ip_address,
|
||||
dst_ip_address=dst_ip_address,
|
||||
),
|
||||
tcp=TCPHeader(
|
||||
src_port=dest_port,
|
||||
dst_port=dest_port,
|
||||
src_port=dst_port,
|
||||
dst_port=dst_port,
|
||||
),
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
if not session_id:
|
||||
session_key = self._get_session_key(frame, from_source=True)
|
||||
session_key = self._get_session_key(frame, inbound_frame=False)
|
||||
session = self.sessions_by_key.get(session_key)
|
||||
if not session:
|
||||
# Create new session
|
||||
@@ -198,33 +193,25 @@ class SessionManager:
|
||||
|
||||
outbound_nic.send_frame(frame)
|
||||
|
||||
def send_payload_to_software_manager(self, payload: Any, session_id: int):
|
||||
def receive_frame(self, frame: Frame):
|
||||
"""
|
||||
Send a payload to the software manager.
|
||||
|
||||
:param payload: The payload to be sent.
|
||||
:param session_id: The Session ID the payload originates from.
|
||||
"""
|
||||
self.software_manager.receive_payload_from_session_manger()
|
||||
|
||||
def receive_payload_from_nic(self, frame: Frame):
|
||||
"""
|
||||
Receive a Frame from the NIC.
|
||||
Receive a Frame.
|
||||
|
||||
Extract the session key using the _get_session_key method, and forward the payload to the appropriate
|
||||
session. If the session does not exist, a new one is created.
|
||||
|
||||
:param frame: The frame being received.
|
||||
"""
|
||||
session_key = self._get_session_key(frame)
|
||||
session = self.sessions_by_key.get(session_key)
|
||||
session_key = self._get_session_key(frame, inbound_frame=True)
|
||||
session: Session = self.sessions_by_key.get(session_key)
|
||||
if not session:
|
||||
# Create new session
|
||||
session = Session.from_session_key(session_key)
|
||||
self.sessions_by_key[session_key] = session
|
||||
self.sessions_by_uuid[session.uuid] = session
|
||||
self.software_manager.receive_payload_from_session_manger(payload=frame, session=session)
|
||||
# TODO: Implement the frame deconstruction and send to SoftwareManager.
|
||||
self.software_manager.receive_payload_from_session_manager(
|
||||
payload=frame.payload, port=frame.tcp.dst_port, protocol=frame.ip.protocol, session_id=session.uuid
|
||||
)
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
"""
|
||||
|
||||
@@ -6,7 +6,6 @@ from prettytable import MARKDOWN, PrettyTable
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.core.session_manager import Session
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.software import SoftwareType
|
||||
@@ -86,7 +85,7 @@ class SoftwareManager:
|
||||
payload: Any,
|
||||
dest_ip_address: Optional[IPv4Address] = None,
|
||||
dest_port: Optional[Port] = None,
|
||||
session_id: Optional[int] = None,
|
||||
session_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Send a payload to the SessionManager.
|
||||
@@ -97,22 +96,21 @@ class SoftwareManager:
|
||||
:param session_id: The Session ID the payload is to originate from. Optional.
|
||||
"""
|
||||
self.session_manager.receive_payload_from_software_manager(
|
||||
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id
|
||||
payload=payload, dst_ip_address=dest_ip_address, dst_port=dest_port, session_id=session_id
|
||||
)
|
||||
|
||||
def receive_payload_from_session_manger(self, payload: Any, session: Session):
|
||||
def receive_payload_from_session_manager(self, payload: Any, port: Port, protocol: IPProtocol, session_id: str):
|
||||
"""
|
||||
Receive a payload from the SessionManager and forward it to the corresponding service or application.
|
||||
|
||||
:param payload: The payload being received.
|
||||
:param session: The transport session the payload originates from.
|
||||
"""
|
||||
# receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None)
|
||||
# if receiver:
|
||||
# receiver.receive_payload(None, payload)
|
||||
# else:
|
||||
# raise ValueError(f"No service or application found for port {port} and protocol {protocol}")
|
||||
pass
|
||||
receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None)
|
||||
if receiver:
|
||||
receiver.receive_payload(payload=payload, session_id=session_id)
|
||||
else:
|
||||
self.sys_log.error(f"No service or application found for port {port} and protocol {protocol}")
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
"""
|
||||
|
||||
@@ -1,19 +1,26 @@
|
||||
from abc import abstractmethod
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.services.service import Service
|
||||
|
||||
|
||||
class DNSClient(BaseModel):
|
||||
class DNSClient(Service):
|
||||
"""Represents a DNS Client as a Service."""
|
||||
|
||||
dns_cache: Dict[str:IPv4Address] = {}
|
||||
dns_cache: Dict[str, IPv4Address] = {}
|
||||
"A dict of known mappings between domain/URLs names and IPv4 addresses."
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "DNSClient"
|
||||
kwargs["port"] = Port.DNS
|
||||
# DNS uses UDP by default
|
||||
# it switches to TCP when the bytes exceed 512 (or 4096) bytes
|
||||
kwargs["protocol"] = IPProtocol.UDP
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the software.
|
||||
@@ -26,57 +33,109 @@ class DNSClient(BaseModel):
|
||||
"""
|
||||
return {"Operating State": self.operating_state}
|
||||
|
||||
def apply_action(self, action: List[str]) -> None:
|
||||
"""
|
||||
Applies a list of actions to the Service.
|
||||
|
||||
:param action: A list of actions to apply.
|
||||
"""
|
||||
pass
|
||||
|
||||
def reset_component_for_episode(self):
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
Resets the Service component for a new episode.
|
||||
|
||||
This method ensures the Service is ready for a new episode, including resetting any
|
||||
stateful properties or statistics, and clearing any message queues.
|
||||
"""
|
||||
super().reset_component_for_episode(episode=episode)
|
||||
self.dns_cache = {}
|
||||
|
||||
def check_domain_in_cache(self, target_domain: str, session_id: str):
|
||||
def add_domain_to_cache(self, domain_name: str, ip_address: IPv4Address):
|
||||
"""
|
||||
Adds a domain name to the DNS Client cache.
|
||||
|
||||
:param: domain_name: The domain name to save to cache
|
||||
:param: ip_address: The IP Address to attach the domain name to
|
||||
"""
|
||||
self.dns_cache[domain_name] = ip_address
|
||||
|
||||
def check_domain_in_cache(
|
||||
self,
|
||||
target_domain: str,
|
||||
dest_ip_address: Optional[IPv4Address] = None,
|
||||
dest_port: Optional[Port] = None,
|
||||
session_id: Optional[str] = None,
|
||||
is_reattempt: bool = False,
|
||||
) -> bool:
|
||||
"""Function to check if domain name is in DNS client cache.
|
||||
|
||||
:param target_domain: The domain requested for an IP address.
|
||||
:param session_id: The ID of the session in order to send the response to the DNS server or application.
|
||||
:param: target_domain: The domain requested for an IP address.
|
||||
:param: dest_ip_address: The ip address of the payload destination.
|
||||
:param: dest_port: The port of the payload destination.
|
||||
:param: session_id: The Session ID the payload is to originate from. Optional.
|
||||
:param: is_reattempt: Checks if the request has been reattempted. Default is False.
|
||||
"""
|
||||
if target_domain in self.dns_cache:
|
||||
ip_address = self.dns_cache[target_domain]
|
||||
self.send(ip_address, session_id)
|
||||
else:
|
||||
self.send(target_domain, session_id)
|
||||
# check if the target domain is in the client's DNS cache
|
||||
payload = DNSPacket(dns_request=DNSRequest(domain_name_request=target_domain))
|
||||
|
||||
def send(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
# check if the domain is already in the DNS cache
|
||||
if target_domain in self.dns_cache:
|
||||
return True
|
||||
else:
|
||||
# return False if already reattempted
|
||||
if is_reattempt:
|
||||
return False
|
||||
else:
|
||||
# send a request to check if domain name exists in the DNS Server
|
||||
self.send(payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id)
|
||||
# call function again
|
||||
return self.check_domain_in_cache(
|
||||
target_domain=target_domain,
|
||||
dest_ip_address=dest_ip_address,
|
||||
dest_port=dest_port,
|
||||
session_id=session_id,
|
||||
is_reattempt=True,
|
||||
)
|
||||
|
||||
def send(
|
||||
self,
|
||||
payload: Any,
|
||||
dest_ip_address: Optional[IPv4Address] = None,
|
||||
dest_port: Optional[Port] = None,
|
||||
session_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""
|
||||
Sends a payload to the SessionManager.
|
||||
|
||||
The specifics of how the payload is processed and whether a response payload
|
||||
is generated should be implemented in subclasses.
|
||||
|
||||
:param payload: The payload to send.
|
||||
:param payload: The payload to be sent.
|
||||
:param dest_ip_address: The ip address of the payload destination.
|
||||
:param dest_port: The port of the payload destination.
|
||||
:param session_id: The Session ID the payload is to originate from. Optional.
|
||||
|
||||
:return: True if successful, False otherwise.
|
||||
"""
|
||||
DNSPacket(dns_request=DNSRequest(domain_name_request=payload), dns_reply=None)
|
||||
# create DNS request packet
|
||||
self.software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id
|
||||
)
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
def receive(
|
||||
self,
|
||||
payload: Any,
|
||||
dest_ip_address: Optional[IPv4Address] = None,
|
||||
dest_port: Optional[Port] = None,
|
||||
session_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""
|
||||
Receives a payload from the SessionManager.
|
||||
|
||||
The specifics of how the payload is processed and whether a response payload
|
||||
is generated should be implemented in subclasses.
|
||||
|
||||
:param payload: The payload to receive. (receive a DNS packet with dns request and dns reply in, send to web
|
||||
browser)
|
||||
:param payload: The payload to be sent.
|
||||
:param dest_ip_address: The ip address of the payload destination.
|
||||
:param dest_port: The port of the payload destination.
|
||||
:param session_id: The Session ID the payload is to originate from. Optional.
|
||||
:return: True if successful, False otherwise.
|
||||
"""
|
||||
super().send()
|
||||
# check the DNS packet (dns request, dns reply) here and see if it actually worked
|
||||
pass
|
||||
|
||||
@@ -1,19 +1,31 @@
|
||||
from abc import abstractmethod
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite.simulator.network.protocols.dns import DNSPacket, DNSReply, DNSRequest
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.dns import DNSPacket
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.services.service import Service
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class DNSServer(BaseModel):
|
||||
class DNSServer(Service):
|
||||
"""Represents a DNS Server as a Service."""
|
||||
|
||||
dns_table: dict[str:IPv4Address] = {}
|
||||
dns_table: Dict[str, IPv4Address] = {}
|
||||
"A dict of mappings between domain names and IPv4 addresses."
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "DNSServer"
|
||||
kwargs["port"] = Port.DNS
|
||||
# DNS uses UDP by default
|
||||
# it switches to TCP when the bytes exceed 512 (or 4096) bytes
|
||||
kwargs["protocol"] = IPProtocol.UDP
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the software.
|
||||
@@ -26,15 +38,7 @@ class DNSServer(BaseModel):
|
||||
"""
|
||||
return {"Operating State": self.operating_state}
|
||||
|
||||
def apply_action(self, action: List[str]) -> None:
|
||||
"""
|
||||
Applies a list of actions to the Service.
|
||||
|
||||
:param action: A list of actions to apply. (unsure)
|
||||
"""
|
||||
pass
|
||||
|
||||
def dns_lookup(self, target_domain: str) -> Optional[IPv4Address]:
|
||||
def dns_lookup(self, target_domain: Any) -> Optional[IPv4Address]:
|
||||
"""
|
||||
Attempts to find the IP address for a domain name.
|
||||
|
||||
@@ -42,11 +46,23 @@ class DNSServer(BaseModel):
|
||||
:return ip_address: The IP address of that domain name or None.
|
||||
"""
|
||||
if target_domain in self.dns_table:
|
||||
self.dns_table[target_domain]
|
||||
return self.dns_table[target_domain]
|
||||
else:
|
||||
return None
|
||||
|
||||
def reset_component_for_episode(self):
|
||||
def dns_register(self, domain_name: str, domain_ip_address: IPv4Address):
|
||||
"""
|
||||
Register a domain name and its IP address.
|
||||
|
||||
:param: domain_name: The domain name to register
|
||||
:type: domain_name: str
|
||||
|
||||
:param: domain_ip_address: The IP address that the domain should route to
|
||||
:type: domain_ip_address: IPv4Address
|
||||
"""
|
||||
self.dns_table[domain_name] = domain_ip_address
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
Resets the Service component for a new episode.
|
||||
|
||||
@@ -54,36 +70,78 @@ class DNSServer(BaseModel):
|
||||
stateful properties or statistics, and clearing any message queues.
|
||||
"""
|
||||
self.dns_table = {}
|
||||
super().reset_component_for_episode(episode=episode)
|
||||
|
||||
def send(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
def send(
|
||||
self,
|
||||
payload: Any,
|
||||
dest_ip_address: Optional[IPv4Address] = None,
|
||||
dest_port: Optional[Port] = None,
|
||||
session_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""
|
||||
Sends a payload to the SessionManager.
|
||||
|
||||
The specifics of how the payload is processed and whether a response payload
|
||||
is generated should be implemented in subclasses.
|
||||
|
||||
:param payload: The payload to send.
|
||||
:param: payload: The payload to send.
|
||||
:param: dest_ip_address: The ip address of the machine that the payload will be sent to
|
||||
:param: dest_port: The port of the machine that the payload will be sent to
|
||||
:param: session_id: The id of the session
|
||||
|
||||
:return: True if successful, False otherwise.
|
||||
"""
|
||||
# DNS packet will be sent from DNS Server to the DNS client
|
||||
DNSPacket(
|
||||
dns_request=DNSRequest(domain_name_request=self.dns_table),
|
||||
dns_reply=DNSReply(domain_name_ip_address=payload),
|
||||
)
|
||||
try:
|
||||
self.software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id)
|
||||
except Exception as e:
|
||||
_LOGGER.error(e)
|
||||
return False
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
def receive(
|
||||
self,
|
||||
payload: Any,
|
||||
dest_ip_address: Optional[IPv4Address] = None,
|
||||
dest_port: Optional[Port] = None,
|
||||
session_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""
|
||||
Receives a payload from the SessionManager.
|
||||
|
||||
The specifics of how the payload is processed and whether a response payload
|
||||
is generated should be implemented in subclasses.
|
||||
|
||||
:param payload: The payload to receive. (take the domain name and do dns lookup)
|
||||
:param: payload: The payload to send.
|
||||
:param: dest_ip_address: The ip address of the machine that the payload will be sent to
|
||||
:param: dest_port: The port of the machine that the payload will be sent to
|
||||
:param: session_id: The id of the session
|
||||
|
||||
:return: True if successful, False otherwise.
|
||||
"""
|
||||
ip_address = self.dns_lookup(payload)
|
||||
if ip_address is not None:
|
||||
self.send(ip_address, session_id)
|
||||
# The payload should be a DNS packet
|
||||
if not isinstance(payload, DNSPacket):
|
||||
_LOGGER.debug(f"{payload} is not a DNSPacket")
|
||||
return False
|
||||
# cast payload into a DNS packet
|
||||
payload: DNSPacket = payload
|
||||
if payload.dns_request is not None:
|
||||
# generate a reply with the correct DNS IP address
|
||||
payload.generate_reply(self.dns_lookup(payload.dns_request.domain_name_request))
|
||||
# send reply
|
||||
self.send(payload, session_id)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
"""Prints a table of DNS Lookup table."""
|
||||
table = PrettyTable(["Domain Name", "IP Address"])
|
||||
if markdown:
|
||||
table.set_style(MARKDOWN)
|
||||
table.align = "l"
|
||||
table.title = f"{self.sys_log.hostname} DNS Lookup table"
|
||||
for dns in self.dns_table.items():
|
||||
table.add_row([dns[0], dns[1]])
|
||||
print(table)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.core import Action, ActionManager
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.software import IOSoftware
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -72,29 +74,54 @@ class Service(IOSoftware):
|
||||
"""
|
||||
pass
|
||||
|
||||
def send(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
def send(
|
||||
self,
|
||||
payload: Any,
|
||||
dest_ip_address: Optional[IPv4Address] = None,
|
||||
dest_port: Optional[Port] = None,
|
||||
session_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""
|
||||
Sends a payload to the SessionManager.
|
||||
|
||||
The specifics of how the payload is processed and whether a response payload
|
||||
is generated should be implemented in subclasses.
|
||||
|
||||
:param payload: The payload to send.
|
||||
:param: payload: The payload to send.
|
||||
:param: dest_ip_address: The ip address of the machine that the payload will be sent to
|
||||
:param: dest_port: The port of the machine that the payload will be sent to
|
||||
:param: session_id: The id of the session
|
||||
|
||||
:return: True if successful, False otherwise.
|
||||
"""
|
||||
pass
|
||||
self.software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port, session_id=session_id
|
||||
)
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
def receive(
|
||||
self,
|
||||
payload: Any,
|
||||
dest_ip_address: Optional[IPv4Address] = None,
|
||||
dest_port: Optional[Port] = None,
|
||||
session_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""
|
||||
Receives a payload from the SessionManager.
|
||||
|
||||
The specifics of how the payload is processed and whether a response payload
|
||||
is generated should be implemented in subclasses.
|
||||
|
||||
:param payload: The payload to receive.
|
||||
:param: payload: The payload to send.
|
||||
:param: dest_ip_address: The ip address of the machine that the payload will be sent to
|
||||
:param: dest_port: The port of the machine that the payload will be sent to
|
||||
:param: session_id: The id of the session
|
||||
|
||||
:return: True if successful, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
pass
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the service."""
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
import sys
|
||||
from ipaddress import IPv4Address
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.hardware.base import Node
|
||||
from primaite.simulator.network.networks import arcd_uc2_network
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.services.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.dns_server import DNSServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def dns_server() -> Node:
|
||||
node = Node(hostname="dns_server")
|
||||
node.software_manager.add_service(service_class=DNSServer)
|
||||
node.software_manager.services["DNSServer"].start()
|
||||
return node
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def dns_client() -> Node:
|
||||
node = Node(hostname="dns_client")
|
||||
node.software_manager.add_service(service_class=DNSClient)
|
||||
node.software_manager.services["DNSClient"].start()
|
||||
return node
|
||||
|
||||
|
||||
def test_create_dns_server(dns_server):
|
||||
assert dns_server is not None
|
||||
dns_server_service: DNSServer = dns_server.software_manager.services["DNSServer"]
|
||||
assert dns_server_service.name is "DNSServer"
|
||||
assert dns_server_service.port is Port.DNS
|
||||
assert dns_server_service.protocol is IPProtocol.UDP
|
||||
|
||||
|
||||
def test_create_dns_client(dns_client):
|
||||
assert dns_client is not None
|
||||
dns_client_service: DNSClient = dns_client.software_manager.services["DNSClient"]
|
||||
assert dns_client_service.name is "DNSClient"
|
||||
assert dns_client_service.port is Port.DNS
|
||||
assert dns_client_service.protocol is IPProtocol.UDP
|
||||
|
||||
|
||||
def test_dns_server_domain_name_registration(dns_server):
|
||||
"""Test to check if the domain name registration works."""
|
||||
dns_server_service: DNSServer = dns_server.software_manager.services["DNSServer"]
|
||||
|
||||
# register the web server in the domain controller
|
||||
dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12"))
|
||||
|
||||
# return none for an unknown domain
|
||||
assert dns_server_service.dns_lookup("fake-domain.com") is None
|
||||
assert dns_server_service.dns_lookup("real-domain.com") is not None
|
||||
|
||||
|
||||
def test_dns_client_check_domain_in_cache(dns_client):
|
||||
"""Test to make sure that the check_domain_in_cache returns the correct values."""
|
||||
dns_client_service: DNSClient = dns_client.software_manager.services["DNSClient"]
|
||||
|
||||
# add a domain to the dns client cache
|
||||
dns_client_service.add_domain_to_cache("real-domain.com", IPv4Address("192.168.1.12"))
|
||||
|
||||
assert dns_client_service.check_domain_in_cache("fake-domain.com") is False
|
||||
assert dns_client_service.check_domain_in_cache("real-domain.com") is True
|
||||
Reference in New Issue
Block a user