From 6b41bec32a2dcb279cd8b2bf02824dafca4d8d9a Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Wed, 6 Sep 2023 22:01:51 +0100 Subject: [PATCH] =?UTF-8?q?#1816=20-=20Added=20the=20final=20pieces=20of?= =?UTF-8?q?=20the=20puzzle=20to=20get=20data=20up=20from=20NIC=20=E2=86=92?= =?UTF-8?q?=20session=20manager=20=E2=86=92=20software=20manager=20?= =?UTF-8?q?=E2=86=92=20service.=20-=20Implemented=20a=20basic=20sim=20DB?= =?UTF-8?q?=20that=20matches=20UC2=20data=20manipulation=20DB=20in=20IY.?= =?UTF-8?q?=20-=20Added=20a=20test=20that=20confirms=20DB=20queries=20can?= =?UTF-8?q?=20be=20sent=20over=20the=20network.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../simulator/file_system/file_system.py | 24 ++-- .../simulator/network/hardware/base.py | 13 ++- src/primaite/simulator/network/networks.py | 37 ++++++ .../simulator/system/core/packet_capture.py | 8 ++ .../simulator/system/core/session_manager.py | 88 +++++++------- .../simulator/system/core/software_manager.py | 24 ++-- .../simulator/system/services/database.py | 107 +++++++++--------- src/primaite/simulator/system/software.py | 8 +- tests/conftest.py | 10 ++ .../system/test_database_on_node.py | 86 +++++++------- .../_file_system/test_file_system.py | 25 ++-- .../_file_system/test_file_system_file.py | 23 ---- .../_file_system/test_file_system_folder.py | 75 ------------ .../test_data_manipulator_service.py | 2 +- .../_system/_services/test_database.py | 68 +++++++++-- 15 files changed, 300 insertions(+), 298 deletions(-) delete mode 100644 tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_file.py delete mode 100644 tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_folder.py diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index a5744b4b..d5e81e1b 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -2,6 +2,7 @@ from __future__ import annotations import math import os.path +import shutil from abc import abstractmethod from pathlib import Path from typing import Dict, Optional @@ -220,22 +221,14 @@ class FileSystem(SimComponent): dst_folder = self.create_folder(dst_folder_name) # add file to dst dst_folder.add_file(file) + if file.real: + old_sim_path = file.sim_path + file.sim_path = file.folder.fs.sim_root / file.path + file.sim_path.parent.mkdir(exist_ok=True) + shutil.move(old_sim_path, file.sim_path) def copy_file(self, src_folder_name: str, src_file_name: str, dst_folder_name): - """ - Copies a file from one folder to another. - can provide - - :param file: The file to move - :type: file: File - - :param src_folder: The folder where the file is located - :type: Folder - - :param target_folder: The folder where the file should be moved to - :type: Folder - """ file = self.get_file(folder_name=src_folder_name, file_name=src_file_name) if file: dst_folder = self.get_folder(folder_name=dst_folder_name) @@ -243,6 +236,9 @@ class FileSystem(SimComponent): dst_folder = self.create_folder(dst_folder_name) new_file = file.make_copy(dst_folder=dst_folder) dst_folder.add_file(new_file) + if file.real: + new_file.sim_path.parent.mkdir(exist_ok=True) + shutil.copy2(file.sim_path, new_file.sim_path) def get_folder(self, folder_name: str) -> Optional[Folder]: """ @@ -419,7 +415,7 @@ class File(FileSystemItemABC): pass def make_copy(self, dst_folder: Folder) -> File: - return File(folder=dst_folder, **self.model_dump(exclude={"uuid", "folder"})) + return File(folder=dst_folder, **self.model_dump(exclude={"uuid", "folder", "sim_path"})) @property def path(self): diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index dcad59f8..832e6a13 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -921,16 +921,19 @@ class Node(SimComponent): kwargs["icmp"] = ICMP(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp")) if not kwargs.get("session_manager"): kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp")) - if not kwargs.get("software_manager"): - kwargs["software_manager"] = SoftwareManager( - sys_log=kwargs.get("sys_log"), session_manager=kwargs.get("session_manager") - ) if not kwargs.get("root"): kwargs["root"] = SIM_OUTPUT / kwargs["hostname"] if not kwargs.get("file_system"): kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs") + if not kwargs.get("software_manager"): + kwargs["software_manager"] = SoftwareManager( + sys_log=kwargs.get("sys_log"), + session_manager=kwargs.get("session_manager"), + file_system=kwargs.get("file_system") + ) super().__init__(**kwargs) self.arp.nics = self.nics + self.session_manager.software_manager = self.software_manager def describe_state(self) -> Dict: """ @@ -1097,6 +1100,8 @@ class Node(SimComponent): if frame.ip.protocol == IPProtocol.TCP: if frame.tcp.src_port == Port.ARP: self.arp.process_arp_packet(from_nic=from_nic, arp_packet=frame.arp) + else: + self.session_manager.receive_frame(frame) elif frame.ip.protocol == IPProtocol.UDP: pass elif frame.ip.protocol == IPProtocol.ICMP: diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index 6a50fe3f..cecb108d 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -6,6 +6,7 @@ from primaite.simulator.network.hardware.nodes.server import Server from primaite.simulator.network.hardware.nodes.switch import Switch from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.services.database import DatabaseService def client_server_routed() -> Network: @@ -160,6 +161,39 @@ def arcd_uc2_network() -> Network: database_server.power_on() network.connect(endpoint_b=database_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[3]) + ddl = """ + CREATE TABLE IF NOT EXISTS user ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name VARCHAR(50) NOT NULL, + email VARCHAR(50) NOT NULL, + age INT, + city VARCHAR(50), + occupation VARCHAR(50) + );""" + + user_insert_statements = [ + "INSERT INTO user (name, email, age, city, occupation) VALUES ('John Doe', 'johndoe@example.com', 32, 'New York', 'Engineer');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Jane Smith', 'janesmith@example.com', 27, 'Los Angeles', 'Designer');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Bob Johnson', 'bobjohnson@example.com', 45, 'Chicago', 'Manager');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Alice Lee', 'alicelee@example.com', 22, 'San Francisco', 'Student');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('David Kim', 'davidkim@example.com', 38, 'Houston', 'Consultant');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Emily Chen', 'emilychen@example.com', 29, 'Seattle', 'Software Developer');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Frank Wang', 'frankwang@example.com', 55, 'New York', 'Entrepreneur');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Grace Park', 'gracepark@example.com', 31, 'Los Angeles', 'Marketing Specialist');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Henry Wu', 'henrywu@example.com', 40, 'Chicago', 'Accountant');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Isabella Kim', 'isabellakim@example.com', 26, 'San Francisco', 'Graphic Designer');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Jake Lee', 'jakelee@example.com', 33, 'Houston', 'Sales Manager');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Kelly Chen', 'kellychen@example.com', 28, 'Seattle', 'Web Developer');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');", + ] + database_server.software_manager.add_service(DatabaseService) + database: DatabaseService = database_server.software_manager.services["Database"] # noqa + database.start() + database._process_sql(ddl) # noqa + for insert_statement in user_insert_statements: + database._process_sql(insert_statement) # noqa + # Backup Server backup_server = Server( hostname="backup_server", ip_address="192.168.1.16", subnet_mask="255.255.255.0", default_gateway="192.168.1.1" @@ -183,4 +217,7 @@ def arcd_uc2_network() -> Network: router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER) + + return network diff --git a/src/primaite/simulator/system/core/packet_capture.py b/src/primaite/simulator/system/core/packet_capture.py index f4521096..79e3630a 100644 --- a/src/primaite/simulator/system/core/packet_capture.py +++ b/src/primaite/simulator/system/core/packet_capture.py @@ -1,3 +1,4 @@ +import json import logging from pathlib import Path from typing import Optional @@ -51,6 +52,13 @@ class PacketCapture: self.logger.addFilter(_JSONFilter()) + def read(self): + frames = [] + with open(self._get_log_path(), "r") as file: + while line := file.readline(): + frames.append(json.loads(line.rstrip())) + return frames + @property def _logger_name(self) -> str: """Get PCAP the logger name.""" diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index be20a28d..aa73410f 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -32,15 +32,14 @@ class Session(SimComponent): """ protocol: IPProtocol - src_ip_address: IPv4Address - dst_ip_address: IPv4Address + with_ip_address: IPv4Address src_port: Optional[Port] dst_port: Optional[Port] connected: bool = False @classmethod def from_session_key( - cls, session_key: Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]] + cls, session_key: Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]] ) -> Session: """ Create a Session instance from a session key tuple. @@ -48,11 +47,10 @@ class Session(SimComponent): :param session_key: Tuple containing the session details. :return: A Session instance. """ - protocol, src_ip_address, dst_ip_address, src_port, dst_port = session_key + protocol, with_ip_address, src_port, dst_port = session_key return Session( protocol=protocol, - src_ip_address=src_ip_address, - dst_ip_address=dst_ip_address, + with_ip_address=with_ip_address, src_port=src_port, dst_port=dst_port, ) @@ -99,8 +97,8 @@ class SessionManager: @staticmethod def _get_session_key( - frame: Frame, from_source: bool = True - ) -> Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]]: + frame: Frame, inbound_frame: bool = True + ) -> Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]: """ Extracts the session key from the given frame. @@ -112,38 +110,38 @@ class SessionManager: - Optional[Port]: The destination port number (if applicable). :param frame: The network frame from which to extract the session key. - :param from_source: A flag to indicate if the key should be extracted from the source or destination. :return: A tuple containing the session key. """ protocol = frame.ip.protocol - src_ip_address = frame.ip.src_ip_address - dst_ip_address = frame.ip.dst_ip_address + with_ip_address = frame.ip.src_ip_address if protocol == IPProtocol.TCP: - if from_source: + if inbound_frame: src_port = frame.tcp.src_port dst_port = frame.tcp.dst_port else: dst_port = frame.tcp.src_port src_port = frame.tcp.dst_port + with_ip_address = frame.ip.dst_ip_address elif protocol == IPProtocol.UDP: - if from_source: + if inbound_frame: src_port = frame.udp.src_port dst_port = frame.udp.dst_port else: dst_port = frame.udp.src_port src_port = frame.udp.dst_port + with_ip_address = frame.ip.dst_ip_address else: src_port = None dst_port = None - return protocol, src_ip_address, dst_ip_address, src_port, dst_port + return protocol, with_ip_address, src_port, dst_port def receive_payload_from_software_manager( - self, - payload: Any, - dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[Port] = None, - session_id: Optional[str] = None, - is_reattempt: bool = False, + self, + payload: Any, + dst_ip_address: Optional[IPv4Address] = None, + dst_port: Optional[Port] = None, + session_id: Optional[str] = None, + is_reattempt: bool = False, ) -> Union[Any, None]: """ Receive a payload from the SoftwareManager. @@ -154,23 +152,21 @@ class SessionManager: :param session_id: The Session ID the payload is to originate from. Optional. If None, one will be created. """ if session_id: - dest_ip_address = self.sessions_by_uuid[session_id].dst_ip_address - dest_port = self.sessions_by_uuid[session_id].dst_port + session = self.sessions_by_uuid[session_id] + dst_ip_address = self.sessions_by_uuid[session_id].with_ip_address + dst_port = self.sessions_by_uuid[session_id].dst_port - dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dest_ip_address) + dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dst_ip_address) if dst_mac_address: - outbound_nic = self.arp_cache.get_arp_cache_nic(dest_ip_address) + outbound_nic = self.arp_cache.get_arp_cache_nic(dst_ip_address) else: if not is_reattempt: - self.arp_cache.send_arp_request(dest_ip_address) + self.arp_cache.send_arp_request(dst_ip_address) return self.receive_payload_from_software_manager( - payload=payload, - dest_ip_address=dest_ip_address, - dest_port=dest_port, - session_id=session_id, - is_reattempt=True, - ) + payload=payload, dst_ip_address=dst_ip_address, dst_port=dst_port, session_id=session_id, + is_reattempt=True + ) else: return @@ -178,17 +174,17 @@ class SessionManager: ethernet=EthernetHeader(src_mac_addr=outbound_nic.mac_address, dst_mac_addr=dst_mac_address), ip=IPPacket( src_ip_address=outbound_nic.ip_address, - dst_ip_address=dest_ip_address, + dst_ip_address=dst_ip_address, ), tcp=TCPHeader( - src_port=dest_port, - dst_port=dest_port, + src_port=dst_port, + dst_port=dst_port, ), payload=payload, ) if not session_id: - session_key = self._get_session_key(frame, from_source=True) + session_key = self._get_session_key(frame, inbound_frame=False) session = self.sessions_by_key.get(session_key) if not session: # Create new session @@ -198,33 +194,25 @@ class SessionManager: outbound_nic.send_frame(frame) - def send_payload_to_software_manager(self, payload: Any, session_id: int): + def receive_frame(self, frame: Frame): """ - Send a payload to the software manager. - - :param payload: The payload to be sent. - :param session_id: The Session ID the payload originates from. - """ - self.software_manager.receive_payload_from_session_manger() - - def receive_payload_from_nic(self, frame: Frame): - """ - Receive a Frame from the NIC. + Receive a Frame. Extract the session key using the _get_session_key method, and forward the payload to the appropriate session. If the session does not exist, a new one is created. :param frame: The frame being received. """ - session_key = self._get_session_key(frame) - session = self.sessions_by_key.get(session_key) + session_key = self._get_session_key(frame, inbound_frame=True) + session: Session = self.sessions_by_key.get(session_key) if not session: # Create new session session = Session.from_session_key(session_key) self.sessions_by_key[session_key] = session self.sessions_by_uuid[session.uuid] = session - self.software_manager.receive_payload_from_session_manger(payload=frame, session=session) - # TODO: Implement the frame deconstruction and send to SoftwareManager. + self.software_manager.receive_payload_from_session_manger( + payload=frame.payload, port=frame.tcp.dst_port, protocol=frame.ip.protocol, session_id=session.uuid + ) def show(self, markdown: bool = False): """ diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index 28e37963..d46cb21c 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union from prettytable import MARKDOWN, PrettyTable +from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application @@ -23,7 +24,7 @@ ServiceClass = TypeVar("ServiceClass", bound=Service) class SoftwareManager: """A class that manages all running Services and Applications on a Node and facilitates their communication.""" - def __init__(self, session_manager: "SessionManager", sys_log: "SysLog"): + def __init__(self, session_manager: "SessionManager", sys_log: SysLog, file_system: FileSystem): """ Initialize a new instance of SoftwareManager. @@ -34,6 +35,7 @@ class SoftwareManager: self.applications: Dict[str, Application] = {} self.port_protocol_mapping: Dict[Tuple[Port, IPProtocol], Union[Service, Application]] = {} self.sys_log: SysLog = sys_log + self.file_system: FileSystem = file_system def add_service(self, service_class: Type[ServiceClass]): """ @@ -41,7 +43,7 @@ class SoftwareManager: :param: service_class: The class of the service to add """ - service = service_class(software_manager=self, sys_log=self.sys_log) + service = service_class(software_manager=self, sys_log=self.sys_log, file_system=self.file_system) service.software_manager = self self.services[service.name] = service @@ -86,7 +88,7 @@ class SoftwareManager: payload: Any, dest_ip_address: Optional[IPv4Address] = None, dest_port: Optional[Port] = None, - session_id: Optional[int] = None, + session_id: Optional[str] = None, ): """ Send a payload to the SessionManager. @@ -97,21 +99,21 @@ class SoftwareManager: :param session_id: The Session ID the payload is to originate from. Optional. """ self.session_manager.receive_payload_from_software_manager( - payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id - ) + payload=payload, dst_ip_address=dest_ip_address, dst_port=dest_port, session_id=session_id + ) - def receive_payload_from_session_manger(self, payload: Any, session: Session): + def receive_payload_from_session_manger(self, payload: Any, port: Port, protocol: IPProtocol, session_id: str): """ Receive a payload from the SessionManager and forward it to the corresponding service or application. :param payload: The payload being received. :param session: The transport session the payload originates from. """ - # receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None) - # if receiver: - # receiver.receive_payload(None, payload) - # else: - # raise ValueError(f"No service or application found for port {port} and protocol {protocol}") + receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None) + if receiver: + receiver.receive(payload=payload, session_id=session_id) + else: + self.sys_log.error(f"No service or application found for port {port} and protocol {protocol}") pass def show(self, markdown: bool = False): diff --git a/src/primaite/simulator/system/services/database.py b/src/primaite/simulator/system/services/database.py index 67ee5cc3..7666597c 100644 --- a/src/primaite/simulator/system/services/database.py +++ b/src/primaite/simulator/system/services/database.py @@ -1,15 +1,61 @@ -from typing import Dict +import sqlite3 +from ipaddress import IPv4Address +from sqlite3 import OperationalError +from typing import Dict, Optional, Any, List, Union -from primaite.simulator.file_system.file_type import FileType -from primaite.simulator.network.hardware.base import Node +from prettytable import PrettyTable, MARKDOWN + +from primaite.simulator.file_system.file_system import File +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.service import Service class DatabaseService(Service): """A generic SQL Server Service.""" + backup_server: Optional[IPv4Address] = None + "The IP Address of the server the " def __init__(self, **kwargs): + kwargs["name"] = "Database" + kwargs["port"] = Port.POSTGRES_SERVER + kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) + self._db_file: File + self._create_db_file() + self._conn = sqlite3.connect(self._db_file.sim_path) + self._cursor = self._conn.cursor() + + def tables(self) -> List[str]: + sql = "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';" + results = self._process_sql(sql) + return [row[0] for row in results["data"]] + + def show(self, markdown: bool = False): + """Prints a Table names in the Database.""" + table = PrettyTable(["Table"]) + if markdown: + table.set_style(MARKDOWN) + table.align = "l" + table.title = f"{self.file_system.sys_log.hostname} Database" + for row in self.tables(): + table.add_row([row]) + print(table) + + def _create_db_file(self): + self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db", real=True) + self.folder = self._db_file.folder + + def _process_sql(self, query: str) -> Dict[str, Union[int, List[Any]]]: + try: + self._cursor.execute(query) + self._conn.commit() + except OperationalError: + # Handle the case where the table does not exist. + return {"status_code": 404, "data": []} + + return {"status_code": 200, "data": self._cursor.fetchall()} def describe_state(self) -> Dict: """ @@ -22,53 +68,12 @@ class DatabaseService(Service): """ return super().describe_state() - @classmethod - def install(cls, node: Node): + def receive(self, payload: Any, session_id: str, **kwargs) -> bool: + result = self._process_sql(payload) + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager(payload=result, session_id=session_id) + return result["status_code"] - def uninstall(self) -> None: - """ - Undo installation procedure. - - This method deletes files created when installing the database, and the database folder if it is empty. - """ - super().uninstall() - node: Node = self.parent - node.file_system.delete_file(self.primary_store) - node.file_system.delete_file(self.transaction_log) - if self.secondary_store: - node.file_system.delete_file(self.secondary_store) - if len(self.folder.files) == 0: - node.file_system.delete_folder(self.folder) - - def install(self) -> None: - """Perform first time install on a node, creating necessary files.""" - super().install() - assert isinstance(self.parent, Node), "Database install can only happen after the db service is added to a node" - self._setup_files() - - def _setup_files( - self, - folder_name: str = "database", - ): - """Set up files that are required by the database on the parent host. - - :param folder_name: Name of the folder which will be setup to hold the db files, defaults to "database" - :type folder_name: str, optional - """ - # note that this parent.file_system.create_folder call in the future will be authenticated by using permissions - # handler. This permission will be granted based on service account given to the database service. - self.parent: Node - self.folder = self.parent.file_system.create_folder(folder_name) - self.primary_store = self.parent.file_system.create_file( - "db_primary_store", db_size, FileType.MDF, folder=self.folder - ) - self.transaction_log = self.parent.file_system.create_file( - "db_transaction_log", "1", FileType.LDF, folder=self.folder - ) - if use_secondary_db_file: - self.secondary_store = self.parent.file_system.create_file( - "db_secondary_store", secondary_db_size, FileType.NDF, folder=self.folder - ) - else: - self.secondary_store = None + def send(self, payload: Any, session_id: str, **kwargs) -> bool: + pass diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 7f206311..70c1bbf2 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -1,8 +1,9 @@ from abc import abstractmethod from enum import Enum -from typing import Any, Dict +from typing import Any, Dict, Optional from primaite.simulator.core import Action, ActionManager, SimComponent +from primaite.simulator.file_system.file_system import FileSystem, Folder from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.sys_log import SysLog @@ -79,6 +80,10 @@ class Software(SimComponent): "An instance of Software Manager that is used by the parent node." sys_log: SysLog = None "An instance of SysLog that is used by the parent node." + file_system: FileSystem + "The FileSystem of the Node the Software is installed on." + folder: Optional[Folder] = None + "The folder on the file system the Software uses." def _init_action_manager(self) -> ActionManager: am = super()._init_action_manager() @@ -216,7 +221,6 @@ class IOSoftware(Software): :param kwargs: Additional keyword arguments specific to the implementation. :return: True if the payload was successfully sent, False otherwise. """ - pass def receive(self, payload: Any, session_id: str, **kwargs) -> bool: """ diff --git a/tests/conftest.py b/tests/conftest.py index f1c05187..5570e21f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,17 @@ ACTION_SPACE_NODE_ACTION_VALUES = 1 _LOGGER = getLogger(__name__) +# PrimAITE v3 stuff +from primaite.simulator.file_system.file_system import FileSystem +from primaite.simulator.network.hardware.base import Node + +@pytest.fixture(scope="function") +def file_system() -> FileSystem: + return Node(hostname="fs_node").file_system + + +#PrimAITE v2 stuff class TempPrimaiteSession(PrimaiteSession): """ A temporary PrimaiteSession class. diff --git a/tests/integration_tests/system/test_database_on_node.py b/tests/integration_tests/system/test_database_on_node.py index ef2a58e4..0d66137b 100644 --- a/tests/integration_tests/system/test_database_on_node.py +++ b/tests/integration_tests/system/test_database_on_node.py @@ -1,52 +1,46 @@ -from primaite.simulator.network.hardware.base import Node -from primaite.simulator.network.transmission.transport_layer import Port -from primaite.simulator.system.services.database import DatabaseService -from primaite.simulator.system.services.service import ServiceOperatingState -from primaite.simulator.system.software import SoftwareCriticality, SoftwareHealthState +from ipaddress import IPv4Address + +from primaite.simulator.network.hardware.nodes.computer import Computer +from primaite.simulator.network.networks import arcd_uc2_network +from primaite.simulator.network.transmission.data_link_layer import Frame, EthernetHeader +from primaite.simulator.network.transmission.network_layer import IPPacket, Precedence +from primaite.simulator.network.transmission.transport_layer import TCPHeader, Port -def test_installing_database(): - db = DatabaseService( - name="SQL-database", - health_state_actual=SoftwareHealthState.GOOD, - health_state_visible=SoftwareHealthState.GOOD, - criticality=SoftwareCriticality.MEDIUM, - port=Port.SQL_SERVER, - operating_state=ServiceOperatingState.RUNNING, +def test_database_query_across_the_network(): + """Tests DB query across the network returns HTTP status 200 and date.""" + network = arcd_uc2_network() + + client_1: Computer = network.get_node_by_hostname("client_1") + + client_1.arp.send_arp_request(IPv4Address("192.168.1.14")) + + dst_mac_address = client_1.arp.get_arp_cache_mac_address(IPv4Address("192.168.1.14")) + + outbound_nic = client_1.arp.get_arp_cache_nic(IPv4Address("192.168.1.14")) + client_1.ping("192.168.1.14") + + + frame = Frame( + ethernet=EthernetHeader( + src_mac_addr=client_1.ethernet_port[1].mac_address, + dst_mac_addr=dst_mac_address + ), + ip=IPPacket( + src_ip_address=client_1.ethernet_port[1].ip_address, + dst_ip_address=IPv4Address("192.168.1.14"), + precedence=Precedence.FLASH + ), + tcp=TCPHeader( + src_port=Port.POSTGRES_SERVER, + dst_port=Port.POSTGRES_SERVER + ), + payload="SELECT * FROM user;" ) - node = Node(hostname="db-server") + outbound_nic.send_frame(frame) - node.install_service(db) + client_1_last_payload = outbound_nic.pcap.read()[-1]["payload"] - assert db in node - - file_exists = False - for folder in node.file_system.folders.values(): - for file in folder.files.values(): - if file.name == "db_primary_store": - file_exists = True - break - if file_exists: - break - assert file_exists - - -def test_uninstalling_database(): - db = DatabaseService( - name="SQL-database", - health_state_actual=SoftwareHealthState.GOOD, - health_state_visible=SoftwareHealthState.GOOD, - criticality=SoftwareCriticality.MEDIUM, - port=Port.SQL_SERVER, - operating_state=ServiceOperatingState.RUNNING, - ) - - node = Node(hostname="db-server") - - node.install_service(db) - - node.uninstall_service(db) - - assert db not in node - assert node.file_system.get_folder("database") is None + assert client_1_last_payload["status_code"] == 200 + assert client_1_last_payload["data"] \ No newline at end of file diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py index 136961e2..d1d78003 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py @@ -1,19 +1,13 @@ import pytest -from primaite.simulator.file_system.file_system import File, FileSystem, Folder +from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.file_system.file_type import FileType -from primaite.simulator.network.hardware.base import Node - - -@pytest.fixture(scope="function") -def file_system() -> FileSystem: - return Node(hostname="fs_node").file_system def test_create_folder_and_file(file_system): """Test creating a folder and a file.""" assert len(file_system.folders) == 1 - test_folder = file_system.create_folder(folder_name="test_folder") + file_system.create_folder(folder_name="test_folder") assert len(file_system.folders) is 2 file_system.create_file(file_name="test_file.txt", folder_name="test_folder") @@ -115,7 +109,7 @@ def test_copy_file(file_system): file_system.create_folder(folder_name="src_folder") file_system.create_folder(folder_name="dst_folder") - file = file_system.create_file(file_name="test_file.txt", size=10, folder_name="src_folder") + file = file_system.create_file(file_name="test_file.txt", size=10, folder_name="src_folder", real=True) original_uuid = file.uuid assert len(file_system.get_folder("src_folder").files) == 1 @@ -128,6 +122,19 @@ def test_copy_file(file_system): assert file_system.get_file("dst_folder", "test_file.txt").uuid != original_uuid +def test_folder_quarantine_state(file_system): + """Tests the changing of folder quarantine status.""" + folder = file_system.get_folder("root") + + assert folder.quarantine_status() is False + + folder.quarantine() + assert folder.quarantine_status() is True + + folder.unquarantine() + assert folder.quarantine_status() is False + + @pytest.mark.skip(reason="Skipping until we tackle serialisation") def test_serialisation(file_system): """Test to check that the object serialisation works correctly.""" diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_file.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_file.py deleted file mode 100644 index 981550f3..00000000 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_file.py +++ /dev/null @@ -1,23 +0,0 @@ -from primaite.simulator.file_system.file_system import File -from primaite.simulator.file_system.file_type import FileType - - -def test_file_type(): - """Tests tha the File type is set correctly.""" - file = File(name="test", file_type=FileType.DOC) - assert file.file_type is FileType.DOC - - -def test_get_size(): - """Tests that the file size is being returned properly.""" - file = File(name="test", size=1.5) - assert file.size == 1.5 - - -def test_serialisation(): - """Test to check that the object serialisation works correctly.""" - file = File(name="test", size=1.5, file_type=FileType.DOC) - serialised_file = file.model_dump_json() - deserialised_file = File.model_validate_json(serialised_file) - - assert file.model_dump_json() == deserialised_file.model_dump_json() diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_folder.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_folder.py deleted file mode 100644 index 72684146..00000000 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_folder.py +++ /dev/null @@ -1,75 +0,0 @@ -from primaite.simulator.file_system.file_system import File -from primaite.simulator.file_system.file_system_folder import Folder -from primaite.simulator.file_system.file_type import FileType - - -def test_adding_removing_file(): - """Test the adding and removing of a file from a folder.""" - folder = Folder(name="test") - - file = File(name="test_file", size=10, file_type=FileType.DOC) - - folder.add_file(file) - assert folder.size == 10 - assert len(folder.files) is 1 - - folder.remove_file(file) - assert folder.size == 0 - assert len(folder.files) is 0 - - -def test_remove_non_existent_file(): - """Test the removing of a file that does not exist.""" - folder = Folder(name="test") - - file = File(name="test_file", size=10, file_type=FileType.DOC) - not_added_file = File(name="fake_file", size=10, file_type=FileType.DOC) - - folder.add_file(file) - assert folder.size == 10 - assert len(folder.files) is 1 - - folder.remove_file(not_added_file) - assert folder.size == 10 - assert len(folder.files) is 1 - - -def test_get_file_by_id(): - """Test to make sure that the correct file is returned.""" - folder = Folder(name="test") - - file = File(name="test_file", size=10, file_type=FileType.DOC) - file2 = File(name="test_file_2", size=10, file_type=FileType.DOC) - - folder.add_file(file) - folder.add_file(file2) - assert folder.size == 20 - assert len(folder.files) is 2 - - assert folder.get_file_by_id(file_id=file.uuid) is file - - -def test_folder_quarantine_state(): - """Tests the changing of folder quarantine status.""" - folder = Folder(name="test") - - assert folder.quarantine_status() is False - - folder.quarantine() - assert folder.quarantine_status() is True - - folder.unquarantine() - assert folder.quarantine_status() is False - - -def test_serialisation(): - """Test to check that the object serialisation works correctly.""" - folder = Folder(name="test") - file = File(name="test_file", size=10, file_type=FileType.DOC) - folder.add_file(file) - - serialised_folder = folder.model_dump_json() - - deserialised_folder = Folder.model_validate_json(serialised_folder) - - assert folder.model_dump_json() == deserialised_folder.model_dump_json() diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulator_service.py b/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulator_service.py index f5b37175..9496a50e 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulator_service.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulator_service.py @@ -29,4 +29,4 @@ def test_creation(): assert False, f"Test was not supposed to throw exception: {e}" # there should be a session after the service is started - assert len(client_1.session_manager.sessions_by_uuid) == 1 + assert len(client_1.session_manager.sessions_by_uuid) == 1 \ No newline at end of file diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py index ebc5536f..acc05d17 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py @@ -1,15 +1,59 @@ -from primaite.simulator.network.transmission.transport_layer import Port +import json + +import pytest + +from primaite.simulator.network.hardware.base import Node from primaite.simulator.system.services.database import DatabaseService -from primaite.simulator.system.services.service import ServiceOperatingState -from primaite.simulator.system.software import SoftwareCriticality, SoftwareHealthState + +DDL = """ +CREATE TABLE IF NOT EXISTS user ( +id INTEGER PRIMARY KEY AUTOINCREMENT, +name VARCHAR(50) NOT NULL, +email VARCHAR(50) NOT NULL, +age INT, +city VARCHAR(50), +occupation VARCHAR(50) +);""" + +USER_INSERT_STATEMENTS = [ + "INSERT INTO user (name, email, age, city, occupation) VALUES ('John Doe', 'johndoe@example.com', 32, 'New York', 'Engineer');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Jane Smith', 'janesmith@example.com', 27, 'Los Angeles', 'Designer');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Bob Johnson', 'bobjohnson@example.com', 45, 'Chicago', 'Manager');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Alice Lee', 'alicelee@example.com', 22, 'San Francisco', 'Student');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('David Kim', 'davidkim@example.com', 38, 'Houston', 'Consultant');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Emily Chen', 'emilychen@example.com', 29, 'Seattle', 'Software Developer');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Frank Wang', 'frankwang@example.com', 55, 'New York', 'Entrepreneur');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Grace Park', 'gracepark@example.com', 31, 'Los Angeles', 'Marketing Specialist');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Henry Wu', 'henrywu@example.com', 40, 'Chicago', 'Accountant');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Isabella Kim', 'isabellakim@example.com', 26, 'San Francisco', 'Graphic Designer');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Jake Lee', 'jakelee@example.com', 33, 'Houston', 'Sales Manager');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Kelly Chen', 'kellychen@example.com', 28, 'Seattle', 'Web Developer');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');", +] -def test_creation(): - db = DatabaseService( - name="SQL-database", - health_state_actual=SoftwareHealthState.GOOD, - health_state_visible=SoftwareHealthState.GOOD, - criticality=SoftwareCriticality.MEDIUM, - port=Port.SQL_SERVER, - operating_state=ServiceOperatingState.RUNNING, - ) +@pytest.fixture(scope="function") +def database_server() -> Node: + node = Node(hostname="db_node") + node.software_manager.add_service(DatabaseService) + node.software_manager.services["Database"].start() + return node + + +@pytest.fixture(scope="function") +def database(database_server) -> DatabaseService: + database: DatabaseService = database_server.software_manager.services["Database"] # noqa + database.receive(DDL, None) + for script in USER_INSERT_STATEMENTS: + database.receive(script, None) + return database + + +def test_creation(database_server): + database_server.software_manager.show() + + +def test_db_population(database): + database.show() + assert database.tables() == ["user"] \ No newline at end of file