diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index d5e81e1b..b2037729 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -3,7 +3,6 @@ from __future__ import annotations import math import os.path import shutil -from abc import abstractmethod from pathlib import Path from typing import Dict, Optional @@ -17,7 +16,7 @@ from primaite.simulator.system.core.sys_log import SysLog _LOGGER = getLogger(__name__) -def convert_size(size_bytes): +def convert_size(size_bytes: int) -> str: """ Convert a file size from bytes to a string with a more human-readable format. @@ -44,7 +43,11 @@ def convert_size(size_bytes): class FileSystemItemABC(SimComponent): - """Abstract base class for file system items used in the file system simulation.""" + """ + Abstract base class for file system items used in the file system simulation. + + :ivar name: The name of the FileSystemItemABC. + """ name: str "The name of the FileSystemItemABC." @@ -64,7 +67,15 @@ class FileSystemItemABC(SimComponent): return state @property - def size_str(self): + def size_str(self) -> str: + """ + Get the file size in a human-readable string format. + + This property makes use of the :func:`convert_size` function to convert the `self.size` attribute to a string + that is easier to read and understand. + + :return: The human-readable string representation of the file size. + """ return convert_size(self.size) @@ -84,11 +95,21 @@ class FileSystem(SimComponent): self.create_folder("root") @property - def size(self): + def size(self) -> int: + """ + Calculate and return the total size of all folders in the file system. + + :return: The sum of the sizes of all folders in the file system. + """ return sum(folder.size for folder in self.folders.values()) def show(self, markdown: bool = False, full: bool = False): - """Prints a of the FileSystem""" + """ + Prints a table of the FileSystem, displaying either just folders or full files. + + :param markdown: Flag indicating if output should be in markdown format. + :param full: Flag indicating if to show full files. + """ headers = ["Folder", "Size"] if full: headers[0] = "File Path" @@ -171,7 +192,6 @@ class FileSystem(SimComponent): :param folder_name: The folder to add the file to. :param real: "Indicates whether the File is actually a real file in the Node sim fs output." """ - if folder_name: # check if file with name already exists folder = self._folders_by_name.get(folder_name) @@ -196,12 +216,25 @@ class FileSystem(SimComponent): return file def get_file(self, folder_name: str, file_name: str) -> Optional[File]: + """ + Retrieve a file by its name from a specific folder. + + :param folder_name: The name of the folder where the file resides. + :param file_name: The name of the file to be retrieved, including its extension. + :return: An instance of File if it exists, otherwise `None`. + """ folder = self.get_folder(folder_name) if folder: return folder.get_file(file_name) self.fs.sys_log.info(f"file not found /{folder_name}/{file_name}") def delete_file(self, folder_name: str, file_name: str): + """ + Delete a file by its name from a specific folder. + + :param folder_name: The name of the folder containing the file. + :param file_name: The name of the file to be deleted, including its extension. + """ folder = self.get_folder(folder_name) if folder: file = folder.get_file(file_name) @@ -209,7 +242,14 @@ class FileSystem(SimComponent): folder.remove_file(file) self.sys_log.info(f"Deleted file /{file.path}") - def move_file(self, src_folder_name: str, src_file_name: str, dst_folder_name): + def move_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str): + """ + Move a file from one folder to another. + + :param src_folder_name: The name of the source folder containing the file. + :param src_file_name: The name of the file to be moved. + :param dst_folder_name: The name of the destination folder. + """ file = self.get_file(folder_name=src_folder_name, file_name=src_file_name) if file: src_folder = file.folder @@ -227,8 +267,14 @@ class FileSystem(SimComponent): file.sim_path.parent.mkdir(exist_ok=True) shutil.move(old_sim_path, file.sim_path) - def copy_file(self, src_folder_name: str, src_file_name: str, dst_folder_name): + def copy_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str): + """ + Copy a file from one folder to another. + :param src_folder_name: The name of the source folder containing the file. + :param src_file_name: The name of the file to be copied. + :param dst_folder_name: The name of the destination folder. + """ file = self.get_file(folder_name=src_folder_name, file_name=src_file_name) if file: dst_folder = self.get_folder(folder_name=dst_folder_name) @@ -283,7 +329,11 @@ class Folder(FileSystemItemABC): return state def show(self, markdown: bool = False): - """Prints a of the Folder""" + """ + Display the contents of the Folder in tabular format. + + :param markdown: Whether to display the table in Markdown format or not. Default is `False`. + """ table = PrettyTable(["File", "Size"]) if markdown: table.set_style(MARKDOWN) @@ -294,7 +344,13 @@ class Folder(FileSystemItemABC): print(table.get_string(sortby="File")) @property - def size(self): + def size(self) -> int: + """ + Calculate and return the total size of all files in the folder. + + :return: The total size of all files in the folder. If no files exist or all have `None` + size, returns 0. + """ return sum(file.size for file in self.files.values() if file.size is not None) def get_file(self, file_name: str) -> Optional[File]: @@ -313,14 +369,19 @@ class Folder(FileSystemItemABC): """ Get a file by its uuid. - :param file_uuid: The file uuid. :return: The matching File. """ return self.files.get(file_uuid) def add_file(self, file: File): - """Adds a file to the folder list.""" + """ + Adds a file to the folder. + + :param File file: The File object to be added to the folder. + :raises Exception: If the provided `file` parameter is None or not an instance of the + `File` class. + """ if file is None or not isinstance(file, File): raise Exception(f"Invalid file: {file}") @@ -340,7 +401,6 @@ class Folder(FileSystemItemABC): The method can take a File object or a file id. :param file: The file to remove - :type: Optional[File] """ if file is None or not isinstance(file, File): raise Exception(f"Invalid file: {file}") @@ -369,7 +429,15 @@ class Folder(FileSystemItemABC): class File(FileSystemItemABC): - """Class that represents a file in the simulation.""" + """ + Class representing a file in the simulation. + + :ivar Folder folder: The folder in which the file resides. + :ivar FileType file_type: The type of the file. + :ivar Optional[int] sim_size: The simulated file size. + :ivar bool real: Indicates if the file is actually a real file in the Node sim fs output. + :ivar Optional[Path] sim_path: The path if the file is real. + """ folder: Folder "The Folder the File is in." @@ -415,16 +483,30 @@ class File(FileSystemItemABC): pass def make_copy(self, dst_folder: Folder) -> File: + """ + Create a copy of the current File object in the given destination folder. + + :param Folder dst_folder: The destination folder for the copied file. + :return: A new File object that is a copy of the current file. + """ return File(folder=dst_folder, **self.model_dump(exclude={"uuid", "folder", "sim_path"})) @property - def path(self): - """The path of the file in the FileSystem.""" + def path(self) -> str: + """ + Get the path of the file in the file system. + + :return: The full path of the file. + """ return f"{self.folder.name}/{self.name}" @property def size(self) -> int: - """The file size in Bytes.""" + """ + Get the size of the file in bytes. + + :return: The size of the file in bytes. + """ if self.real: return os.path.getsize(self.sim_path) return self.sim_size diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 832e6a13..fa1058cf 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -929,7 +929,7 @@ class Node(SimComponent): kwargs["software_manager"] = SoftwareManager( sys_log=kwargs.get("sys_log"), session_manager=kwargs.get("session_manager"), - file_system=kwargs.get("file_system") + file_system=kwargs.get("file_system"), ) super().__init__(**kwargs) self.arp.nics = self.nics diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index cecb108d..c030d907 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -172,20 +172,20 @@ def arcd_uc2_network() -> Network: );""" user_insert_statements = [ - "INSERT INTO user (name, email, age, city, occupation) VALUES ('John Doe', 'johndoe@example.com', 32, 'New York', 'Engineer');", - "INSERT INTO user (name, email, age, city, occupation) VALUES ('Jane Smith', 'janesmith@example.com', 27, 'Los Angeles', 'Designer');", - "INSERT INTO user (name, email, age, city, occupation) VALUES ('Bob Johnson', 'bobjohnson@example.com', 45, 'Chicago', 'Manager');", - "INSERT INTO user (name, email, age, city, occupation) VALUES ('Alice Lee', 'alicelee@example.com', 22, 'San Francisco', 'Student');", - "INSERT INTO user (name, email, age, city, occupation) VALUES ('David Kim', 'davidkim@example.com', 38, 'Houston', 'Consultant');", - "INSERT INTO user (name, email, age, city, occupation) VALUES ('Emily Chen', 'emilychen@example.com', 29, 'Seattle', 'Software Developer');", - "INSERT INTO user (name, email, age, city, occupation) VALUES ('Frank Wang', 'frankwang@example.com', 55, 'New York', 'Entrepreneur');", - "INSERT INTO user (name, email, age, city, occupation) VALUES ('Grace Park', 'gracepark@example.com', 31, 'Los Angeles', 'Marketing Specialist');", - "INSERT INTO user (name, email, age, city, occupation) VALUES ('Henry Wu', 'henrywu@example.com', 40, 'Chicago', 'Accountant');", - "INSERT INTO user (name, email, age, city, occupation) VALUES ('Isabella Kim', 'isabellakim@example.com', 26, 'San Francisco', 'Graphic Designer');", - "INSERT INTO user (name, email, age, city, occupation) VALUES ('Jake Lee', 'jakelee@example.com', 33, 'Houston', 'Sales Manager');", - "INSERT INTO user (name, email, age, city, occupation) VALUES ('Kelly Chen', 'kellychen@example.com', 28, 'Seattle', 'Web Developer');", - "INSERT INTO user (name, email, age, city, occupation) VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');", - "INSERT INTO user (name, email, age, city, occupation) VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');", + "INSERT INTO user (name, email, age, city, occupation) VALUES ('John Doe', 'johndoe@example.com', 32, 'New York', 'Engineer');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Jane Smith', 'janesmith@example.com', 27, 'Los Angeles', 'Designer');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Bob Johnson', 'bobjohnson@example.com', 45, 'Chicago', 'Manager');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Alice Lee', 'alicelee@example.com', 22, 'San Francisco', 'Student');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('David Kim', 'davidkim@example.com', 38, 'Houston', 'Consultant');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Emily Chen', 'emilychen@example.com', 29, 'Seattle', 'Software Developer');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Frank Wang', 'frankwang@example.com', 55, 'New York', 'Entrepreneur');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Grace Park', 'gracepark@example.com', 31, 'Los Angeles', 'Marketing Specialist');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Henry Wu', 'henrywu@example.com', 40, 'Chicago', 'Accountant');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Isabella Kim', 'isabellakim@example.com', 26, 'San Francisco', 'Graphic Designer');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Jake Lee', 'jakelee@example.com', 33, 'Houston', 'Sales Manager');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Kelly Chen', 'kellychen@example.com', 28, 'Seattle', 'Web Developer');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');", # noqa ] database_server.software_manager.add_service(DatabaseService) database: DatabaseService = database_server.software_manager.services["Database"] # noqa @@ -219,5 +219,4 @@ def arcd_uc2_network() -> Network: router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER) - return network diff --git a/src/primaite/simulator/system/core/packet_capture.py b/src/primaite/simulator/system/core/packet_capture.py index 79e3630a..b1e35a77 100644 --- a/src/primaite/simulator/system/core/packet_capture.py +++ b/src/primaite/simulator/system/core/packet_capture.py @@ -1,7 +1,7 @@ import json import logging from pathlib import Path -from typing import Optional +from typing import Any, Dict, List, Optional from primaite.simulator import SIM_OUTPUT @@ -52,7 +52,12 @@ class PacketCapture: self.logger.addFilter(_JSONFilter()) - def read(self): + def read(self) -> List[Dict[Any]]: + """ + Read packet capture logs and return them as a list of dictionaries. + + :return: List of frames captured, represented as dictionaries. + """ frames = [] with open(self._get_log_path(), "r") as file: while line := file.readline(): diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index aa73410f..71b7dcec 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -38,9 +38,7 @@ class Session(SimComponent): connected: bool = False @classmethod - def from_session_key( - cls, session_key: Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]] - ) -> Session: + def from_session_key(cls, session_key: Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]) -> Session: """ Create a Session instance from a session key tuple. @@ -97,7 +95,7 @@ class SessionManager: @staticmethod def _get_session_key( - frame: Frame, inbound_frame: bool = True + frame: Frame, inbound_frame: bool = True ) -> Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]: """ Extracts the session key from the given frame. @@ -136,12 +134,12 @@ class SessionManager: return protocol, with_ip_address, src_port, dst_port def receive_payload_from_software_manager( - self, - payload: Any, - dst_ip_address: Optional[IPv4Address] = None, - dst_port: Optional[Port] = None, - session_id: Optional[str] = None, - is_reattempt: bool = False, + self, + payload: Any, + dst_ip_address: Optional[IPv4Address] = None, + dst_port: Optional[Port] = None, + session_id: Optional[str] = None, + is_reattempt: bool = False, ) -> Union[Any, None]: """ Receive a payload from the SoftwareManager. @@ -164,9 +162,12 @@ class SessionManager: if not is_reattempt: self.arp_cache.send_arp_request(dst_ip_address) return self.receive_payload_from_software_manager( - payload=payload, dst_ip_address=dst_ip_address, dst_port=dst_port, session_id=session_id, - is_reattempt=True - ) + payload=payload, + dst_ip_address=dst_ip_address, + dst_port=dst_port, + session_id=session_id, + is_reattempt=True, + ) else: return diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index d46cb21c..13d4524c 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -7,7 +7,6 @@ from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application -from primaite.simulator.system.core.session_manager import Session from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.services.service import Service from primaite.simulator.system.software import SoftwareType @@ -100,7 +99,7 @@ class SoftwareManager: """ self.session_manager.receive_payload_from_software_manager( payload=payload, dst_ip_address=dest_ip_address, dst_port=dest_port, session_id=session_id - ) + ) def receive_payload_from_session_manger(self, payload: Any, port: Port, protocol: IPProtocol, session_id: str): """ diff --git a/src/primaite/simulator/system/services/database.py b/src/primaite/simulator/system/services/database.py index 7666597c..c02b2872 100644 --- a/src/primaite/simulator/system/services/database.py +++ b/src/primaite/simulator/system/services/database.py @@ -1,9 +1,9 @@ import sqlite3 from ipaddress import IPv4Address from sqlite3 import OperationalError -from typing import Dict, Optional, Any, List, Union +from typing import Any, Dict, List, Optional, Union -from prettytable import PrettyTable, MARKDOWN +from prettytable import MARKDOWN, PrettyTable from primaite.simulator.file_system.file_system import File from primaite.simulator.network.transmission.network_layer import IPProtocol @@ -13,7 +13,12 @@ from primaite.simulator.system.services.service import Service class DatabaseService(Service): - """A generic SQL Server Service.""" + """ + A class for simulating a generic SQL Server service. + + This class inherits from the `Service` class and provides methods to manage and query a SQLite database. + """ + backup_server: Optional[IPv4Address] = None "The IP Address of the server the " @@ -28,12 +33,21 @@ class DatabaseService(Service): self._cursor = self._conn.cursor() def tables(self) -> List[str]: + """ + Get a list of table names present in the database. + + :return: List of table names. + """ sql = "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';" results = self._process_sql(sql) return [row[0] for row in results["data"]] def show(self, markdown: bool = False): - """Prints a Table names in the Database.""" + """ + Prints a list of table names in the database using PrettyTable. + + :param markdown: Whether to output the table in Markdown format. + """ table = PrettyTable(["Table"]) if markdown: table.set_style(MARKDOWN) @@ -44,10 +58,17 @@ class DatabaseService(Service): print(table) def _create_db_file(self): + """Creates the Simulation File and sqlite file in the file system.""" self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db", real=True) self.folder = self._db_file.folder def _process_sql(self, query: str) -> Dict[str, Union[int, List[Any]]]: + """ + Executes the given SQL query and returns the result. + + :param query: The SQL query to be executed. + :return: Dictionary containing status code and data fetched. + """ try: self._cursor.execute(query) self._conn.commit() @@ -69,11 +90,15 @@ class DatabaseService(Service): return super().describe_state() def receive(self, payload: Any, session_id: str, **kwargs) -> bool: + """ + Processes the incoming SQL payload and sends the result back. + + :param payload: The SQL query to be executed. + :param session_id: The session identifier. + :return: The status code of the SQL execution. + """ result = self._process_sql(payload) software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager(payload=result, session_id=session_id) - return result["status_code"] - - def send(self, payload: Any, session_id: str, **kwargs) -> bool: - pass + return result["status_code"] == 200 diff --git a/tests/conftest.py b/tests/conftest.py index 5570e21f..9c216a8e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,7 +29,7 @@ def file_system() -> FileSystem: return Node(hostname="fs_node").file_system -#PrimAITE v2 stuff +# PrimAITE v2 stuff class TempPrimaiteSession(PrimaiteSession): """ A temporary PrimaiteSession class. diff --git a/tests/integration_tests/system/test_database_on_node.py b/tests/integration_tests/system/test_database_on_node.py index 0d66137b..7ad11222 100644 --- a/tests/integration_tests/system/test_database_on_node.py +++ b/tests/integration_tests/system/test_database_on_node.py @@ -2,9 +2,9 @@ from ipaddress import IPv4Address from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.networks import arcd_uc2_network -from primaite.simulator.network.transmission.data_link_layer import Frame, EthernetHeader +from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame from primaite.simulator.network.transmission.network_layer import IPPacket, Precedence -from primaite.simulator.network.transmission.transport_layer import TCPHeader, Port +from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader def test_database_query_across_the_network(): @@ -20,22 +20,15 @@ def test_database_query_across_the_network(): outbound_nic = client_1.arp.get_arp_cache_nic(IPv4Address("192.168.1.14")) client_1.ping("192.168.1.14") - frame = Frame( - ethernet=EthernetHeader( - src_mac_addr=client_1.ethernet_port[1].mac_address, - dst_mac_addr=dst_mac_address - ), + ethernet=EthernetHeader(src_mac_addr=client_1.ethernet_port[1].mac_address, dst_mac_addr=dst_mac_address), ip=IPPacket( src_ip_address=client_1.ethernet_port[1].ip_address, dst_ip_address=IPv4Address("192.168.1.14"), - precedence=Precedence.FLASH + precedence=Precedence.FLASH, ), - tcp=TCPHeader( - src_port=Port.POSTGRES_SERVER, - dst_port=Port.POSTGRES_SERVER - ), - payload="SELECT * FROM user;" + tcp=TCPHeader(src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER), + payload="SELECT * FROM user;", ) outbound_nic.send_frame(frame) @@ -43,4 +36,4 @@ def test_database_query_across_the_network(): client_1_last_payload = outbound_nic.pcap.read()[-1]["payload"] assert client_1_last_payload["status_code"] == 200 - assert client_1_last_payload["data"] \ No newline at end of file + assert client_1_last_payload["data"] diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulator_service.py b/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulator_service.py index 9496a50e..f5b37175 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulator_service.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulator_service.py @@ -29,4 +29,4 @@ def test_creation(): assert False, f"Test was not supposed to throw exception: {e}" # there should be a session after the service is started - assert len(client_1.session_manager.sessions_by_uuid) == 1 \ No newline at end of file + assert len(client_1.session_manager.sessions_by_uuid) == 1 diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py index acc05d17..f3751f27 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py @@ -56,4 +56,4 @@ def test_creation(database_server): def test_db_population(database): database.show() - assert database.tables() == ["user"] \ No newline at end of file + assert database.tables() == ["user"]