#1816 - Added the final pieces of the puzzle to get data up from NIC → session manager → software manager → service.
- Implemented a basic sim DB that matches UC2 data manipulation DB in IY. - Added a test that confirms DB queries can be sent over the network.
This commit is contained in:
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os.path
|
||||
import shutil
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
@@ -220,22 +221,14 @@ class FileSystem(SimComponent):
|
||||
dst_folder = self.create_folder(dst_folder_name)
|
||||
# add file to dst
|
||||
dst_folder.add_file(file)
|
||||
if file.real:
|
||||
old_sim_path = file.sim_path
|
||||
file.sim_path = file.folder.fs.sim_root / file.path
|
||||
file.sim_path.parent.mkdir(exist_ok=True)
|
||||
shutil.move(old_sim_path, file.sim_path)
|
||||
|
||||
def copy_file(self, src_folder_name: str, src_file_name: str, dst_folder_name):
|
||||
"""
|
||||
Copies a file from one folder to another.
|
||||
|
||||
can provide
|
||||
|
||||
:param file: The file to move
|
||||
:type: file: File
|
||||
|
||||
:param src_folder: The folder where the file is located
|
||||
:type: Folder
|
||||
|
||||
:param target_folder: The folder where the file should be moved to
|
||||
:type: Folder
|
||||
"""
|
||||
file = self.get_file(folder_name=src_folder_name, file_name=src_file_name)
|
||||
if file:
|
||||
dst_folder = self.get_folder(folder_name=dst_folder_name)
|
||||
@@ -243,6 +236,9 @@ class FileSystem(SimComponent):
|
||||
dst_folder = self.create_folder(dst_folder_name)
|
||||
new_file = file.make_copy(dst_folder=dst_folder)
|
||||
dst_folder.add_file(new_file)
|
||||
if file.real:
|
||||
new_file.sim_path.parent.mkdir(exist_ok=True)
|
||||
shutil.copy2(file.sim_path, new_file.sim_path)
|
||||
|
||||
def get_folder(self, folder_name: str) -> Optional[Folder]:
|
||||
"""
|
||||
@@ -419,7 +415,7 @@ class File(FileSystemItemABC):
|
||||
pass
|
||||
|
||||
def make_copy(self, dst_folder: Folder) -> File:
|
||||
return File(folder=dst_folder, **self.model_dump(exclude={"uuid", "folder"}))
|
||||
return File(folder=dst_folder, **self.model_dump(exclude={"uuid", "folder", "sim_path"}))
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
|
||||
@@ -921,16 +921,19 @@ class Node(SimComponent):
|
||||
kwargs["icmp"] = ICMP(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp"))
|
||||
if not kwargs.get("session_manager"):
|
||||
kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp"))
|
||||
if not kwargs.get("software_manager"):
|
||||
kwargs["software_manager"] = SoftwareManager(
|
||||
sys_log=kwargs.get("sys_log"), session_manager=kwargs.get("session_manager")
|
||||
)
|
||||
if not kwargs.get("root"):
|
||||
kwargs["root"] = SIM_OUTPUT / kwargs["hostname"]
|
||||
if not kwargs.get("file_system"):
|
||||
kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs")
|
||||
if not kwargs.get("software_manager"):
|
||||
kwargs["software_manager"] = SoftwareManager(
|
||||
sys_log=kwargs.get("sys_log"),
|
||||
session_manager=kwargs.get("session_manager"),
|
||||
file_system=kwargs.get("file_system")
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
self.arp.nics = self.nics
|
||||
self.session_manager.software_manager = self.software_manager
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
@@ -1097,6 +1100,8 @@ class Node(SimComponent):
|
||||
if frame.ip.protocol == IPProtocol.TCP:
|
||||
if frame.tcp.src_port == Port.ARP:
|
||||
self.arp.process_arp_packet(from_nic=from_nic, arp_packet=frame.arp)
|
||||
else:
|
||||
self.session_manager.receive_frame(frame)
|
||||
elif frame.ip.protocol == IPProtocol.UDP:
|
||||
pass
|
||||
elif frame.ip.protocol == IPProtocol.ICMP:
|
||||
|
||||
@@ -6,6 +6,7 @@ from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.services.database import DatabaseService
|
||||
|
||||
|
||||
def client_server_routed() -> Network:
|
||||
@@ -160,6 +161,39 @@ def arcd_uc2_network() -> Network:
|
||||
database_server.power_on()
|
||||
network.connect(endpoint_b=database_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[3])
|
||||
|
||||
ddl = """
|
||||
CREATE TABLE IF NOT EXISTS user (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name VARCHAR(50) NOT NULL,
|
||||
email VARCHAR(50) NOT NULL,
|
||||
age INT,
|
||||
city VARCHAR(50),
|
||||
occupation VARCHAR(50)
|
||||
);"""
|
||||
|
||||
user_insert_statements = [
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('John Doe', 'johndoe@example.com', 32, 'New York', 'Engineer');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Jane Smith', 'janesmith@example.com', 27, 'Los Angeles', 'Designer');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Bob Johnson', 'bobjohnson@example.com', 45, 'Chicago', 'Manager');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Alice Lee', 'alicelee@example.com', 22, 'San Francisco', 'Student');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('David Kim', 'davidkim@example.com', 38, 'Houston', 'Consultant');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Emily Chen', 'emilychen@example.com', 29, 'Seattle', 'Software Developer');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Frank Wang', 'frankwang@example.com', 55, 'New York', 'Entrepreneur');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Grace Park', 'gracepark@example.com', 31, 'Los Angeles', 'Marketing Specialist');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Henry Wu', 'henrywu@example.com', 40, 'Chicago', 'Accountant');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Isabella Kim', 'isabellakim@example.com', 26, 'San Francisco', 'Graphic Designer');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Jake Lee', 'jakelee@example.com', 33, 'Houston', 'Sales Manager');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Kelly Chen', 'kellychen@example.com', 28, 'Seattle', 'Web Developer');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');",
|
||||
]
|
||||
database_server.software_manager.add_service(DatabaseService)
|
||||
database: DatabaseService = database_server.software_manager.services["Database"] # noqa
|
||||
database.start()
|
||||
database._process_sql(ddl) # noqa
|
||||
for insert_statement in user_insert_statements:
|
||||
database._process_sql(insert_statement) # noqa
|
||||
|
||||
# Backup Server
|
||||
backup_server = Server(
|
||||
hostname="backup_server", ip_address="192.168.1.16", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
|
||||
@@ -183,4 +217,7 @@ def arcd_uc2_network() -> Network:
|
||||
|
||||
router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
|
||||
|
||||
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER)
|
||||
|
||||
|
||||
return network
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
@@ -51,6 +52,13 @@ class PacketCapture:
|
||||
|
||||
self.logger.addFilter(_JSONFilter())
|
||||
|
||||
def read(self):
|
||||
frames = []
|
||||
with open(self._get_log_path(), "r") as file:
|
||||
while line := file.readline():
|
||||
frames.append(json.loads(line.rstrip()))
|
||||
return frames
|
||||
|
||||
@property
|
||||
def _logger_name(self) -> str:
|
||||
"""Get PCAP the logger name."""
|
||||
|
||||
@@ -32,15 +32,14 @@ class Session(SimComponent):
|
||||
"""
|
||||
|
||||
protocol: IPProtocol
|
||||
src_ip_address: IPv4Address
|
||||
dst_ip_address: IPv4Address
|
||||
with_ip_address: IPv4Address
|
||||
src_port: Optional[Port]
|
||||
dst_port: Optional[Port]
|
||||
connected: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_session_key(
|
||||
cls, session_key: Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]]
|
||||
cls, session_key: Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]
|
||||
) -> Session:
|
||||
"""
|
||||
Create a Session instance from a session key tuple.
|
||||
@@ -48,11 +47,10 @@ class Session(SimComponent):
|
||||
:param session_key: Tuple containing the session details.
|
||||
:return: A Session instance.
|
||||
"""
|
||||
protocol, src_ip_address, dst_ip_address, src_port, dst_port = session_key
|
||||
protocol, with_ip_address, src_port, dst_port = session_key
|
||||
return Session(
|
||||
protocol=protocol,
|
||||
src_ip_address=src_ip_address,
|
||||
dst_ip_address=dst_ip_address,
|
||||
with_ip_address=with_ip_address,
|
||||
src_port=src_port,
|
||||
dst_port=dst_port,
|
||||
)
|
||||
@@ -99,8 +97,8 @@ class SessionManager:
|
||||
|
||||
@staticmethod
|
||||
def _get_session_key(
|
||||
frame: Frame, from_source: bool = True
|
||||
) -> Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]]:
|
||||
frame: Frame, inbound_frame: bool = True
|
||||
) -> Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]:
|
||||
"""
|
||||
Extracts the session key from the given frame.
|
||||
|
||||
@@ -112,38 +110,38 @@ class SessionManager:
|
||||
- Optional[Port]: The destination port number (if applicable).
|
||||
|
||||
:param frame: The network frame from which to extract the session key.
|
||||
:param from_source: A flag to indicate if the key should be extracted from the source or destination.
|
||||
:return: A tuple containing the session key.
|
||||
"""
|
||||
protocol = frame.ip.protocol
|
||||
src_ip_address = frame.ip.src_ip_address
|
||||
dst_ip_address = frame.ip.dst_ip_address
|
||||
with_ip_address = frame.ip.src_ip_address
|
||||
if protocol == IPProtocol.TCP:
|
||||
if from_source:
|
||||
if inbound_frame:
|
||||
src_port = frame.tcp.src_port
|
||||
dst_port = frame.tcp.dst_port
|
||||
else:
|
||||
dst_port = frame.tcp.src_port
|
||||
src_port = frame.tcp.dst_port
|
||||
with_ip_address = frame.ip.dst_ip_address
|
||||
elif protocol == IPProtocol.UDP:
|
||||
if from_source:
|
||||
if inbound_frame:
|
||||
src_port = frame.udp.src_port
|
||||
dst_port = frame.udp.dst_port
|
||||
else:
|
||||
dst_port = frame.udp.src_port
|
||||
src_port = frame.udp.dst_port
|
||||
with_ip_address = frame.ip.dst_ip_address
|
||||
else:
|
||||
src_port = None
|
||||
dst_port = None
|
||||
return protocol, src_ip_address, dst_ip_address, src_port, dst_port
|
||||
return protocol, with_ip_address, src_port, dst_port
|
||||
|
||||
def receive_payload_from_software_manager(
|
||||
self,
|
||||
payload: Any,
|
||||
dest_ip_address: Optional[IPv4Address] = None,
|
||||
dest_port: Optional[Port] = None,
|
||||
session_id: Optional[str] = None,
|
||||
is_reattempt: bool = False,
|
||||
self,
|
||||
payload: Any,
|
||||
dst_ip_address: Optional[IPv4Address] = None,
|
||||
dst_port: Optional[Port] = None,
|
||||
session_id: Optional[str] = None,
|
||||
is_reattempt: bool = False,
|
||||
) -> Union[Any, None]:
|
||||
"""
|
||||
Receive a payload from the SoftwareManager.
|
||||
@@ -154,23 +152,21 @@ class SessionManager:
|
||||
:param session_id: The Session ID the payload is to originate from. Optional. If None, one will be created.
|
||||
"""
|
||||
if session_id:
|
||||
dest_ip_address = self.sessions_by_uuid[session_id].dst_ip_address
|
||||
dest_port = self.sessions_by_uuid[session_id].dst_port
|
||||
session = self.sessions_by_uuid[session_id]
|
||||
dst_ip_address = self.sessions_by_uuid[session_id].with_ip_address
|
||||
dst_port = self.sessions_by_uuid[session_id].dst_port
|
||||
|
||||
dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dest_ip_address)
|
||||
dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dst_ip_address)
|
||||
|
||||
if dst_mac_address:
|
||||
outbound_nic = self.arp_cache.get_arp_cache_nic(dest_ip_address)
|
||||
outbound_nic = self.arp_cache.get_arp_cache_nic(dst_ip_address)
|
||||
else:
|
||||
if not is_reattempt:
|
||||
self.arp_cache.send_arp_request(dest_ip_address)
|
||||
self.arp_cache.send_arp_request(dst_ip_address)
|
||||
return self.receive_payload_from_software_manager(
|
||||
payload=payload,
|
||||
dest_ip_address=dest_ip_address,
|
||||
dest_port=dest_port,
|
||||
session_id=session_id,
|
||||
is_reattempt=True,
|
||||
)
|
||||
payload=payload, dst_ip_address=dst_ip_address, dst_port=dst_port, session_id=session_id,
|
||||
is_reattempt=True
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
@@ -178,17 +174,17 @@ class SessionManager:
|
||||
ethernet=EthernetHeader(src_mac_addr=outbound_nic.mac_address, dst_mac_addr=dst_mac_address),
|
||||
ip=IPPacket(
|
||||
src_ip_address=outbound_nic.ip_address,
|
||||
dst_ip_address=dest_ip_address,
|
||||
dst_ip_address=dst_ip_address,
|
||||
),
|
||||
tcp=TCPHeader(
|
||||
src_port=dest_port,
|
||||
dst_port=dest_port,
|
||||
src_port=dst_port,
|
||||
dst_port=dst_port,
|
||||
),
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
if not session_id:
|
||||
session_key = self._get_session_key(frame, from_source=True)
|
||||
session_key = self._get_session_key(frame, inbound_frame=False)
|
||||
session = self.sessions_by_key.get(session_key)
|
||||
if not session:
|
||||
# Create new session
|
||||
@@ -198,33 +194,25 @@ class SessionManager:
|
||||
|
||||
outbound_nic.send_frame(frame)
|
||||
|
||||
def send_payload_to_software_manager(self, payload: Any, session_id: int):
|
||||
def receive_frame(self, frame: Frame):
|
||||
"""
|
||||
Send a payload to the software manager.
|
||||
|
||||
:param payload: The payload to be sent.
|
||||
:param session_id: The Session ID the payload originates from.
|
||||
"""
|
||||
self.software_manager.receive_payload_from_session_manger()
|
||||
|
||||
def receive_payload_from_nic(self, frame: Frame):
|
||||
"""
|
||||
Receive a Frame from the NIC.
|
||||
Receive a Frame.
|
||||
|
||||
Extract the session key using the _get_session_key method, and forward the payload to the appropriate
|
||||
session. If the session does not exist, a new one is created.
|
||||
|
||||
:param frame: The frame being received.
|
||||
"""
|
||||
session_key = self._get_session_key(frame)
|
||||
session = self.sessions_by_key.get(session_key)
|
||||
session_key = self._get_session_key(frame, inbound_frame=True)
|
||||
session: Session = self.sessions_by_key.get(session_key)
|
||||
if not session:
|
||||
# Create new session
|
||||
session = Session.from_session_key(session_key)
|
||||
self.sessions_by_key[session_key] = session
|
||||
self.sessions_by_uuid[session.uuid] = session
|
||||
self.software_manager.receive_payload_from_session_manger(payload=frame, session=session)
|
||||
# TODO: Implement the frame deconstruction and send to SoftwareManager.
|
||||
self.software_manager.receive_payload_from_session_manger(
|
||||
payload=frame.payload, port=frame.tcp.dst_port, protocol=frame.ip.protocol, session_id=session.uuid
|
||||
)
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
"""
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
@@ -23,7 +24,7 @@ ServiceClass = TypeVar("ServiceClass", bound=Service)
|
||||
class SoftwareManager:
|
||||
"""A class that manages all running Services and Applications on a Node and facilitates their communication."""
|
||||
|
||||
def __init__(self, session_manager: "SessionManager", sys_log: "SysLog"):
|
||||
def __init__(self, session_manager: "SessionManager", sys_log: SysLog, file_system: FileSystem):
|
||||
"""
|
||||
Initialize a new instance of SoftwareManager.
|
||||
|
||||
@@ -34,6 +35,7 @@ class SoftwareManager:
|
||||
self.applications: Dict[str, Application] = {}
|
||||
self.port_protocol_mapping: Dict[Tuple[Port, IPProtocol], Union[Service, Application]] = {}
|
||||
self.sys_log: SysLog = sys_log
|
||||
self.file_system: FileSystem = file_system
|
||||
|
||||
def add_service(self, service_class: Type[ServiceClass]):
|
||||
"""
|
||||
@@ -41,7 +43,7 @@ class SoftwareManager:
|
||||
|
||||
:param: service_class: The class of the service to add
|
||||
"""
|
||||
service = service_class(software_manager=self, sys_log=self.sys_log)
|
||||
service = service_class(software_manager=self, sys_log=self.sys_log, file_system=self.file_system)
|
||||
|
||||
service.software_manager = self
|
||||
self.services[service.name] = service
|
||||
@@ -86,7 +88,7 @@ class SoftwareManager:
|
||||
payload: Any,
|
||||
dest_ip_address: Optional[IPv4Address] = None,
|
||||
dest_port: Optional[Port] = None,
|
||||
session_id: Optional[int] = None,
|
||||
session_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Send a payload to the SessionManager.
|
||||
@@ -97,21 +99,21 @@ class SoftwareManager:
|
||||
:param session_id: The Session ID the payload is to originate from. Optional.
|
||||
"""
|
||||
self.session_manager.receive_payload_from_software_manager(
|
||||
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id
|
||||
)
|
||||
payload=payload, dst_ip_address=dest_ip_address, dst_port=dest_port, session_id=session_id
|
||||
)
|
||||
|
||||
def receive_payload_from_session_manger(self, payload: Any, session: Session):
|
||||
def receive_payload_from_session_manger(self, payload: Any, port: Port, protocol: IPProtocol, session_id: str):
|
||||
"""
|
||||
Receive a payload from the SessionManager and forward it to the corresponding service or application.
|
||||
|
||||
:param payload: The payload being received.
|
||||
:param session: The transport session the payload originates from.
|
||||
"""
|
||||
# receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None)
|
||||
# if receiver:
|
||||
# receiver.receive_payload(None, payload)
|
||||
# else:
|
||||
# raise ValueError(f"No service or application found for port {port} and protocol {protocol}")
|
||||
receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None)
|
||||
if receiver:
|
||||
receiver.receive(payload=payload, session_id=session_id)
|
||||
else:
|
||||
self.sys_log.error(f"No service or application found for port {port} and protocol {protocol}")
|
||||
pass
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
|
||||
@@ -1,15 +1,61 @@
|
||||
from typing import Dict
|
||||
import sqlite3
|
||||
from ipaddress import IPv4Address
|
||||
from sqlite3 import OperationalError
|
||||
from typing import Dict, Optional, Any, List, Union
|
||||
|
||||
from primaite.simulator.file_system.file_type import FileType
|
||||
from primaite.simulator.network.hardware.base import Node
|
||||
from prettytable import PrettyTable, MARKDOWN
|
||||
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.simulator.system.services.service import Service
|
||||
|
||||
|
||||
class DatabaseService(Service):
|
||||
"""A generic SQL Server Service."""
|
||||
backup_server: Optional[IPv4Address] = None
|
||||
"The IP Address of the server the "
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "Database"
|
||||
kwargs["port"] = Port.POSTGRES_SERVER
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
super().__init__(**kwargs)
|
||||
self._db_file: File
|
||||
self._create_db_file()
|
||||
self._conn = sqlite3.connect(self._db_file.sim_path)
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
def tables(self) -> List[str]:
|
||||
sql = "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';"
|
||||
results = self._process_sql(sql)
|
||||
return [row[0] for row in results["data"]]
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
"""Prints a Table names in the Database."""
|
||||
table = PrettyTable(["Table"])
|
||||
if markdown:
|
||||
table.set_style(MARKDOWN)
|
||||
table.align = "l"
|
||||
table.title = f"{self.file_system.sys_log.hostname} Database"
|
||||
for row in self.tables():
|
||||
table.add_row([row])
|
||||
print(table)
|
||||
|
||||
def _create_db_file(self):
|
||||
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db", real=True)
|
||||
self.folder = self._db_file.folder
|
||||
|
||||
def _process_sql(self, query: str) -> Dict[str, Union[int, List[Any]]]:
|
||||
try:
|
||||
self._cursor.execute(query)
|
||||
self._conn.commit()
|
||||
except OperationalError:
|
||||
# Handle the case where the table does not exist.
|
||||
return {"status_code": 404, "data": []}
|
||||
|
||||
return {"status_code": 200, "data": self._cursor.fetchall()}
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
@@ -22,53 +68,12 @@ class DatabaseService(Service):
|
||||
"""
|
||||
return super().describe_state()
|
||||
|
||||
@classmethod
|
||||
def install(cls, node: Node):
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
result = self._process_sql(payload)
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(payload=result, session_id=session_id)
|
||||
|
||||
return result["status_code"]
|
||||
|
||||
def uninstall(self) -> None:
|
||||
"""
|
||||
Undo installation procedure.
|
||||
|
||||
This method deletes files created when installing the database, and the database folder if it is empty.
|
||||
"""
|
||||
super().uninstall()
|
||||
node: Node = self.parent
|
||||
node.file_system.delete_file(self.primary_store)
|
||||
node.file_system.delete_file(self.transaction_log)
|
||||
if self.secondary_store:
|
||||
node.file_system.delete_file(self.secondary_store)
|
||||
if len(self.folder.files) == 0:
|
||||
node.file_system.delete_folder(self.folder)
|
||||
|
||||
def install(self) -> None:
|
||||
"""Perform first time install on a node, creating necessary files."""
|
||||
super().install()
|
||||
assert isinstance(self.parent, Node), "Database install can only happen after the db service is added to a node"
|
||||
self._setup_files()
|
||||
|
||||
def _setup_files(
|
||||
self,
|
||||
folder_name: str = "database",
|
||||
):
|
||||
"""Set up files that are required by the database on the parent host.
|
||||
|
||||
:param folder_name: Name of the folder which will be setup to hold the db files, defaults to "database"
|
||||
:type folder_name: str, optional
|
||||
"""
|
||||
# note that this parent.file_system.create_folder call in the future will be authenticated by using permissions
|
||||
# handler. This permission will be granted based on service account given to the database service.
|
||||
self.parent: Node
|
||||
self.folder = self.parent.file_system.create_folder(folder_name)
|
||||
self.primary_store = self.parent.file_system.create_file(
|
||||
"db_primary_store", db_size, FileType.MDF, folder=self.folder
|
||||
)
|
||||
self.transaction_log = self.parent.file_system.create_file(
|
||||
"db_transaction_log", "1", FileType.LDF, folder=self.folder
|
||||
)
|
||||
if use_secondary_db_file:
|
||||
self.secondary_store = self.parent.file_system.create_file(
|
||||
"db_secondary_store", secondary_db_size, FileType.NDF, folder=self.folder
|
||||
)
|
||||
else:
|
||||
self.secondary_store = None
|
||||
def send(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
pass
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from primaite.simulator.core import Action, ActionManager, SimComponent
|
||||
from primaite.simulator.file_system.file_system import FileSystem, Folder
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
|
||||
@@ -79,6 +80,10 @@ class Software(SimComponent):
|
||||
"An instance of Software Manager that is used by the parent node."
|
||||
sys_log: SysLog = None
|
||||
"An instance of SysLog that is used by the parent node."
|
||||
file_system: FileSystem
|
||||
"The FileSystem of the Node the Software is installed on."
|
||||
folder: Optional[Folder] = None
|
||||
"The folder on the file system the Software uses."
|
||||
|
||||
def _init_action_manager(self) -> ActionManager:
|
||||
am = super()._init_action_manager()
|
||||
@@ -216,7 +221,6 @@ class IOSoftware(Software):
|
||||
:param kwargs: Additional keyword arguments specific to the implementation.
|
||||
:return: True if the payload was successfully sent, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
"""
|
||||
|
||||
@@ -19,7 +19,17 @@ ACTION_SPACE_NODE_ACTION_VALUES = 1
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
# PrimAITE v3 stuff
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.network.hardware.base import Node
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def file_system() -> FileSystem:
|
||||
return Node(hostname="fs_node").file_system
|
||||
|
||||
|
||||
#PrimAITE v2 stuff
|
||||
class TempPrimaiteSession(PrimaiteSession):
|
||||
"""
|
||||
A temporary PrimaiteSession class.
|
||||
|
||||
@@ -1,52 +1,46 @@
|
||||
from primaite.simulator.network.hardware.base import Node
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.services.database import DatabaseService
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
from primaite.simulator.system.software import SoftwareCriticality, SoftwareHealthState
|
||||
from ipaddress import IPv4Address
|
||||
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.networks import arcd_uc2_network
|
||||
from primaite.simulator.network.transmission.data_link_layer import Frame, EthernetHeader
|
||||
from primaite.simulator.network.transmission.network_layer import IPPacket, Precedence
|
||||
from primaite.simulator.network.transmission.transport_layer import TCPHeader, Port
|
||||
|
||||
|
||||
def test_installing_database():
|
||||
db = DatabaseService(
|
||||
name="SQL-database",
|
||||
health_state_actual=SoftwareHealthState.GOOD,
|
||||
health_state_visible=SoftwareHealthState.GOOD,
|
||||
criticality=SoftwareCriticality.MEDIUM,
|
||||
port=Port.SQL_SERVER,
|
||||
operating_state=ServiceOperatingState.RUNNING,
|
||||
def test_database_query_across_the_network():
|
||||
"""Tests DB query across the network returns HTTP status 200 and date."""
|
||||
network = arcd_uc2_network()
|
||||
|
||||
client_1: Computer = network.get_node_by_hostname("client_1")
|
||||
|
||||
client_1.arp.send_arp_request(IPv4Address("192.168.1.14"))
|
||||
|
||||
dst_mac_address = client_1.arp.get_arp_cache_mac_address(IPv4Address("192.168.1.14"))
|
||||
|
||||
outbound_nic = client_1.arp.get_arp_cache_nic(IPv4Address("192.168.1.14"))
|
||||
client_1.ping("192.168.1.14")
|
||||
|
||||
|
||||
frame = Frame(
|
||||
ethernet=EthernetHeader(
|
||||
src_mac_addr=client_1.ethernet_port[1].mac_address,
|
||||
dst_mac_addr=dst_mac_address
|
||||
),
|
||||
ip=IPPacket(
|
||||
src_ip_address=client_1.ethernet_port[1].ip_address,
|
||||
dst_ip_address=IPv4Address("192.168.1.14"),
|
||||
precedence=Precedence.FLASH
|
||||
),
|
||||
tcp=TCPHeader(
|
||||
src_port=Port.POSTGRES_SERVER,
|
||||
dst_port=Port.POSTGRES_SERVER
|
||||
),
|
||||
payload="SELECT * FROM user;"
|
||||
)
|
||||
|
||||
node = Node(hostname="db-server")
|
||||
outbound_nic.send_frame(frame)
|
||||
|
||||
node.install_service(db)
|
||||
client_1_last_payload = outbound_nic.pcap.read()[-1]["payload"]
|
||||
|
||||
assert db in node
|
||||
|
||||
file_exists = False
|
||||
for folder in node.file_system.folders.values():
|
||||
for file in folder.files.values():
|
||||
if file.name == "db_primary_store":
|
||||
file_exists = True
|
||||
break
|
||||
if file_exists:
|
||||
break
|
||||
assert file_exists
|
||||
|
||||
|
||||
def test_uninstalling_database():
|
||||
db = DatabaseService(
|
||||
name="SQL-database",
|
||||
health_state_actual=SoftwareHealthState.GOOD,
|
||||
health_state_visible=SoftwareHealthState.GOOD,
|
||||
criticality=SoftwareCriticality.MEDIUM,
|
||||
port=Port.SQL_SERVER,
|
||||
operating_state=ServiceOperatingState.RUNNING,
|
||||
)
|
||||
|
||||
node = Node(hostname="db-server")
|
||||
|
||||
node.install_service(db)
|
||||
|
||||
node.uninstall_service(db)
|
||||
|
||||
assert db not in node
|
||||
assert node.file_system.get_folder("database") is None
|
||||
assert client_1_last_payload["status_code"] == 200
|
||||
assert client_1_last_payload["data"]
|
||||
@@ -1,19 +1,13 @@
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.file_system.file_system import File, FileSystem, Folder
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.file_system.file_type import FileType
|
||||
from primaite.simulator.network.hardware.base import Node
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def file_system() -> FileSystem:
|
||||
return Node(hostname="fs_node").file_system
|
||||
|
||||
|
||||
def test_create_folder_and_file(file_system):
|
||||
"""Test creating a folder and a file."""
|
||||
assert len(file_system.folders) == 1
|
||||
test_folder = file_system.create_folder(folder_name="test_folder")
|
||||
file_system.create_folder(folder_name="test_folder")
|
||||
|
||||
assert len(file_system.folders) is 2
|
||||
file_system.create_file(file_name="test_file.txt", folder_name="test_folder")
|
||||
@@ -115,7 +109,7 @@ def test_copy_file(file_system):
|
||||
file_system.create_folder(folder_name="src_folder")
|
||||
file_system.create_folder(folder_name="dst_folder")
|
||||
|
||||
file = file_system.create_file(file_name="test_file.txt", size=10, folder_name="src_folder")
|
||||
file = file_system.create_file(file_name="test_file.txt", size=10, folder_name="src_folder", real=True)
|
||||
original_uuid = file.uuid
|
||||
|
||||
assert len(file_system.get_folder("src_folder").files) == 1
|
||||
@@ -128,6 +122,19 @@ def test_copy_file(file_system):
|
||||
assert file_system.get_file("dst_folder", "test_file.txt").uuid != original_uuid
|
||||
|
||||
|
||||
def test_folder_quarantine_state(file_system):
|
||||
"""Tests the changing of folder quarantine status."""
|
||||
folder = file_system.get_folder("root")
|
||||
|
||||
assert folder.quarantine_status() is False
|
||||
|
||||
folder.quarantine()
|
||||
assert folder.quarantine_status() is True
|
||||
|
||||
folder.unquarantine()
|
||||
assert folder.quarantine_status() is False
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping until we tackle serialisation")
|
||||
def test_serialisation(file_system):
|
||||
"""Test to check that the object serialisation works correctly."""
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.file_system.file_type import FileType
|
||||
|
||||
|
||||
def test_file_type():
|
||||
"""Tests tha the File type is set correctly."""
|
||||
file = File(name="test", file_type=FileType.DOC)
|
||||
assert file.file_type is FileType.DOC
|
||||
|
||||
|
||||
def test_get_size():
|
||||
"""Tests that the file size is being returned properly."""
|
||||
file = File(name="test", size=1.5)
|
||||
assert file.size == 1.5
|
||||
|
||||
|
||||
def test_serialisation():
|
||||
"""Test to check that the object serialisation works correctly."""
|
||||
file = File(name="test", size=1.5, file_type=FileType.DOC)
|
||||
serialised_file = file.model_dump_json()
|
||||
deserialised_file = File.model_validate_json(serialised_file)
|
||||
|
||||
assert file.model_dump_json() == deserialised_file.model_dump_json()
|
||||
@@ -1,75 +0,0 @@
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.file_system.file_system_folder import Folder
|
||||
from primaite.simulator.file_system.file_type import FileType
|
||||
|
||||
|
||||
def test_adding_removing_file():
|
||||
"""Test the adding and removing of a file from a folder."""
|
||||
folder = Folder(name="test")
|
||||
|
||||
file = File(name="test_file", size=10, file_type=FileType.DOC)
|
||||
|
||||
folder.add_file(file)
|
||||
assert folder.size == 10
|
||||
assert len(folder.files) is 1
|
||||
|
||||
folder.remove_file(file)
|
||||
assert folder.size == 0
|
||||
assert len(folder.files) is 0
|
||||
|
||||
|
||||
def test_remove_non_existent_file():
|
||||
"""Test the removing of a file that does not exist."""
|
||||
folder = Folder(name="test")
|
||||
|
||||
file = File(name="test_file", size=10, file_type=FileType.DOC)
|
||||
not_added_file = File(name="fake_file", size=10, file_type=FileType.DOC)
|
||||
|
||||
folder.add_file(file)
|
||||
assert folder.size == 10
|
||||
assert len(folder.files) is 1
|
||||
|
||||
folder.remove_file(not_added_file)
|
||||
assert folder.size == 10
|
||||
assert len(folder.files) is 1
|
||||
|
||||
|
||||
def test_get_file_by_id():
|
||||
"""Test to make sure that the correct file is returned."""
|
||||
folder = Folder(name="test")
|
||||
|
||||
file = File(name="test_file", size=10, file_type=FileType.DOC)
|
||||
file2 = File(name="test_file_2", size=10, file_type=FileType.DOC)
|
||||
|
||||
folder.add_file(file)
|
||||
folder.add_file(file2)
|
||||
assert folder.size == 20
|
||||
assert len(folder.files) is 2
|
||||
|
||||
assert folder.get_file_by_id(file_id=file.uuid) is file
|
||||
|
||||
|
||||
def test_folder_quarantine_state():
|
||||
"""Tests the changing of folder quarantine status."""
|
||||
folder = Folder(name="test")
|
||||
|
||||
assert folder.quarantine_status() is False
|
||||
|
||||
folder.quarantine()
|
||||
assert folder.quarantine_status() is True
|
||||
|
||||
folder.unquarantine()
|
||||
assert folder.quarantine_status() is False
|
||||
|
||||
|
||||
def test_serialisation():
|
||||
"""Test to check that the object serialisation works correctly."""
|
||||
folder = Folder(name="test")
|
||||
file = File(name="test_file", size=10, file_type=FileType.DOC)
|
||||
folder.add_file(file)
|
||||
|
||||
serialised_folder = folder.model_dump_json()
|
||||
|
||||
deserialised_folder = Folder.model_validate_json(serialised_folder)
|
||||
|
||||
assert folder.model_dump_json() == deserialised_folder.model_dump_json()
|
||||
@@ -29,4 +29,4 @@ def test_creation():
|
||||
assert False, f"Test was not supposed to throw exception: {e}"
|
||||
|
||||
# there should be a session after the service is started
|
||||
assert len(client_1.session_manager.sessions_by_uuid) == 1
|
||||
assert len(client_1.session_manager.sessions_by_uuid) == 1
|
||||
@@ -1,15 +1,59 @@
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.hardware.base import Node
|
||||
from primaite.simulator.system.services.database import DatabaseService
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
from primaite.simulator.system.software import SoftwareCriticality, SoftwareHealthState
|
||||
|
||||
DDL = """
|
||||
CREATE TABLE IF NOT EXISTS user (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name VARCHAR(50) NOT NULL,
|
||||
email VARCHAR(50) NOT NULL,
|
||||
age INT,
|
||||
city VARCHAR(50),
|
||||
occupation VARCHAR(50)
|
||||
);"""
|
||||
|
||||
USER_INSERT_STATEMENTS = [
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('John Doe', 'johndoe@example.com', 32, 'New York', 'Engineer');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Jane Smith', 'janesmith@example.com', 27, 'Los Angeles', 'Designer');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Bob Johnson', 'bobjohnson@example.com', 45, 'Chicago', 'Manager');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Alice Lee', 'alicelee@example.com', 22, 'San Francisco', 'Student');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('David Kim', 'davidkim@example.com', 38, 'Houston', 'Consultant');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Emily Chen', 'emilychen@example.com', 29, 'Seattle', 'Software Developer');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Frank Wang', 'frankwang@example.com', 55, 'New York', 'Entrepreneur');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Grace Park', 'gracepark@example.com', 31, 'Los Angeles', 'Marketing Specialist');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Henry Wu', 'henrywu@example.com', 40, 'Chicago', 'Accountant');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Isabella Kim', 'isabellakim@example.com', 26, 'San Francisco', 'Graphic Designer');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Jake Lee', 'jakelee@example.com', 33, 'Houston', 'Sales Manager');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Kelly Chen', 'kellychen@example.com', 28, 'Seattle', 'Web Developer');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');",
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');",
|
||||
]
|
||||
|
||||
|
||||
def test_creation():
|
||||
db = DatabaseService(
|
||||
name="SQL-database",
|
||||
health_state_actual=SoftwareHealthState.GOOD,
|
||||
health_state_visible=SoftwareHealthState.GOOD,
|
||||
criticality=SoftwareCriticality.MEDIUM,
|
||||
port=Port.SQL_SERVER,
|
||||
operating_state=ServiceOperatingState.RUNNING,
|
||||
)
|
||||
@pytest.fixture(scope="function")
|
||||
def database_server() -> Node:
|
||||
node = Node(hostname="db_node")
|
||||
node.software_manager.add_service(DatabaseService)
|
||||
node.software_manager.services["Database"].start()
|
||||
return node
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def database(database_server) -> DatabaseService:
|
||||
database: DatabaseService = database_server.software_manager.services["Database"] # noqa
|
||||
database.receive(DDL, None)
|
||||
for script in USER_INSERT_STATEMENTS:
|
||||
database.receive(script, None)
|
||||
return database
|
||||
|
||||
|
||||
def test_creation(database_server):
|
||||
database_server.software_manager.show()
|
||||
|
||||
|
||||
def test_db_population(database):
|
||||
database.show()
|
||||
assert database.tables() == ["user"]
|
||||
Reference in New Issue
Block a user