Merged PR 177: DNS Server and Client

## Summary
Carrying on where Sunil left the task. Added more functionality to the DNS.
- DNS Server can keep a list of domain names that the DNS Client can request
- DNS Client has a cache of domain names that are valid from the DNS Server

## Test process
https://dev.azure.com/ma-dev-uk/PrimAITE/_git/PrimAITE/pullrequest/177?_a=files&path=/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns.py

## Checklist
- [x] This PR is linked to a **work item**
- [x] I have performed **self-review** of the code
- [x] I have written **tests** for any new functionality added with this PR
- [X] I have updated the **documentation** if this PR changes or adds functionality
- [ ] I have written/updated **design docs** if this PR implements new functionality
- [X] I have update the **change log**
- [x] I have run **pre-commit** checks for code style

Related work items: #1752
This commit is contained in:
Czar Echavez
2023-09-18 10:31:16 +00:00
committed by Christopher McCarthy
17 changed files with 667 additions and 20 deletions

2
.gitignore vendored
View File

@@ -144,9 +144,11 @@ cython_debug/
# IDE
.idea/
docs/source/primaite-dependencies.rst
.vscode/
# outputs
src/primaite/outputs/
simulation_output/
# benchmark session outputs
benchmark/output

View File

@@ -29,6 +29,7 @@ SessionManager.
1. Creating a simulation - this notebook explains how to build up a simulation using the Python package. (WIP)
- Red Agent Services:
- Data Manipulator Bot - A red agent service which sends a payload to a target machine. (By default this payload is a SQL query that breaks a database)
- DNS Services: DNS Client and DNS Server
## [2.0.0] - 2023-07-26

View File

@@ -0,0 +1,56 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
DNS Client Server
=================
DNS Server
----------
Also known as a DNS Resolver, the ``DNSServer`` provides a DNS Server simulation by extending the base Service class.
Key capabilities
^^^^^^^^^^^^^^^^
- Simulates DNS requests and DNSPacket transfer across a network
- Registers domain names and the IP Address linked to the domain name
- Returns the IP address for a given domain name within a DNS Packet that a DNS Client can read
- Leverages the Service base class for install/uninstall, status tracking, etc.
Usage
^^^^^
- Install on a Node via the ``SoftwareManager`` to start the database service.
- Service runs on TCP port 53 by default. (TODO: TCP for now, should be UDP in future)
Implementation
^^^^^^^^^^^^^^
- DNS request and responses use a ``DNSPacket`` object
- Extends Service class for integration with ``SoftwareManager``.
DNS Client
----------
The DNSClient provides a client interface for connecting to the ``DNSServer``.
Key features
^^^^^^^^^^^^
- Connects to the ``DNSServer`` via the ``SoftwareManager``.
- Executes DNS lookup requests and keeps a cache of known domain name IP addresses.
- Handles connection to DNSServer and querying for domain name IP addresses.
Usage
^^^^^
- Install on a Node via the ``SoftwareManager`` to start the database service.
- Service runs on TCP port 53 by default. (TODO: TCP for now, should be UDP in future)
- Execute domain name checks with ``check_domain_exists``.
- ``DNSClient`` will automatically add the IP Address of the domain into its cache
Implementation
^^^^^^^^^^^^^^
- Leverages ``SoftwareManager`` for sending payloads over the network.
- Provides easy interface for Nodes to find IP addresses via domain names.
- Extends base Service class.

View File

@@ -17,3 +17,4 @@ Contents
database_client_server
data_manipulation_bot
dns_client_server

View File

@@ -192,6 +192,9 @@ class SimComponent(BaseModel):
:param action: List describing the action to apply to this object.
:type action: List[str]
:param: context: Dict containing context for actions
:type context: Dict
"""
if self.action_manager is None:
return

View File

@@ -5,7 +5,7 @@ import secrets
from enum import Enum
from ipaddress import IPv4Address, IPv4Network
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Dict, Literal, Optional, Tuple, Union
from prettytable import MARKDOWN, PrettyTable
@@ -89,8 +89,6 @@ class NIC(SimComponent):
"The Maximum Transmission Unit (MTU) of the NIC in Bytes. Default is 1500 B"
wake_on_lan: bool = False
"Indicates if the NIC supports Wake-on-LAN functionality."
dns_servers: List[IPv4Address] = []
"List of IP addresses of DNS servers used for name resolution."
connected_node: Optional[Node] = None
"The Node to which the NIC is connected."
connected_link: Optional[Link] = None
@@ -406,7 +404,8 @@ class SwitchPort(SimComponent):
if self.enabled:
frame.decrement_ttl()
self.pcap.capture(frame)
self.connected_node.forward_frame(frame=frame, incoming_port=self)
connected_node: Node = self.connected_node
connected_node.forward_frame(frame=frame, incoming_port=self)
return True
return False
@@ -881,6 +880,8 @@ class Node(SimComponent):
"The NICs on the node."
ethernet_port: Dict[int, NIC] = {}
"The NICs on the node by port id."
dns_server: Optional[IPv4Address] = None
"List of IP addresses of DNS servers used for name resolution."
accounts: Dict[str, Account] = {}
"All accounts on the node."
@@ -930,6 +931,7 @@ class Node(SimComponent):
sys_log=kwargs.get("sys_log"),
session_manager=kwargs.get("session_manager"),
file_system=kwargs.get("file_system"),
dns_server=kwargs.get("dns_server"),
)
super().__init__(**kwargs)
self.arp.nics = self.nics

View File

@@ -10,6 +10,8 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.database_service import DatabaseService
from primaite.simulator.system.services.dns_client import DNSClient
from primaite.simulator.system.services.dns_server import DNSServer
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
@@ -126,9 +128,16 @@ def arcd_uc2_network() -> Network:
# Client 1
client_1 = Computer(
hostname="client_1", ip_address="192.168.10.21", subnet_mask="255.255.255.0", default_gateway="192.168.10.1"
hostname="client_1",
ip_address="192.168.10.21",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
dns_server=IPv4Address("192.168.1.10"),
)
client_1.power_on()
client_1.software_manager.install(DNSClient)
client_1_dns_client_service: DNSServer = client_1.software_manager.software["DNSClient"] # noqa
client_1_dns_client_service.start()
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
client_1.software_manager.install(DataManipulationBot)
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
@@ -136,9 +145,16 @@ def arcd_uc2_network() -> Network:
# Client 2
client_2 = Computer(
hostname="client_2", ip_address="192.168.10.22", subnet_mask="255.255.255.0", default_gateway="192.168.10.1"
hostname="client_2",
ip_address="192.168.10.22",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
dns_server=IPv4Address("192.168.1.10"),
)
client_2.power_on()
client_2.software_manager.install(DNSClient)
client_2_dns_client_service: DNSServer = client_2.software_manager.software["DNSClient"] # noqa
client_2_dns_client_service.start()
network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2])
# Domain Controller
@@ -149,6 +165,8 @@ def arcd_uc2_network() -> Network:
default_gateway="192.168.1.1",
)
domain_controller.power_on()
domain_controller.software_manager.install(DNSServer)
network.connect(endpoint_b=domain_controller.ethernet_port[1], endpoint_a=switch_1.switch_ports[1])
# Database Server
@@ -157,6 +175,7 @@ def arcd_uc2_network() -> Network:
ip_address="192.168.1.14",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
)
database_server.power_on()
network.connect(endpoint_b=database_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[3])
@@ -196,19 +215,33 @@ def arcd_uc2_network() -> Network:
# Web Server
web_server = Server(
hostname="web_server", ip_address="192.168.1.12", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
hostname="web_server",
ip_address="192.168.1.12",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
)
web_server.power_on()
web_server.software_manager.install(DatabaseClient)
database_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
database_client.configure(server_ip_address=IPv4Address("192.168.1.14"))
network.connect(endpoint_b=web_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[2])
database_client.run()
database_client.connect()
# register the web_server to a domain
dns_server_service: DNSServer = domain_controller.software_manager.software["DNSServer"] # noqa
dns_server_service.start()
dns_server_service.dns_register("arcd.com", web_server.ip_address)
# Backup Server
backup_server = Server(
hostname="backup_server", ip_address="192.168.1.16", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
hostname="backup_server",
ip_address="192.168.1.16",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
)
backup_server.power_on()
network.connect(endpoint_b=backup_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[4])
@@ -219,6 +252,7 @@ def arcd_uc2_network() -> Network:
ip_address="192.168.1.110",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
)
security_suite.power_on()
network.connect(endpoint_b=security_suite.ethernet_port[1], endpoint_a=switch_1.switch_ports[7])
@@ -229,6 +263,12 @@ def arcd_uc2_network() -> Network:
router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER)
# Allow PostgreSQL requests
router_1.acl.add_rule(
action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0
)
# Allow DNS requests
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1)
return network

View File

@@ -0,0 +1,61 @@
from __future__ import annotations
from ipaddress import IPv4Address
from typing import Optional
from pydantic import BaseModel
class DNSRequest(BaseModel):
"""Represents a DNS Request packet of a network frame.
:param domain_name_request: Domain Name Request for IP address.
"""
domain_name_request: str
"Domain Name Request for IP address."
class DNSReply(BaseModel):
"""Represents a DNS Reply packet of a network frame.
:param domain_name_ip_address: IP Address of the Domain Name requested.
"""
domain_name_ip_address: Optional[IPv4Address] = None
"IP Address of the Domain Name requested."
class DNSPacket(BaseModel):
"""
Represents the DNS layer of a network frame.
:param dns_request: DNS Request packet sent by DNS Client.
:param dns_reply: DNS Reply packet generated by DNS Server.
:Example:
>>> dns_request = DNSPacket(
... domain_name_request=DNSRequest(domain_name_request="www.google.co.uk"),
... dns_reply=None
... )
>>> dns_response = DNSPacket(
... dns_request=DNSRequest(domain_name_request="www.google.co.uk"),
... dns_reply=DNSReply(domain_name_ip_address=IPv4Address("142.250.179.227"))
... )
"""
dns_request: DNSRequest
"DNS Request packet sent by DNS Client."
dns_reply: Optional[DNSReply] = None
"DNS Reply packet generated by DNS Server."
def generate_reply(self, domain_ip_address: IPv4Address) -> DNSPacket:
"""Generate a new DNSPacket to be sent as a response with a DNS Reply packet which contains the IP address.
:param domain_ip_address: The IP address that was being sought after from the original target domain name.
:return: A new instance of DNSPacket.
"""
self.dns_reply = DNSReply(domain_name_ip_address=domain_ip_address)
return self

View File

@@ -0,0 +1,54 @@
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from primaite.simulator.system.applications.application import Application
class WebBrowser(Application):
"""
Represents a web browser in the simulation environment.
The application requests and loads web pages using its domain name and requesting IP addresses using DNS.
"""
domain_name: str
"The domain name of the webpage."
domain_name_ip_address: Optional[IPv4Address]
"The IP address of the domain name for the webpage."
history: Dict[str]
"A dict that stores all of the previous domain names."
def reset_component_for_episode(self, episode: int):
"""
Resets the Application component for a new episode.
This method ensures the Application is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
self.domain_name = ""
self.domain_name_ip_address = None
self.history = {}
def send(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Sends a payload to the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to send.
:return: True if successful, False otherwise.
"""
pass
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Receives a payload from the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to receive.
:return: True if successful, False otherwise.
"""
pass

View File

@@ -211,7 +211,7 @@ class SessionManager:
session = Session.from_session_key(session_key)
self.sessions_by_key[session_key] = session
self.sessions_by_uuid[session.uuid] = session
self.software_manager.receive_payload_from_session_manger(
self.software_manager.receive_payload_from_session_manager(
payload=frame.payload, port=frame.tcp.dst_port, protocol=frame.ip.protocol, session_id=session.uuid
)

View File

@@ -23,7 +23,13 @@ IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware)
class SoftwareManager:
"""A class that manages all running Services and Applications on a Node and facilitates their communication."""
def __init__(self, session_manager: "SessionManager", sys_log: SysLog, file_system: FileSystem):
def __init__(
self,
session_manager: "SessionManager",
sys_log: SysLog,
file_system: FileSystem,
dns_server: Optional[IPv4Address],
):
"""
Initialize a new instance of SoftwareManager.
@@ -35,6 +41,7 @@ class SoftwareManager:
self.port_protocol_mapping: Dict[Tuple[Port, IPProtocol], Union[Service, Application]] = {}
self.sys_log: SysLog = sys_log
self.file_system: FileSystem = file_system
self.dns_server: Optional[IPv4Address] = dns_server
def get_open_ports(self) -> List[Port]:
"""
@@ -58,7 +65,9 @@ class SoftwareManager:
if software_class in self._software_class_to_name_map:
self.sys_log.info(f"Cannot install {software_class} as it is already installed")
return
software = software_class(software_manager=self, sys_log=self.sys_log, file_system=self.file_system)
software = software_class(
software_manager=self, sys_log=self.sys_log, file_system=self.file_system, dns_server=self.dns_server
)
if isinstance(software, Application):
software.install()
software.software_manager = self
@@ -114,7 +123,7 @@ class SoftwareManager:
payload=payload, dst_ip_address=dest_ip_address, dst_port=dest_port, session_id=session_id
)
def receive_payload_from_session_manger(self, payload: Any, port: Port, protocol: IPProtocol, session_id: str):
def receive_payload_from_session_manager(self, payload: Any, port: Port, protocol: IPProtocol, session_id: str):
"""
Receive a payload from the SessionManager and forward it to the corresponding service or application.

View File

@@ -0,0 +1,154 @@
from ipaddress import IPv4Address
from typing import Dict, Optional
from primaite import getLogger
from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.services.service import Service
_LOGGER = getLogger(__name__)
class DNSClient(Service):
"""Represents a DNS Client as a Service."""
dns_cache: Dict[str, IPv4Address] = {}
"A dict of known mappings between domain/URLs names and IPv4 addresses."
dns_server: Optional[IPv4Address] = None
"The DNS Server the client sends requests to."
def __init__(self, **kwargs):
kwargs["name"] = "DNSClient"
kwargs["port"] = Port.DNS
# DNS uses UDP by default
# it switches to TCP when the bytes exceed 512 (or 4096) bytes
# TCP for now
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
def describe_state(self) -> Dict:
"""
Describes the current state of the software.
The specifics of the software's state, including its health, criticality,
and any other pertinent information, should be implemented in subclasses.
:return: A dictionary containing key-value pairs representing the current state of the software.
:rtype: Dict
"""
state = super().describe_state()
return state
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
This method ensures the Service is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
pass
def add_domain_to_cache(self, domain_name: str, ip_address: IPv4Address):
"""
Adds a domain name to the DNS Client cache.
:param: domain_name: The domain name to save to cache
:param: ip_address: The IP Address to attach the domain name to
"""
self.dns_cache[domain_name] = ip_address
def check_domain_exists(
self,
target_domain: str,
session_id: Optional[str] = None,
is_reattempt: bool = False,
) -> bool:
"""Function to check if domain name exists.
:param: target_domain: The domain requested for an IP address.
:param: session_id: The Session ID the payload is to originate from. Optional.
:param: is_reattempt: Checks if the request has been reattempted. Default is False.
"""
# check if the target domain is in the client's DNS cache
payload = DNSPacket(dns_request=DNSRequest(domain_name_request=target_domain))
# check if the domain is already in the DNS cache
if target_domain in self.dns_cache:
self.sys_log.info(
f"DNS Client: Domain lookup for {target_domain} successful, resolves to {self.dns_cache[target_domain]}"
)
return True
else:
# return False if already reattempted
if is_reattempt:
self.sys_log.info(f"DNS Client: Domain lookup for {target_domain} failed")
return False
else:
# send a request to check if domain name exists in the DNS Server
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload=payload, dest_ip_address=self.dns_server, dest_port=Port.DNS
)
# recursively re-call the function passing is_reattempt=True
return self.check_domain_exists(
target_domain=target_domain,
session_id=session_id,
is_reattempt=True,
)
def send(
self,
payload: DNSPacket,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""
Sends a payload to the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to be sent.
:param dest_ip_address: The ip address of the payload destination.
:param dest_port: The port of the payload destination.
:param session_id: The Session ID the payload is to originate from. Optional.
:return: True if successful, False otherwise.
"""
# create DNS request packet
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id)
return True
def receive(
self,
payload: DNSPacket,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""
Receives a payload from the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to be sent.
:param session_id: The Session ID the payload is to originate from. Optional.
:return: True if successful, False otherwise.
"""
# The payload should be a DNS packet
if not isinstance(payload, DNSPacket):
_LOGGER.debug(f"{payload} is not a DNSPacket")
return False
# cast payload into a DNS packet
payload: DNSPacket = payload
if payload.dns_reply is not None:
# add the IP address to the client cache
if payload.dns_reply.domain_name_ip_address:
self.dns_cache[payload.dns_request.domain_name_request] = payload.dns_reply.domain_name_ip_address
return True
return False

View File

@@ -0,0 +1,122 @@
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.simulator.network.protocols.dns import DNSPacket
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.service import Service
_LOGGER = getLogger(__name__)
class DNSServer(Service):
"""Represents a DNS Server as a Service."""
dns_table: Dict[str, IPv4Address] = {}
"A dict of mappings between domain names and IPv4 addresses."
def __init__(self, **kwargs):
kwargs["name"] = "DNSServer"
kwargs["port"] = Port.DNS
# DNS uses UDP by default
# it switches to TCP when the bytes exceed 512 (or 4096) bytes
# TCP for now
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
def describe_state(self) -> Dict:
"""
Describes the current state of the software.
The specifics of the software's state, including its health, criticality,
and any other pertinent information, should be implemented in subclasses.
:return: A dictionary containing key-value pairs representing the current state of the software.
:rtype: Dict
"""
state = super().describe_state()
return state
def dns_lookup(self, target_domain: str) -> Optional[IPv4Address]:
"""
Attempts to find the IP address for a domain name.
:param target_domain: The single domain name requested by a DNS client.
:return ip_address: The IP address of that domain name or None.
"""
return self.dns_table.get(target_domain)
def dns_register(self, domain_name: str, domain_ip_address: IPv4Address):
"""
Register a domain name and its IP address.
:param: domain_name: The domain name to register
:type: domain_name: str
:param: domain_ip_address: The IP address that the domain should route to
:type: domain_ip_address: IPv4Address
"""
self.dns_table[domain_name] = domain_ip_address
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
This method ensures the Service is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
pass
def receive(
self,
payload: Any,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""
Receives a payload from the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param: payload: The payload to send.
:param: session_id: The id of the session. Optional.
:return: True if DNS request returns a valid IP, otherwise, False
"""
# The payload should be a DNS packet
if not isinstance(payload, DNSPacket):
_LOGGER.debug(f"{payload} is not a DNSPacket")
return False
# cast payload into a DNS packet
payload: DNSPacket = payload
if payload.dns_request is not None:
self.sys_log.info(
f"DNS Server: Received domain lookup request for {payload.dns_request.domain_name_request} "
f"from session {session_id}"
)
# generate a reply with the correct DNS IP address
payload = payload.generate_reply(self.dns_lookup(payload.dns_request.domain_name_request))
self.sys_log.info(
f"DNS Server: Responding to domain lookup request for {payload.dns_request.domain_name_request} "
f"with ip address: {payload.dns_reply.domain_name_ip_address}"
)
# send reply
self.send(payload, session_id)
return payload.dns_reply.domain_name_ip_address is not None
return False
def show(self, markdown: bool = False):
"""Prints a table of DNS Lookup table."""
table = PrettyTable(["Domain Name", "IP Address"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.sys_log.hostname} DNS Lookup table"
for dns in self.dns_table.items():
table.add_row([dns[0], dns[1]])
print(table)

View File

@@ -72,29 +72,44 @@ class Service(IOSoftware):
"""
pass
def send(self, payload: Any, session_id: str, **kwargs) -> bool:
def send(
self,
payload: Any,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""
Sends a payload to the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to send.
:param: payload: The payload to send.
:param: session_id: The id of the session
:return: True if successful, False otherwise.
"""
pass
self.software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id)
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
def receive(
self,
payload: Any,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""
Receives a payload from the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to receive.
:param: payload: The payload to send.
:param: session_id: The id of the session
:return: True if successful, False otherwise.
"""
pass
pass
def stop(self) -> None:
"""Stop the service."""

View File

@@ -0,0 +1,28 @@
from ipaddress import IPv4Address
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.system.services.dns_client import DNSClient
from primaite.simulator.system.services.dns_server import DNSServer
from primaite.simulator.system.services.service import ServiceOperatingState
def test_dns_client_server(uc2_network):
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
domain_controller: Server = uc2_network.get_node_by_hostname("domain_controller")
dns_client: DNSClient = client_1.software_manager.software["DNSClient"]
dns_server: DNSServer = domain_controller.software_manager.software["DNSServer"]
assert dns_client.operating_state == ServiceOperatingState.RUNNING
assert dns_server.operating_state == ServiceOperatingState.RUNNING
dns_server.show()
# fake domain should not be added to dns cache
assert not dns_client.check_domain_exists(target_domain="fake-domain.com")
assert dns_client.dns_cache.get("fake-domain.com", None) is None
# arcd.com is registered in dns server and should be saved to cache
assert dns_client.check_domain_exists(target_domain="arcd.com")
assert dns_client.dns_cache.get("arcd.com", None) is not None

View File

@@ -45,7 +45,6 @@ def test_seeded_learning(temp_primaite_session):
), "Expected output is based upon a agent that was trained with seed 67890"
session.learn()
actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict()
print(actual_mean_reward_per_episode, "THISt")
assert actual_mean_reward_per_episode == expected_mean_reward_per_episode

View File

@@ -0,0 +1,100 @@
from ipaddress import IPv4Address
import pytest
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.protocols.dns import DNSPacket, DNSReply, DNSRequest
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.dns_client import DNSClient
from primaite.simulator.system.services.dns_server import DNSServer
@pytest.fixture(scope="function")
def dns_server() -> Node:
node = Node(hostname="dns_server")
node.software_manager.install(software_class=DNSServer)
node.software_manager.software["DNSServer"].start()
return node
@pytest.fixture(scope="function")
def dns_client() -> Node:
node = Node(hostname="dns_client")
node.software_manager.install(software_class=DNSClient)
node.software_manager.software["DNSClient"].start()
return node
def test_create_dns_server(dns_server):
assert dns_server is not None
dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"]
assert dns_server_service.name is "DNSServer"
assert dns_server_service.port is Port.DNS
assert dns_server_service.protocol is IPProtocol.TCP
def test_create_dns_client(dns_client):
assert dns_client is not None
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"]
assert dns_client_service.name is "DNSClient"
assert dns_client_service.port is Port.DNS
assert dns_client_service.protocol is IPProtocol.TCP
def test_dns_server_domain_name_registration(dns_server):
"""Test to check if the domain name registration works."""
dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"]
# register the web server in the domain controller
dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12"))
# return none for an unknown domain
assert dns_server_service.dns_lookup("fake-domain.com") is None
assert dns_server_service.dns_lookup("real-domain.com") is not None
def test_dns_client_check_domain_in_cache(dns_client):
"""Test to make sure that the check_domain_in_cache returns the correct values."""
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"]
# add a domain to the dns client cache
dns_client_service.add_domain_to_cache("real-domain.com", IPv4Address("192.168.1.12"))
assert dns_client_service.check_domain_exists("fake-domain.com") is False
assert dns_client_service.check_domain_exists("real-domain.com") is True
def test_dns_server_receive(dns_server):
"""Test to make sure that the DNS Server correctly responds to a DNS Client request."""
dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"]
# register the web server in the domain controller
dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12"))
assert (
dns_server_service.receive(payload=DNSPacket(dns_request=DNSRequest(domain_name_request="fake-domain.com")))
is False
)
assert (
dns_server_service.receive(payload=DNSPacket(dns_request=DNSRequest(domain_name_request="real-domain.com")))
is True
)
dns_server_service.show()
def test_dns_client_receive(dns_client):
"""Test to make sure the DNS Client knows how to deal with request responses."""
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"]
dns_client_service.receive(
payload=DNSPacket(
dns_request=DNSRequest(domain_name_request="real-domain.com"),
dns_reply=DNSReply(domain_name_ip_address=IPv4Address("192.168.1.12")),
)
)
# domain name should be saved to cache
assert dns_client_service.dns_cache["real-domain.com"] == IPv4Address("192.168.1.12")