#1816 - Added the final pieces of the puzzle to get data up from NIC → session manager → software manager → service.

- Implemented a basic sim DB that matches UC2 data manipulation DB in IY.
- Added a test that confirms DB queries can be sent over the network.
This commit is contained in:
Chris McCarthy
2023-09-06 22:26:23 +01:00
parent 6b41bec32a
commit 2f744af34e
11 changed files with 180 additions and 76 deletions

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import math
import os.path
import shutil
from abc import abstractmethod
from pathlib import Path
from typing import Dict, Optional
@@ -17,7 +16,7 @@ from primaite.simulator.system.core.sys_log import SysLog
_LOGGER = getLogger(__name__)
def convert_size(size_bytes):
def convert_size(size_bytes: int) -> str:
"""
Convert a file size from bytes to a string with a more human-readable format.
@@ -44,7 +43,11 @@ def convert_size(size_bytes):
class FileSystemItemABC(SimComponent):
"""Abstract base class for file system items used in the file system simulation."""
"""
Abstract base class for file system items used in the file system simulation.
:ivar name: The name of the FileSystemItemABC.
"""
name: str
"The name of the FileSystemItemABC."
@@ -64,7 +67,15 @@ class FileSystemItemABC(SimComponent):
return state
@property
def size_str(self):
def size_str(self) -> str:
"""
Get the file size in a human-readable string format.
This property makes use of the :func:`convert_size` function to convert the `self.size` attribute to a string
that is easier to read and understand.
:return: The human-readable string representation of the file size.
"""
return convert_size(self.size)
@@ -84,11 +95,21 @@ class FileSystem(SimComponent):
self.create_folder("root")
@property
def size(self):
def size(self) -> int:
"""
Calculate and return the total size of all folders in the file system.
:return: The sum of the sizes of all folders in the file system.
"""
return sum(folder.size for folder in self.folders.values())
def show(self, markdown: bool = False, full: bool = False):
"""Prints a of the FileSystem"""
"""
Prints a table of the FileSystem, displaying either just folders or full files.
:param markdown: Flag indicating if output should be in markdown format.
:param full: Flag indicating if to show full files.
"""
headers = ["Folder", "Size"]
if full:
headers[0] = "File Path"
@@ -171,7 +192,6 @@ class FileSystem(SimComponent):
:param folder_name: The folder to add the file to.
:param real: "Indicates whether the File is actually a real file in the Node sim fs output."
"""
if folder_name:
# check if file with name already exists
folder = self._folders_by_name.get(folder_name)
@@ -196,12 +216,25 @@ class FileSystem(SimComponent):
return file
def get_file(self, folder_name: str, file_name: str) -> Optional[File]:
"""
Retrieve a file by its name from a specific folder.
:param folder_name: The name of the folder where the file resides.
:param file_name: The name of the file to be retrieved, including its extension.
:return: An instance of File if it exists, otherwise `None`.
"""
folder = self.get_folder(folder_name)
if folder:
return folder.get_file(file_name)
self.fs.sys_log.info(f"file not found /{folder_name}/{file_name}")
def delete_file(self, folder_name: str, file_name: str):
"""
Delete a file by its name from a specific folder.
:param folder_name: The name of the folder containing the file.
:param file_name: The name of the file to be deleted, including its extension.
"""
folder = self.get_folder(folder_name)
if folder:
file = folder.get_file(file_name)
@@ -209,7 +242,14 @@ class FileSystem(SimComponent):
folder.remove_file(file)
self.sys_log.info(f"Deleted file /{file.path}")
def move_file(self, src_folder_name: str, src_file_name: str, dst_folder_name):
def move_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str):
"""
Move a file from one folder to another.
:param src_folder_name: The name of the source folder containing the file.
:param src_file_name: The name of the file to be moved.
:param dst_folder_name: The name of the destination folder.
"""
file = self.get_file(folder_name=src_folder_name, file_name=src_file_name)
if file:
src_folder = file.folder
@@ -227,8 +267,14 @@ class FileSystem(SimComponent):
file.sim_path.parent.mkdir(exist_ok=True)
shutil.move(old_sim_path, file.sim_path)
def copy_file(self, src_folder_name: str, src_file_name: str, dst_folder_name):
def copy_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str):
"""
Copy a file from one folder to another.
:param src_folder_name: The name of the source folder containing the file.
:param src_file_name: The name of the file to be copied.
:param dst_folder_name: The name of the destination folder.
"""
file = self.get_file(folder_name=src_folder_name, file_name=src_file_name)
if file:
dst_folder = self.get_folder(folder_name=dst_folder_name)
@@ -283,7 +329,11 @@ class Folder(FileSystemItemABC):
return state
def show(self, markdown: bool = False):
"""Prints a of the Folder"""
"""
Display the contents of the Folder in tabular format.
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
"""
table = PrettyTable(["File", "Size"])
if markdown:
table.set_style(MARKDOWN)
@@ -294,7 +344,13 @@ class Folder(FileSystemItemABC):
print(table.get_string(sortby="File"))
@property
def size(self):
def size(self) -> int:
"""
Calculate and return the total size of all files in the folder.
:return: The total size of all files in the folder. If no files exist or all have `None`
size, returns 0.
"""
return sum(file.size for file in self.files.values() if file.size is not None)
def get_file(self, file_name: str) -> Optional[File]:
@@ -313,14 +369,19 @@ class Folder(FileSystemItemABC):
"""
Get a file by its uuid.
:param file_uuid: The file uuid.
:return: The matching File.
"""
return self.files.get(file_uuid)
def add_file(self, file: File):
"""Adds a file to the folder list."""
"""
Adds a file to the folder.
:param File file: The File object to be added to the folder.
:raises Exception: If the provided `file` parameter is None or not an instance of the
`File` class.
"""
if file is None or not isinstance(file, File):
raise Exception(f"Invalid file: {file}")
@@ -340,7 +401,6 @@ class Folder(FileSystemItemABC):
The method can take a File object or a file id.
:param file: The file to remove
:type: Optional[File]
"""
if file is None or not isinstance(file, File):
raise Exception(f"Invalid file: {file}")
@@ -369,7 +429,15 @@ class Folder(FileSystemItemABC):
class File(FileSystemItemABC):
"""Class that represents a file in the simulation."""
"""
Class representing a file in the simulation.
:ivar Folder folder: The folder in which the file resides.
:ivar FileType file_type: The type of the file.
:ivar Optional[int] sim_size: The simulated file size.
:ivar bool real: Indicates if the file is actually a real file in the Node sim fs output.
:ivar Optional[Path] sim_path: The path if the file is real.
"""
folder: Folder
"The Folder the File is in."
@@ -415,16 +483,30 @@ class File(FileSystemItemABC):
pass
def make_copy(self, dst_folder: Folder) -> File:
"""
Create a copy of the current File object in the given destination folder.
:param Folder dst_folder: The destination folder for the copied file.
:return: A new File object that is a copy of the current file.
"""
return File(folder=dst_folder, **self.model_dump(exclude={"uuid", "folder", "sim_path"}))
@property
def path(self):
"""The path of the file in the FileSystem."""
def path(self) -> str:
"""
Get the path of the file in the file system.
:return: The full path of the file.
"""
return f"{self.folder.name}/{self.name}"
@property
def size(self) -> int:
"""The file size in Bytes."""
"""
Get the size of the file in bytes.
:return: The size of the file in bytes.
"""
if self.real:
return os.path.getsize(self.sim_path)
return self.sim_size

View File

@@ -929,7 +929,7 @@ class Node(SimComponent):
kwargs["software_manager"] = SoftwareManager(
sys_log=kwargs.get("sys_log"),
session_manager=kwargs.get("session_manager"),
file_system=kwargs.get("file_system")
file_system=kwargs.get("file_system"),
)
super().__init__(**kwargs)
self.arp.nics = self.nics

View File

@@ -172,20 +172,20 @@ def arcd_uc2_network() -> Network:
);"""
user_insert_statements = [
"INSERT INTO user (name, email, age, city, occupation) VALUES ('John Doe', 'johndoe@example.com', 32, 'New York', 'Engineer');",
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Jane Smith', 'janesmith@example.com', 27, 'Los Angeles', 'Designer');",
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Bob Johnson', 'bobjohnson@example.com', 45, 'Chicago', 'Manager');",
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Alice Lee', 'alicelee@example.com', 22, 'San Francisco', 'Student');",
"INSERT INTO user (name, email, age, city, occupation) VALUES ('David Kim', 'davidkim@example.com', 38, 'Houston', 'Consultant');",
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Emily Chen', 'emilychen@example.com', 29, 'Seattle', 'Software Developer');",
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Frank Wang', 'frankwang@example.com', 55, 'New York', 'Entrepreneur');",
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Grace Park', 'gracepark@example.com', 31, 'Los Angeles', 'Marketing Specialist');",
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Henry Wu', 'henrywu@example.com', 40, 'Chicago', 'Accountant');",
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Isabella Kim', 'isabellakim@example.com', 26, 'San Francisco', 'Graphic Designer');",
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Jake Lee', 'jakelee@example.com', 33, 'Houston', 'Sales Manager');",
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Kelly Chen', 'kellychen@example.com', 28, 'Seattle', 'Web Developer');",
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');",
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');",
"INSERT INTO user (name, email, age, city, occupation) VALUES ('John Doe', 'johndoe@example.com', 32, 'New York', 'Engineer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Jane Smith', 'janesmith@example.com', 27, 'Los Angeles', 'Designer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Bob Johnson', 'bobjohnson@example.com', 45, 'Chicago', 'Manager');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Alice Lee', 'alicelee@example.com', 22, 'San Francisco', 'Student');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('David Kim', 'davidkim@example.com', 38, 'Houston', 'Consultant');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Emily Chen', 'emilychen@example.com', 29, 'Seattle', 'Software Developer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Frank Wang', 'frankwang@example.com', 55, 'New York', 'Entrepreneur');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Grace Park', 'gracepark@example.com', 31, 'Los Angeles', 'Marketing Specialist');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Henry Wu', 'henrywu@example.com', 40, 'Chicago', 'Accountant');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Isabella Kim', 'isabellakim@example.com', 26, 'San Francisco', 'Graphic Designer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Jake Lee', 'jakelee@example.com', 33, 'Houston', 'Sales Manager');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Kelly Chen', 'kellychen@example.com', 28, 'Seattle', 'Web Developer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');", # noqa
]
database_server.software_manager.add_service(DatabaseService)
database: DatabaseService = database_server.software_manager.services["Database"] # noqa
@@ -219,5 +219,4 @@ def arcd_uc2_network() -> Network:
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER)
return network

View File

@@ -1,7 +1,7 @@
import json
import logging
from pathlib import Path
from typing import Optional
from typing import Any, Dict, List, Optional
from primaite.simulator import SIM_OUTPUT
@@ -52,7 +52,12 @@ class PacketCapture:
self.logger.addFilter(_JSONFilter())
def read(self):
def read(self) -> List[Dict[Any]]:
"""
Read packet capture logs and return them as a list of dictionaries.
:return: List of frames captured, represented as dictionaries.
"""
frames = []
with open(self._get_log_path(), "r") as file:
while line := file.readline():

View File

@@ -38,9 +38,7 @@ class Session(SimComponent):
connected: bool = False
@classmethod
def from_session_key(
cls, session_key: Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]
) -> Session:
def from_session_key(cls, session_key: Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]) -> Session:
"""
Create a Session instance from a session key tuple.
@@ -97,7 +95,7 @@ class SessionManager:
@staticmethod
def _get_session_key(
frame: Frame, inbound_frame: bool = True
frame: Frame, inbound_frame: bool = True
) -> Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]:
"""
Extracts the session key from the given frame.
@@ -136,12 +134,12 @@ class SessionManager:
return protocol, with_ip_address, src_port, dst_port
def receive_payload_from_software_manager(
self,
payload: Any,
dst_ip_address: Optional[IPv4Address] = None,
dst_port: Optional[Port] = None,
session_id: Optional[str] = None,
is_reattempt: bool = False,
self,
payload: Any,
dst_ip_address: Optional[IPv4Address] = None,
dst_port: Optional[Port] = None,
session_id: Optional[str] = None,
is_reattempt: bool = False,
) -> Union[Any, None]:
"""
Receive a payload from the SoftwareManager.
@@ -164,9 +162,12 @@ class SessionManager:
if not is_reattempt:
self.arp_cache.send_arp_request(dst_ip_address)
return self.receive_payload_from_software_manager(
payload=payload, dst_ip_address=dst_ip_address, dst_port=dst_port, session_id=session_id,
is_reattempt=True
)
payload=payload,
dst_ip_address=dst_ip_address,
dst_port=dst_port,
session_id=session_id,
is_reattempt=True,
)
else:
return

View File

@@ -7,7 +7,6 @@ from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.core.session_manager import Session
from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.software import SoftwareType
@@ -100,7 +99,7 @@ class SoftwareManager:
"""
self.session_manager.receive_payload_from_software_manager(
payload=payload, dst_ip_address=dest_ip_address, dst_port=dest_port, session_id=session_id
)
)
def receive_payload_from_session_manger(self, payload: Any, port: Port, protocol: IPProtocol, session_id: str):
"""

View File

@@ -1,9 +1,9 @@
import sqlite3
from ipaddress import IPv4Address
from sqlite3 import OperationalError
from typing import Dict, Optional, Any, List, Union
from typing import Any, Dict, List, Optional, Union
from prettytable import PrettyTable, MARKDOWN
from prettytable import MARKDOWN, PrettyTable
from primaite.simulator.file_system.file_system import File
from primaite.simulator.network.transmission.network_layer import IPProtocol
@@ -13,7 +13,12 @@ from primaite.simulator.system.services.service import Service
class DatabaseService(Service):
"""A generic SQL Server Service."""
"""
A class for simulating a generic SQL Server service.
This class inherits from the `Service` class and provides methods to manage and query a SQLite database.
"""
backup_server: Optional[IPv4Address] = None
"The IP Address of the server the "
@@ -28,12 +33,21 @@ class DatabaseService(Service):
self._cursor = self._conn.cursor()
def tables(self) -> List[str]:
"""
Get a list of table names present in the database.
:return: List of table names.
"""
sql = "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';"
results = self._process_sql(sql)
return [row[0] for row in results["data"]]
def show(self, markdown: bool = False):
"""Prints a Table names in the Database."""
"""
Prints a list of table names in the database using PrettyTable.
:param markdown: Whether to output the table in Markdown format.
"""
table = PrettyTable(["Table"])
if markdown:
table.set_style(MARKDOWN)
@@ -44,10 +58,17 @@ class DatabaseService(Service):
print(table)
def _create_db_file(self):
"""Creates the Simulation File and sqlite file in the file system."""
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db", real=True)
self.folder = self._db_file.folder
def _process_sql(self, query: str) -> Dict[str, Union[int, List[Any]]]:
"""
Executes the given SQL query and returns the result.
:param query: The SQL query to be executed.
:return: Dictionary containing status code and data fetched.
"""
try:
self._cursor.execute(query)
self._conn.commit()
@@ -69,11 +90,15 @@ class DatabaseService(Service):
return super().describe_state()
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Processes the incoming SQL payload and sends the result back.
:param payload: The SQL query to be executed.
:param session_id: The session identifier.
:return: The status code of the SQL execution.
"""
result = self._process_sql(payload)
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(payload=result, session_id=session_id)
return result["status_code"]
def send(self, payload: Any, session_id: str, **kwargs) -> bool:
pass
return result["status_code"] == 200

View File

@@ -29,7 +29,7 @@ def file_system() -> FileSystem:
return Node(hostname="fs_node").file_system
#PrimAITE v2 stuff
# PrimAITE v2 stuff
class TempPrimaiteSession(PrimaiteSession):
"""
A temporary PrimaiteSession class.

View File

@@ -2,9 +2,9 @@ from ipaddress import IPv4Address
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.networks import arcd_uc2_network
from primaite.simulator.network.transmission.data_link_layer import Frame, EthernetHeader
from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
from primaite.simulator.network.transmission.network_layer import IPPacket, Precedence
from primaite.simulator.network.transmission.transport_layer import TCPHeader, Port
from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader
def test_database_query_across_the_network():
@@ -20,22 +20,15 @@ def test_database_query_across_the_network():
outbound_nic = client_1.arp.get_arp_cache_nic(IPv4Address("192.168.1.14"))
client_1.ping("192.168.1.14")
frame = Frame(
ethernet=EthernetHeader(
src_mac_addr=client_1.ethernet_port[1].mac_address,
dst_mac_addr=dst_mac_address
),
ethernet=EthernetHeader(src_mac_addr=client_1.ethernet_port[1].mac_address, dst_mac_addr=dst_mac_address),
ip=IPPacket(
src_ip_address=client_1.ethernet_port[1].ip_address,
dst_ip_address=IPv4Address("192.168.1.14"),
precedence=Precedence.FLASH
precedence=Precedence.FLASH,
),
tcp=TCPHeader(
src_port=Port.POSTGRES_SERVER,
dst_port=Port.POSTGRES_SERVER
),
payload="SELECT * FROM user;"
tcp=TCPHeader(src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER),
payload="SELECT * FROM user;",
)
outbound_nic.send_frame(frame)
@@ -43,4 +36,4 @@ def test_database_query_across_the_network():
client_1_last_payload = outbound_nic.pcap.read()[-1]["payload"]
assert client_1_last_payload["status_code"] == 200
assert client_1_last_payload["data"]
assert client_1_last_payload["data"]

View File

@@ -29,4 +29,4 @@ def test_creation():
assert False, f"Test was not supposed to throw exception: {e}"
# there should be a session after the service is started
assert len(client_1.session_manager.sessions_by_uuid) == 1
assert len(client_1.session_manager.sessions_by_uuid) == 1

View File

@@ -56,4 +56,4 @@ def test_creation(database_server):
def test_db_population(database):
database.show()
assert database.tables() == ["user"]
assert database.tables() == ["user"]