From b1d8666c163798adaf95c05f4b65c1028bbe4289 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Fri, 8 Sep 2023 16:50:49 +0100 Subject: [PATCH] #1816 - Added database client. Installed the database client on the Web Server node in the UC2 network. Updated the integration test to query the DB server using the DB client. --- .../network/internal_frame_processing.rst | 0 .../simulator/file_system/file_type.py | 170 ++++++++++++++++++ .../simulator/network/hardware/base.py | 102 +++++++---- src/primaite/simulator/network/networks.py | 14 +- .../system/applications/application.py | 23 ++- .../system/applications/database_client.py | 83 +++++++++ .../simulator/system/core/software_manager.py | 89 ++++----- .../simulator/system/services/database.py | 58 +++++- .../simulator/system/services/service.py | 7 - tests/conftest.py | 7 + .../system/test_database_on_node.py | 57 +++--- 11 files changed, 478 insertions(+), 132 deletions(-) create mode 100644 docs/source/simulation_components/network/internal_frame_processing.rst create mode 100644 src/primaite/simulator/file_system/file_type.py create mode 100644 src/primaite/simulator/system/applications/database_client.py diff --git a/docs/source/simulation_components/network/internal_frame_processing.rst b/docs/source/simulation_components/network/internal_frame_processing.rst new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/simulator/file_system/file_type.py b/src/primaite/simulator/file_system/file_type.py new file mode 100644 index 00000000..140cd0e7 --- /dev/null +++ b/src/primaite/simulator/file_system/file_type.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +from enum import Enum +from random import choice + + +class FileType(Enum): + """An enumeration of common file types.""" + + UNKNOWN = 0 + "Unknown file type." + + # Text formats + TXT = 1 + "Plain text file." + DOC = 2 + "Microsoft Word document (.doc)" + DOCX = 3 + "Microsoft Word document (.docx)" + PDF = 4 + "Portable Document Format." + HTML = 5 + "HyperText Markup Language file." + XML = 6 + "Extensible Markup Language file." + CSV = 7 + "Comma-Separated Values file." + + # Spreadsheet formats + XLS = 8 + "Microsoft Excel file (.xls)" + XLSX = 9 + "Microsoft Excel file (.xlsx)" + + # Image formats + JPEG = 10 + "JPEG image file." + PNG = 11 + "PNG image file." + GIF = 12 + "GIF image file." + BMP = 13 + "Bitmap image file." + + # Audio formats + MP3 = 14 + "MP3 audio file." + WAV = 15 + "WAV audio file." + + # Video formats + MP4 = 16 + "MP4 video file." + AVI = 17 + "AVI video file." + MKV = 18 + "MKV video file." + FLV = 19 + "FLV video file." + + # Presentation formats + PPT = 20 + "Microsoft PowerPoint file (.ppt)" + PPTX = 21 + "Microsoft PowerPoint file (.pptx)" + + # Web formats + JS = 22 + "JavaScript file." + CSS = 23 + "Cascading Style Sheets file." + + # Programming languages + PY = 24 + "Python script file." + C = 25 + "C source code file." + CPP = 26 + "C++ source code file." + JAVA = 27 + "Java source code file." + + # Compressed file types + RAR = 28 + "RAR archive file." + ZIP = 29 + "ZIP archive file." + TAR = 30 + "TAR archive file." + GZ = 31 + "Gzip compressed file." + + # Database file types + DB = 32 + "Generic DB file. Used by sqlite3." + + @classmethod + def _missing_(cls, value): + return cls.UNKNOWN + + @classmethod + def random(cls) -> FileType: + """ + Returns a random FileType. + + :return: A random FileType. + """ + return choice(list(FileType)) + + @property + def default_size(self) -> int: + """ + Get the default size of the FileType in bytes. + + Returns 0 if a default size does not exist. + """ + size = file_type_sizes_bytes[self] + return size if size else 0 + + +def get_file_type_from_extension(file_type_extension: str): + """ + Get a FileType from a file type extension. + + If a matching extension does not exist, FileType.UNKNOWN is returned. + + :param file_type_extension: A file type extension. + :return: A file type extension. + """ + try: + return FileType[file_type_extension.upper()] + except KeyError: + return FileType.UNKNOWN + + +file_type_sizes_bytes = { + FileType.UNKNOWN: 0, + FileType.TXT: 4096, + FileType.DOC: 51200, + FileType.DOCX: 30720, + FileType.PDF: 102400, + FileType.HTML: 15360, + FileType.XML: 10240, + FileType.CSV: 15360, + FileType.XLS: 102400, + FileType.XLSX: 25600, + FileType.JPEG: 102400, + FileType.PNG: 40960, + FileType.GIF: 30720, + FileType.BMP: 307200, + FileType.MP3: 5120000, + FileType.WAV: 25600000, + FileType.MP4: 25600000, + FileType.AVI: 51200000, + FileType.MKV: 51200000, + FileType.FLV: 15360000, + FileType.PPT: 204800, + FileType.PPTX: 102400, + FileType.JS: 10240, + FileType.CSS: 5120, + FileType.PY: 5120, + FileType.C: 5120, + FileType.CPP: 10240, + FileType.JAVA: 10240, + FileType.RAR: 1024000, + FileType.ZIP: 1024000, + FileType.TAR: 1024000, + FileType.GZ: 819200, + FileType.DB: 15360000, +} diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index fa1058cf..efc1e251 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -5,7 +5,7 @@ import secrets from enum import Enum from ipaddress import IPv4Address, IPv4Network from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union from prettytable import MARKDOWN, PrettyTable @@ -959,7 +959,24 @@ class Node(SimComponent): ) return state - def show(self, markdown: bool = False): + def show(self, markdown: bool = False, component: Literal["NIC", "OPEN_PORTS"] = "NIC"): + if component == "NIC": + self._show_nic(markdown) + elif component == "OPEN_PORTS": + self._show_open_ports(markdown) + + def _show_open_ports(self, markdown: bool = False): + """Prints a table of the open ports on the Node.""" + table = PrettyTable(["Port", "Name"]) + if markdown: + table.set_style(MARKDOWN) + table.align = "l" + table.title = f"{self.hostname} Open Ports" + for port in self.software_manager.get_open_ports(): + table.add_row([port.value, port.name]) + print(table) + + def _show_nic(self, markdown: bool = False): """Prints a table of the NICs on the Node.""" table = PrettyTable(["Port", "MAC Address", "Address", "Speed", "Status"]) if markdown: @@ -1048,29 +1065,30 @@ class Node(SimComponent): :param pings: The number of pings to attempt, default is 4. :return: True if the ping is successful, otherwise False. """ - if not isinstance(target_ip_address, IPv4Address): - target_ip_address = IPv4Address(target_ip_address) - if target_ip_address.is_loopback: - self.sys_log.info("Pinging loopback address") - return any(nic.enabled for nic in self.nics.values()) if self.operating_state == NodeOperatingState.ON: - self.sys_log.info(f"Pinging {target_ip_address}:") - sequence, identifier = 0, None - while sequence < pings: - sequence, identifier = self.icmp.ping(target_ip_address, sequence, identifier, pings) - request_replies = self.icmp.request_replies.get(identifier) - passed = request_replies == pings - if request_replies: - self.icmp.request_replies.pop(identifier) - else: - request_replies = 0 - self.sys_log.info( - f"Ping statistics for {target_ip_address}: " - f"Packets: Sent = {pings}, " - f"Received = {request_replies}, " - f"Lost = {pings-request_replies} ({(pings-request_replies)/pings*100}% loss)" - ) - return passed + if not isinstance(target_ip_address, IPv4Address): + target_ip_address = IPv4Address(target_ip_address) + if target_ip_address.is_loopback: + self.sys_log.info("Pinging loopback address") + return any(nic.enabled for nic in self.nics.values()) + if self.operating_state == NodeOperatingState.ON: + self.sys_log.info(f"Pinging {target_ip_address}:") + sequence, identifier = 0, None + while sequence < pings: + sequence, identifier = self.icmp.ping(target_ip_address, sequence, identifier, pings) + request_replies = self.icmp.request_replies.get(identifier) + passed = request_replies == pings + if request_replies: + self.icmp.request_replies.pop(identifier) + else: + request_replies = 0 + self.sys_log.info( + f"Ping statistics for {target_ip_address}: " + f"Packets: Sent = {pings}, " + f"Received = {request_replies}, " + f"Lost = {pings-request_replies} ({(pings-request_replies)/pings*100}% loss)" + ) + return passed return False def send_frame(self, frame: Frame): @@ -1079,7 +1097,8 @@ class Node(SimComponent): :param frame: The Frame to be sent. """ - nic: NIC = self._get_arp_cache_nic(frame.ip.dst_ip_address) + if self.operating_state == NodeOperatingState.ON: + nic: NIC = self._get_arp_cache_nic(frame.ip.dst_ip_address) nic.send_frame(frame) def receive_frame(self, frame: Frame, from_nic: NIC): @@ -1092,20 +1111,27 @@ class Node(SimComponent): :param frame: The Frame being received. :param from_nic: The NIC that received the frame. """ - if frame.ip: - if frame.ip.src_ip_address in self.arp: - self.arp.add_arp_cache_entry( - ip_address=frame.ip.src_ip_address, mac_address=frame.ethernet.src_mac_addr, nic=from_nic - ) - if frame.ip.protocol == IPProtocol.TCP: - if frame.tcp.src_port == Port.ARP: - self.arp.process_arp_packet(from_nic=from_nic, arp_packet=frame.arp) + if self.operating_state == NodeOperatingState.ON: + if frame.ip: + if frame.ip.src_ip_address in self.arp: + self.arp.add_arp_cache_entry( + ip_address=frame.ip.src_ip_address, mac_address=frame.ethernet.src_mac_addr, nic=from_nic + ) + if frame.ip.protocol == IPProtocol.ICMP: + self.icmp.process_icmp(frame=frame, from_nic=from_nic) + return + # Check if the destination port is open on the Node + if frame.tcp.dst_port in self.software_manager.get_open_ports(): + # accept thr frame as the port is open + if frame.tcp.src_port == Port.ARP: + self.arp.process_arp_packet(from_nic=from_nic, arp_packet=frame.arp) + else: + self.session_manager.receive_frame(frame) else: - self.session_manager.receive_frame(frame) - elif frame.ip.protocol == IPProtocol.UDP: - pass - elif frame.ip.protocol == IPProtocol.ICMP: - self.icmp.process_icmp(frame=frame, from_nic=from_nic) + # denied as port closed + self.sys_log.info(f"Ignoring frame for port {frame.tcp.dst_port.value} from {frame.ip.src_ip_address}") + # TODO: do we need to do anything more here? + pass def install_service(self, service: Service) -> None: """ diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index c030d907..a364abea 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -6,6 +6,7 @@ from primaite.simulator.network.hardware.nodes.server import Server from primaite.simulator.network.hardware.nodes.switch import Switch from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.services.database import DatabaseService @@ -149,6 +150,9 @@ def arcd_uc2_network() -> Network: hostname="web_server", ip_address="192.168.1.12", subnet_mask="255.255.255.0", default_gateway="192.168.1.1" ) web_server.power_on() + web_server.software_manager.install(DatabaseClient) + database_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + database_client.run() network.connect(endpoint_b=web_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[2]) # Database Server @@ -187,12 +191,12 @@ def arcd_uc2_network() -> Network: "INSERT INTO user (name, email, age, city, occupation) VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');", # noqa "INSERT INTO user (name, email, age, city, occupation) VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');", # noqa ] - database_server.software_manager.add_service(DatabaseService) - database: DatabaseService = database_server.software_manager.services["Database"] # noqa - database.start() - database._process_sql(ddl) # noqa + database_server.software_manager.install(DatabaseService) + database_service: DatabaseService = database_server.software_manager.software["DatabaseService"] # noqa + database_service.start() + database_service._process_sql(ddl) # noqa for insert_statement in user_insert_statements: - database._process_sql(insert_statement) # noqa + database_service._process_sql(insert_statement) # noqa # Backup Server backup_server = Server( diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 6a07f00f..2a3013e1 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -23,9 +23,9 @@ class Application(IOSoftware): Applications are user-facing programs that may perform input/output operations. """ - operating_state: ApplicationOperatingState + operating_state: ApplicationOperatingState = ApplicationOperatingState.CLOSED "The current operating state of the Application." - execution_control_status: str + execution_control_status: str = "manual" "Control status of the application's execution. It could be 'manual' or 'automatic'." num_executions: int = 0 "The number of times the application has been executed. Default is 0." @@ -53,6 +53,25 @@ class Application(IOSoftware): ) return state + def run(self) -> None: + """Open the Application""" + if self.operating_state == ApplicationOperatingState.CLOSED: + self.sys_log.info(f"Running Application {self.name}") + self.operating_state = ApplicationOperatingState.RUNNING + + def close(self) -> None: + """Close the Application""" + if self.operating_state == ApplicationOperatingState.RUNNING: + self.sys_log.info(f"Closed Application{self.name}") + self.operating_state = ApplicationOperatingState.CLOSED + + def install(self) -> None: + """Install Application.""" + super().install() + if self.operating_state == ApplicationOperatingState.CLOSED: + self.sys_log.info(f"Installing Application {self.name}") + self.operating_state = ApplicationOperatingState.INSTALLING + def reset_component_for_episode(self, episode: int): """ Resets the Application component for a new episode. diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py new file mode 100644 index 00000000..38ce3c7f --- /dev/null +++ b/src/primaite/simulator/system/applications/database_client.py @@ -0,0 +1,83 @@ +from ipaddress import IPv4Address +from typing import Any, Dict, Optional + +from prettytable import PrettyTable + +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.application import Application +from primaite.simulator.system.core.software_manager import SoftwareManager + + +class DatabaseClient(Application): + server_ip_address: Optional[IPv4Address] = None + connected: bool = False + + def __init__(self, **kwargs): + kwargs["name"] = "DatabaseClient" + kwargs["port"] = Port.POSTGRES_SERVER + kwargs["protocol"] = IPProtocol.TCP + super().__init__(**kwargs) + + def describe_state(self) -> Dict: + return super().describe_state() + + def connect(self, server_ip_address: IPv4Address, password: Optional[str] = None) -> bool: + if not self.connected and self.operating_state.RUNNING: + return self._connect(server_ip_address, password) + + def _connect( + self, server_ip_address: IPv4Address, password: Optional[str] = None, is_reattempt: bool = False + ) -> bool: + if is_reattempt: + if self.connected: + self.sys_log.info(f"DatabaseClient connected to {server_ip_address} authorised") + self.server_ip_address = server_ip_address + return self.connected + else: + self.sys_log.info(f"DatabaseClient connected to {server_ip_address} declined") + payload = {"type": "connect_request", "password": password} + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( + payload=payload, dest_ip_address=server_ip_address, dest_port=self.port + ) + return self._connect(server_ip_address, password, True) + + def disconnect(self): + if self.connected and self.operating_state.RUNNING: + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( + payload={"type": "disconnect"}, dest_ip_address=self.server_ip_address, dest_port=self.port + ) + + self.sys_log.info(f"DatabaseClient disconnected from {self.server_ip_address}") + self.server_ip_address = None + + def query(self, sql: str): + if self.connected and self.operating_state.RUNNING: + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( + payload={"type": "sql", "sql": sql}, dest_ip_address=self.server_ip_address, dest_port=self.port + ) + + def _print_data(self, data: Dict): + """ + Display the contents of the Folder in tabular format. + + :param markdown: Whether to display the table in Markdown format or not. Default is `False`. + """ + table = PrettyTable(list(data.values())[0]) + + table.align = "l" + table.title = f"{self.sys_log.hostname} Database Client" + for row in data.values(): + table.add_row(row.values()) + print(table) + + def receive(self, payload: Any, session_id: str, **kwargs) -> bool: + if isinstance(payload, dict) and payload.get("type"): + if payload["type"] == "connect_response": + self.connected = payload["response"] == True + elif payload["type"] == "sql": + self._print_data(payload["data"]) + return True diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index 13d4524c..c3fe29fd 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -1,15 +1,15 @@ from ipaddress import IPv4Address -from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union from prettytable import MARKDOWN, PrettyTable from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port -from primaite.simulator.system.applications.application import Application +from primaite.simulator.system.applications.application import Application, ApplicationOperatingState from primaite.simulator.system.core.sys_log import SysLog -from primaite.simulator.system.services.service import Service -from primaite.simulator.system.software import SoftwareType +from primaite.simulator.system.services.service import Service, ServiceOperatingState +from primaite.simulator.system.software import IOSoftware, SoftwareType if TYPE_CHECKING: from primaite.simulator.system.core.session_manager import SessionManager @@ -17,7 +17,7 @@ if TYPE_CHECKING: from typing import Type, TypeVar -ServiceClass = TypeVar("ServiceClass", bound=Service) +IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware) class SoftwareManager: @@ -30,57 +30,55 @@ class SoftwareManager: :param session_manager: The session manager handling network communications. """ self.session_manager = session_manager - self.services: Dict[str, Service] = {} - self.applications: Dict[str, Application] = {} + self.software: Dict[str, Union[Service, Application]] = {} + self._software_class_to_name_map: Dict[Type[IOSoftwareClass], str] = {} self.port_protocol_mapping: Dict[Tuple[Port, IPProtocol], Union[Service, Application]] = {} self.sys_log: SysLog = sys_log self.file_system: FileSystem = file_system - def add_service(self, service_class: Type[ServiceClass]): - """ - Add a Service to the manager. + def get_open_ports(self) -> List[Port]: + open_ports = [Port.ARP] + for software in self.port_protocol_mapping.values(): + if software.operating_state in {ApplicationOperatingState.RUNNING, ServiceOperatingState.RUNNING}: + open_ports.append(software.port) + open_ports.sort(key=lambda port: port.value) + return open_ports - :param: service_class: The class of the service to add - """ - service = service_class(software_manager=self, sys_log=self.sys_log, file_system=self.file_system) + def install(self, software_class: Type[IOSoftwareClass]): + if software_class in self._software_class_to_name_map: + self.sys_log.info(f"Cannot install {software_class} as it is already installed") + return + software = software_class(software_manager=self, sys_log=self.sys_log, file_system=self.file_system) + if isinstance(software, Application): + software.install() + software.software_manager = self + self.software[software.name] = software + self.port_protocol_mapping[(software.port, software.protocol)] = software + self.sys_log.info(f"Installed {software.name}") + if isinstance(software, Application): + software.operating_state = ApplicationOperatingState.CLOSED - service.software_manager = self - self.services[service.name] = service - self.port_protocol_mapping[(service.port, service.protocol)] = service + def uninstall(self, software_name: str): + if software_name in self.software: + software = self.software.pop(software_name) # noqa + del software + self.sys_log.info(f"Deleted {software_name}") + return + self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed") - def add_application(self, name: str, application: Application, port: Port, protocol: IPProtocol): - """ - Add an Application to the manager. - - :param name: The name of the application. - :param application: The application instance. - :param port: The port used by the application. - :param protocol: The network protocol used by the application. - """ - application.software_manager = self - self.applications[name] = application - self.port_protocol_mapping[(port, protocol)] = application - - def send_internal_payload(self, target_software: str, target_software_type: SoftwareType, payload: Any): + def send_internal_payload(self, target_software: str, payload: Any): """ Send a payload to a specific service or application. :param target_software: The name of the target service or application. - :param target_software_type: The type of software (Service, Application, Process). :param payload: The data to be sent. - :param receiver_type: The type of the target, either 'service' or 'application'. """ - if target_software_type is SoftwareType.SERVICE: - receiver = self.services.get(target_software) - elif target_software_type is SoftwareType.APPLICATION: - receiver = self.applications.get(target_software) - else: - raise ValueError(f"Invalid receiver type {target_software_type}") + receiver = self.software.get(target_software) if receiver: receiver.receive_payload(payload) else: - raise ValueError(f"No {target_software_type.name.lower()} found with the name {target_software}") + self.sys_log.error(f"No Service of Application found with the name {target_software}") def send_payload_to_session_manager( self, @@ -121,13 +119,20 @@ class SoftwareManager: :param markdown: If True, outputs the table in markdown format. Default is False. """ - table = PrettyTable(["Name", "Operating State", "Health State", "Port"]) + table = PrettyTable(["Name", "Type", "Operating State", "Health State", "Port"]) if markdown: table.set_style(MARKDOWN) table.align = "l" table.title = f"{self.sys_log.hostname} Software Manager" - for service in self.services.values(): + for software in self.port_protocol_mapping.values(): + software_type = "Service" if isinstance(software, Service) else "Application" table.add_row( - [service.name, service.operating_state.name, service.health_state_actual.name, service.port.value] + [ + software.name, + software_type, + software.operating_state.name, + software.health_state_actual.name, + software.port.value, + ] ) print(table) diff --git a/src/primaite/simulator/system/services/database.py b/src/primaite/simulator/system/services/database.py index dc148031..e34f06fa 100644 --- a/src/primaite/simulator/system/services/database.py +++ b/src/primaite/simulator/system/services/database.py @@ -1,4 +1,5 @@ import sqlite3 +from datetime import datetime from ipaddress import IPv4Address from sqlite3 import OperationalError from typing import Any, Dict, List, Optional, Union @@ -8,8 +9,10 @@ from prettytable import MARKDOWN, PrettyTable from primaite.simulator.file_system.file_system import File from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.core.session_manager import Session from primaite.simulator.system.core.software_manager import SoftwareManager -from primaite.simulator.system.services.service import Service +from primaite.simulator.system.services.service import Service, ServiceOperatingState +from primaite.simulator.system.software import SoftwareHealthState class DatabaseService(Service): @@ -19,11 +22,11 @@ class DatabaseService(Service): This class inherits from the `Service` class and provides methods to manage and query a SQLite database. """ - backup_server: Optional[IPv4Address] = None - "The IP Address of the server the " + password: Optional[str] = None + connections: Dict[str, datetime] = {} def __init__(self, **kwargs): - kwargs["name"] = "Database" + kwargs["name"] = "DatabaseService" kwargs["port"] = Port.POSTGRES_SERVER kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) @@ -62,6 +65,24 @@ class DatabaseService(Service): self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db", real=True) self.folder = self._db_file.folder + def _process_connect( + self, session_id: str, password: Optional[str] = None + ) -> Dict[str, Union[int, Dict[str, bool]]]: + status_code = 500 # Default internal server error + if self.operating_state == ServiceOperatingState.RUNNING: + status_code = 503 # service unavailable + if self.health_state_actual == SoftwareHealthState.GOOD: + if self.password == password: + status_code = 200 # ok + self.connections[session_id] = datetime.now() + self.sys_log.info(f"Connect request for {session_id=} authorised") + else: + status_code = 401 # Unauthorised + self.sys_log.info(f"Connect request for {session_id=} declined") + else: + status_code = 404 # service not found + return {"status_code": status_code, "type": "connect_response", "response": status_code == 200} + def _process_sql(self, query: str) -> Dict[str, Union[int, List[Any]]]: """ Executes the given SQL query and returns the result. @@ -71,12 +92,21 @@ class DatabaseService(Service): """ try: self._cursor.execute(query) + self._conn.commit() except OperationalError: # Handle the case where the table does not exist. return {"status_code": 404, "data": []} - - return {"status_code": 200, "data": self._cursor.fetchall()} + data = [] + description = self._cursor.description + if description: + headers = [] + for header in description: + headers.append(header[0]) + data = self._cursor.fetchall() + if data and headers: + data = {row[0]: {header: value for header, value in zip(headers, row)} for row in data} + return {"status_code": 200, "type": "sql", "data": data} def describe_state(self) -> Dict: """ @@ -97,10 +127,20 @@ class DatabaseService(Service): :param session_id: The session identifier. :return: True if the Status Code is 200, otherwise False. """ - result = self._process_sql(payload) + result = {"status_code": 500, "data": []} + if isinstance(payload, dict) and payload.get("type"): + if payload["type"] == "connect_request": + result = self._process_connect(session_id=session_id, password=payload.get("password")) + elif payload["type"] == "disconnect": + if session_id in self.connections: + self.connections.pop(session_id) + elif payload["type"] == "sql": + if session_id in self.connections: + result = self._process_sql(payload.get("sql")) + else: + result = {"status_code": 401, "type": "sql"} self.send(payload=result, session_id=session_id) - - return payload["status_code"] == 200 + return True def send(self, payload: Any, session_id: str, **kwargs) -> bool: """ diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index b9340103..30b48527 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -98,35 +98,30 @@ class Service(IOSoftware): def stop(self) -> None: """Stop the service.""" - _LOGGER.debug(f"Stopping service {self.name}") if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]: self.sys_log.info(f"Stopping service {self.name}") self.operating_state = ServiceOperatingState.STOPPED def start(self, **kwargs) -> None: """Start the service.""" - _LOGGER.debug(f"Starting service {self.name}") if self.operating_state == ServiceOperatingState.STOPPED: self.sys_log.info(f"Starting service {self.name}") self.operating_state = ServiceOperatingState.RUNNING def pause(self) -> None: """Pause the service.""" - _LOGGER.debug(f"Pausing service {self.name}") if self.operating_state == ServiceOperatingState.RUNNING: self.sys_log.info(f"Pausing service {self.name}") self.operating_state = ServiceOperatingState.PAUSED def resume(self) -> None: """Resume paused service.""" - _LOGGER.debug(f"Resuming service {self.name}") if self.operating_state == ServiceOperatingState.PAUSED: self.sys_log.info(f"Resuming service {self.name}") self.operating_state = ServiceOperatingState.RUNNING def restart(self) -> None: """Restart running service.""" - _LOGGER.debug(f"Restarting service {self.name}") if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]: self.sys_log.info(f"Pausing service {self.name}") self.operating_state = ServiceOperatingState.RESTARTING @@ -134,13 +129,11 @@ class Service(IOSoftware): def disable(self) -> None: """Disable the service.""" - _LOGGER.debug(f"Disabling service {self.name}") self.sys_log.info(f"Disabling Application {self.name}") self.operating_state = ServiceOperatingState.DISABLED def enable(self) -> None: """Enable the disabled service.""" - _LOGGER.debug(f"Enabling service {self.name}") if self.operating_state == ServiceOperatingState.DISABLED: self.sys_log.info(f"Enabling Application {self.name}") self.operating_state = ServiceOperatingState.STOPPED diff --git a/tests/conftest.py b/tests/conftest.py index 9c216a8e..35548f2a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,8 @@ import pytest from primaite import getLogger from primaite.environment.primaite_env import Primaite from primaite.primaite_session import PrimaiteSession +from primaite.simulator.network.container import Network +from primaite.simulator.network.networks import arcd_uc2_network from tests.mock_and_patch.get_session_path_mock import get_temp_session_path ACTION_SPACE_NODE_VALUES = 1 @@ -24,6 +26,11 @@ from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.hardware.base import Node +@pytest.fixture(scope="function") +def uc2_network() -> Network: + return arcd_uc2_network() + + @pytest.fixture(scope="function") def file_system() -> FileSystem: return Node(hostname="fs_node").file_system diff --git a/tests/integration_tests/system/test_database_on_node.py b/tests/integration_tests/system/test_database_on_node.py index 7ad11222..7562b29b 100644 --- a/tests/integration_tests/system/test_database_on_node.py +++ b/tests/integration_tests/system/test_database_on_node.py @@ -1,39 +1,38 @@ from ipaddress import IPv4Address -from primaite.simulator.network.hardware.nodes.computer import Computer -from primaite.simulator.network.networks import arcd_uc2_network -from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame -from primaite.simulator.network.transmission.network_layer import IPPacket, Precedence -from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader +from primaite.simulator.network.hardware.nodes.server import Server +from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.services.database import DatabaseService -def test_database_query_across_the_network(): +def test_database_client_server_connection(uc2_network): + web_server: Server = uc2_network.get_node_by_hostname("web_server") + db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + + db_server: Server = uc2_network.get_node_by_hostname("database_server") + db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] + + assert len(db_service.connections) == 0 + + assert db_client.connect(server_ip_address=IPv4Address("192.168.1.14")) + assert len(db_service.connections) == 1 + + db_client.disconnect() + assert len(db_service.connections) == 0 + + +def test_database_client_query(uc2_network): """Tests DB query across the network returns HTTP status 200 and date.""" - network = arcd_uc2_network() + web_server: Server = uc2_network.get_node_by_hostname("web_server") + db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] - client_1: Computer = network.get_node_by_hostname("client_1") + db_client.connect(server_ip_address=IPv4Address("192.168.1.14")) - client_1.arp.send_arp_request(IPv4Address("192.168.1.14")) + db_client.query("SELECT * FROM user;") - dst_mac_address = client_1.arp.get_arp_cache_mac_address(IPv4Address("192.168.1.14")) + web_server_nic = web_server.ethernet_port[1] - outbound_nic = client_1.arp.get_arp_cache_nic(IPv4Address("192.168.1.14")) - client_1.ping("192.168.1.14") + web_server_last_payload = web_server_nic.pcap.read()[-1]["payload"] - frame = Frame( - ethernet=EthernetHeader(src_mac_addr=client_1.ethernet_port[1].mac_address, dst_mac_addr=dst_mac_address), - ip=IPPacket( - src_ip_address=client_1.ethernet_port[1].ip_address, - dst_ip_address=IPv4Address("192.168.1.14"), - precedence=Precedence.FLASH, - ), - tcp=TCPHeader(src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER), - payload="SELECT * FROM user;", - ) - - outbound_nic.send_frame(frame) - - client_1_last_payload = outbound_nic.pcap.read()[-1]["payload"] - - assert client_1_last_payload["status_code"] == 200 - assert client_1_last_payload["data"] + assert web_server_last_payload["status_code"] == 200 + assert web_server_last_payload["data"]