#1816 - Added database client. Installed the database client on the Web Server node in the UC2 network. Updated the integration test to query the DB server using the DB client.
This commit is contained in:
170
src/primaite/simulator/file_system/file_type.py
Normal file
170
src/primaite/simulator/file_system/file_type.py
Normal file
@@ -0,0 +1,170 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from random import choice
|
||||
|
||||
|
||||
class FileType(Enum):
|
||||
"""An enumeration of common file types."""
|
||||
|
||||
UNKNOWN = 0
|
||||
"Unknown file type."
|
||||
|
||||
# Text formats
|
||||
TXT = 1
|
||||
"Plain text file."
|
||||
DOC = 2
|
||||
"Microsoft Word document (.doc)"
|
||||
DOCX = 3
|
||||
"Microsoft Word document (.docx)"
|
||||
PDF = 4
|
||||
"Portable Document Format."
|
||||
HTML = 5
|
||||
"HyperText Markup Language file."
|
||||
XML = 6
|
||||
"Extensible Markup Language file."
|
||||
CSV = 7
|
||||
"Comma-Separated Values file."
|
||||
|
||||
# Spreadsheet formats
|
||||
XLS = 8
|
||||
"Microsoft Excel file (.xls)"
|
||||
XLSX = 9
|
||||
"Microsoft Excel file (.xlsx)"
|
||||
|
||||
# Image formats
|
||||
JPEG = 10
|
||||
"JPEG image file."
|
||||
PNG = 11
|
||||
"PNG image file."
|
||||
GIF = 12
|
||||
"GIF image file."
|
||||
BMP = 13
|
||||
"Bitmap image file."
|
||||
|
||||
# Audio formats
|
||||
MP3 = 14
|
||||
"MP3 audio file."
|
||||
WAV = 15
|
||||
"WAV audio file."
|
||||
|
||||
# Video formats
|
||||
MP4 = 16
|
||||
"MP4 video file."
|
||||
AVI = 17
|
||||
"AVI video file."
|
||||
MKV = 18
|
||||
"MKV video file."
|
||||
FLV = 19
|
||||
"FLV video file."
|
||||
|
||||
# Presentation formats
|
||||
PPT = 20
|
||||
"Microsoft PowerPoint file (.ppt)"
|
||||
PPTX = 21
|
||||
"Microsoft PowerPoint file (.pptx)"
|
||||
|
||||
# Web formats
|
||||
JS = 22
|
||||
"JavaScript file."
|
||||
CSS = 23
|
||||
"Cascading Style Sheets file."
|
||||
|
||||
# Programming languages
|
||||
PY = 24
|
||||
"Python script file."
|
||||
C = 25
|
||||
"C source code file."
|
||||
CPP = 26
|
||||
"C++ source code file."
|
||||
JAVA = 27
|
||||
"Java source code file."
|
||||
|
||||
# Compressed file types
|
||||
RAR = 28
|
||||
"RAR archive file."
|
||||
ZIP = 29
|
||||
"ZIP archive file."
|
||||
TAR = 30
|
||||
"TAR archive file."
|
||||
GZ = 31
|
||||
"Gzip compressed file."
|
||||
|
||||
# Database file types
|
||||
DB = 32
|
||||
"Generic DB file. Used by sqlite3."
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
return cls.UNKNOWN
|
||||
|
||||
@classmethod
|
||||
def random(cls) -> FileType:
|
||||
"""
|
||||
Returns a random FileType.
|
||||
|
||||
:return: A random FileType.
|
||||
"""
|
||||
return choice(list(FileType))
|
||||
|
||||
@property
|
||||
def default_size(self) -> int:
|
||||
"""
|
||||
Get the default size of the FileType in bytes.
|
||||
|
||||
Returns 0 if a default size does not exist.
|
||||
"""
|
||||
size = file_type_sizes_bytes[self]
|
||||
return size if size else 0
|
||||
|
||||
|
||||
def get_file_type_from_extension(file_type_extension: str):
|
||||
"""
|
||||
Get a FileType from a file type extension.
|
||||
|
||||
If a matching extension does not exist, FileType.UNKNOWN is returned.
|
||||
|
||||
:param file_type_extension: A file type extension.
|
||||
:return: A file type extension.
|
||||
"""
|
||||
try:
|
||||
return FileType[file_type_extension.upper()]
|
||||
except KeyError:
|
||||
return FileType.UNKNOWN
|
||||
|
||||
|
||||
file_type_sizes_bytes = {
|
||||
FileType.UNKNOWN: 0,
|
||||
FileType.TXT: 4096,
|
||||
FileType.DOC: 51200,
|
||||
FileType.DOCX: 30720,
|
||||
FileType.PDF: 102400,
|
||||
FileType.HTML: 15360,
|
||||
FileType.XML: 10240,
|
||||
FileType.CSV: 15360,
|
||||
FileType.XLS: 102400,
|
||||
FileType.XLSX: 25600,
|
||||
FileType.JPEG: 102400,
|
||||
FileType.PNG: 40960,
|
||||
FileType.GIF: 30720,
|
||||
FileType.BMP: 307200,
|
||||
FileType.MP3: 5120000,
|
||||
FileType.WAV: 25600000,
|
||||
FileType.MP4: 25600000,
|
||||
FileType.AVI: 51200000,
|
||||
FileType.MKV: 51200000,
|
||||
FileType.FLV: 15360000,
|
||||
FileType.PPT: 204800,
|
||||
FileType.PPTX: 102400,
|
||||
FileType.JS: 10240,
|
||||
FileType.CSS: 5120,
|
||||
FileType.PY: 5120,
|
||||
FileType.C: 5120,
|
||||
FileType.CPP: 10240,
|
||||
FileType.JAVA: 10240,
|
||||
FileType.RAR: 1024000,
|
||||
FileType.ZIP: 1024000,
|
||||
FileType.TAR: 1024000,
|
||||
FileType.GZ: 819200,
|
||||
FileType.DB: 15360000,
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import secrets
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
@@ -959,7 +959,24 @@ class Node(SimComponent):
|
||||
)
|
||||
return state
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
def show(self, markdown: bool = False, component: Literal["NIC", "OPEN_PORTS"] = "NIC"):
|
||||
if component == "NIC":
|
||||
self._show_nic(markdown)
|
||||
elif component == "OPEN_PORTS":
|
||||
self._show_open_ports(markdown)
|
||||
|
||||
def _show_open_ports(self, markdown: bool = False):
|
||||
"""Prints a table of the open ports on the Node."""
|
||||
table = PrettyTable(["Port", "Name"])
|
||||
if markdown:
|
||||
table.set_style(MARKDOWN)
|
||||
table.align = "l"
|
||||
table.title = f"{self.hostname} Open Ports"
|
||||
for port in self.software_manager.get_open_ports():
|
||||
table.add_row([port.value, port.name])
|
||||
print(table)
|
||||
|
||||
def _show_nic(self, markdown: bool = False):
|
||||
"""Prints a table of the NICs on the Node."""
|
||||
table = PrettyTable(["Port", "MAC Address", "Address", "Speed", "Status"])
|
||||
if markdown:
|
||||
@@ -1048,29 +1065,30 @@ class Node(SimComponent):
|
||||
:param pings: The number of pings to attempt, default is 4.
|
||||
:return: True if the ping is successful, otherwise False.
|
||||
"""
|
||||
if not isinstance(target_ip_address, IPv4Address):
|
||||
target_ip_address = IPv4Address(target_ip_address)
|
||||
if target_ip_address.is_loopback:
|
||||
self.sys_log.info("Pinging loopback address")
|
||||
return any(nic.enabled for nic in self.nics.values())
|
||||
if self.operating_state == NodeOperatingState.ON:
|
||||
self.sys_log.info(f"Pinging {target_ip_address}:")
|
||||
sequence, identifier = 0, None
|
||||
while sequence < pings:
|
||||
sequence, identifier = self.icmp.ping(target_ip_address, sequence, identifier, pings)
|
||||
request_replies = self.icmp.request_replies.get(identifier)
|
||||
passed = request_replies == pings
|
||||
if request_replies:
|
||||
self.icmp.request_replies.pop(identifier)
|
||||
else:
|
||||
request_replies = 0
|
||||
self.sys_log.info(
|
||||
f"Ping statistics for {target_ip_address}: "
|
||||
f"Packets: Sent = {pings}, "
|
||||
f"Received = {request_replies}, "
|
||||
f"Lost = {pings-request_replies} ({(pings-request_replies)/pings*100}% loss)"
|
||||
)
|
||||
return passed
|
||||
if not isinstance(target_ip_address, IPv4Address):
|
||||
target_ip_address = IPv4Address(target_ip_address)
|
||||
if target_ip_address.is_loopback:
|
||||
self.sys_log.info("Pinging loopback address")
|
||||
return any(nic.enabled for nic in self.nics.values())
|
||||
if self.operating_state == NodeOperatingState.ON:
|
||||
self.sys_log.info(f"Pinging {target_ip_address}:")
|
||||
sequence, identifier = 0, None
|
||||
while sequence < pings:
|
||||
sequence, identifier = self.icmp.ping(target_ip_address, sequence, identifier, pings)
|
||||
request_replies = self.icmp.request_replies.get(identifier)
|
||||
passed = request_replies == pings
|
||||
if request_replies:
|
||||
self.icmp.request_replies.pop(identifier)
|
||||
else:
|
||||
request_replies = 0
|
||||
self.sys_log.info(
|
||||
f"Ping statistics for {target_ip_address}: "
|
||||
f"Packets: Sent = {pings}, "
|
||||
f"Received = {request_replies}, "
|
||||
f"Lost = {pings-request_replies} ({(pings-request_replies)/pings*100}% loss)"
|
||||
)
|
||||
return passed
|
||||
return False
|
||||
|
||||
def send_frame(self, frame: Frame):
|
||||
@@ -1079,7 +1097,8 @@ class Node(SimComponent):
|
||||
|
||||
:param frame: The Frame to be sent.
|
||||
"""
|
||||
nic: NIC = self._get_arp_cache_nic(frame.ip.dst_ip_address)
|
||||
if self.operating_state == NodeOperatingState.ON:
|
||||
nic: NIC = self._get_arp_cache_nic(frame.ip.dst_ip_address)
|
||||
nic.send_frame(frame)
|
||||
|
||||
def receive_frame(self, frame: Frame, from_nic: NIC):
|
||||
@@ -1092,20 +1111,27 @@ class Node(SimComponent):
|
||||
:param frame: The Frame being received.
|
||||
:param from_nic: The NIC that received the frame.
|
||||
"""
|
||||
if frame.ip:
|
||||
if frame.ip.src_ip_address in self.arp:
|
||||
self.arp.add_arp_cache_entry(
|
||||
ip_address=frame.ip.src_ip_address, mac_address=frame.ethernet.src_mac_addr, nic=from_nic
|
||||
)
|
||||
if frame.ip.protocol == IPProtocol.TCP:
|
||||
if frame.tcp.src_port == Port.ARP:
|
||||
self.arp.process_arp_packet(from_nic=from_nic, arp_packet=frame.arp)
|
||||
if self.operating_state == NodeOperatingState.ON:
|
||||
if frame.ip:
|
||||
if frame.ip.src_ip_address in self.arp:
|
||||
self.arp.add_arp_cache_entry(
|
||||
ip_address=frame.ip.src_ip_address, mac_address=frame.ethernet.src_mac_addr, nic=from_nic
|
||||
)
|
||||
if frame.ip.protocol == IPProtocol.ICMP:
|
||||
self.icmp.process_icmp(frame=frame, from_nic=from_nic)
|
||||
return
|
||||
# Check if the destination port is open on the Node
|
||||
if frame.tcp.dst_port in self.software_manager.get_open_ports():
|
||||
# accept thr frame as the port is open
|
||||
if frame.tcp.src_port == Port.ARP:
|
||||
self.arp.process_arp_packet(from_nic=from_nic, arp_packet=frame.arp)
|
||||
else:
|
||||
self.session_manager.receive_frame(frame)
|
||||
else:
|
||||
self.session_manager.receive_frame(frame)
|
||||
elif frame.ip.protocol == IPProtocol.UDP:
|
||||
pass
|
||||
elif frame.ip.protocol == IPProtocol.ICMP:
|
||||
self.icmp.process_icmp(frame=frame, from_nic=from_nic)
|
||||
# denied as port closed
|
||||
self.sys_log.info(f"Ignoring frame for port {frame.tcp.dst_port.value} from {frame.ip.src_ip_address}")
|
||||
# TODO: do we need to do anything more here?
|
||||
pass
|
||||
|
||||
def install_service(self, service: Service) -> None:
|
||||
"""
|
||||
|
||||
@@ -6,6 +6,7 @@ from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.services.database import DatabaseService
|
||||
|
||||
|
||||
@@ -149,6 +150,9 @@ def arcd_uc2_network() -> Network:
|
||||
hostname="web_server", ip_address="192.168.1.12", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
|
||||
)
|
||||
web_server.power_on()
|
||||
web_server.software_manager.install(DatabaseClient)
|
||||
database_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
|
||||
database_client.run()
|
||||
network.connect(endpoint_b=web_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[2])
|
||||
|
||||
# Database Server
|
||||
@@ -187,12 +191,12 @@ def arcd_uc2_network() -> Network:
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');", # noqa
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');", # noqa
|
||||
]
|
||||
database_server.software_manager.add_service(DatabaseService)
|
||||
database: DatabaseService = database_server.software_manager.services["Database"] # noqa
|
||||
database.start()
|
||||
database._process_sql(ddl) # noqa
|
||||
database_server.software_manager.install(DatabaseService)
|
||||
database_service: DatabaseService = database_server.software_manager.software["DatabaseService"] # noqa
|
||||
database_service.start()
|
||||
database_service._process_sql(ddl) # noqa
|
||||
for insert_statement in user_insert_statements:
|
||||
database._process_sql(insert_statement) # noqa
|
||||
database_service._process_sql(insert_statement) # noqa
|
||||
|
||||
# Backup Server
|
||||
backup_server = Server(
|
||||
|
||||
@@ -23,9 +23,9 @@ class Application(IOSoftware):
|
||||
Applications are user-facing programs that may perform input/output operations.
|
||||
"""
|
||||
|
||||
operating_state: ApplicationOperatingState
|
||||
operating_state: ApplicationOperatingState = ApplicationOperatingState.CLOSED
|
||||
"The current operating state of the Application."
|
||||
execution_control_status: str
|
||||
execution_control_status: str = "manual"
|
||||
"Control status of the application's execution. It could be 'manual' or 'automatic'."
|
||||
num_executions: int = 0
|
||||
"The number of times the application has been executed. Default is 0."
|
||||
@@ -53,6 +53,25 @@ class Application(IOSoftware):
|
||||
)
|
||||
return state
|
||||
|
||||
def run(self) -> None:
|
||||
"""Open the Application"""
|
||||
if self.operating_state == ApplicationOperatingState.CLOSED:
|
||||
self.sys_log.info(f"Running Application {self.name}")
|
||||
self.operating_state = ApplicationOperatingState.RUNNING
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the Application"""
|
||||
if self.operating_state == ApplicationOperatingState.RUNNING:
|
||||
self.sys_log.info(f"Closed Application{self.name}")
|
||||
self.operating_state = ApplicationOperatingState.CLOSED
|
||||
|
||||
def install(self) -> None:
|
||||
"""Install Application."""
|
||||
super().install()
|
||||
if self.operating_state == ApplicationOperatingState.CLOSED:
|
||||
self.sys_log.info(f"Installing Application {self.name}")
|
||||
self.operating_state = ApplicationOperatingState.INSTALLING
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
Resets the Application component for a new episode.
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from prettytable import PrettyTable
|
||||
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
|
||||
|
||||
class DatabaseClient(Application):
|
||||
server_ip_address: Optional[IPv4Address] = None
|
||||
connected: bool = False
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "DatabaseClient"
|
||||
kwargs["port"] = Port.POSTGRES_SERVER
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
return super().describe_state()
|
||||
|
||||
def connect(self, server_ip_address: IPv4Address, password: Optional[str] = None) -> bool:
|
||||
if not self.connected and self.operating_state.RUNNING:
|
||||
return self._connect(server_ip_address, password)
|
||||
|
||||
def _connect(
|
||||
self, server_ip_address: IPv4Address, password: Optional[str] = None, is_reattempt: bool = False
|
||||
) -> bool:
|
||||
if is_reattempt:
|
||||
if self.connected:
|
||||
self.sys_log.info(f"DatabaseClient connected to {server_ip_address} authorised")
|
||||
self.server_ip_address = server_ip_address
|
||||
return self.connected
|
||||
else:
|
||||
self.sys_log.info(f"DatabaseClient connected to {server_ip_address} declined")
|
||||
payload = {"type": "connect_request", "password": password}
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_ip_address=server_ip_address, dest_port=self.port
|
||||
)
|
||||
return self._connect(server_ip_address, password, True)
|
||||
|
||||
def disconnect(self):
|
||||
if self.connected and self.operating_state.RUNNING:
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "disconnect"}, dest_ip_address=self.server_ip_address, dest_port=self.port
|
||||
)
|
||||
|
||||
self.sys_log.info(f"DatabaseClient disconnected from {self.server_ip_address}")
|
||||
self.server_ip_address = None
|
||||
|
||||
def query(self, sql: str):
|
||||
if self.connected and self.operating_state.RUNNING:
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "sql", "sql": sql}, dest_ip_address=self.server_ip_address, dest_port=self.port
|
||||
)
|
||||
|
||||
def _print_data(self, data: Dict):
|
||||
"""
|
||||
Display the contents of the Folder in tabular format.
|
||||
|
||||
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
|
||||
"""
|
||||
table = PrettyTable(list(data.values())[0])
|
||||
|
||||
table.align = "l"
|
||||
table.title = f"{self.sys_log.hostname} Database Client"
|
||||
for row in data.values():
|
||||
table.add_row(row.values())
|
||||
print(table)
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
if isinstance(payload, dict) and payload.get("type"):
|
||||
if payload["type"] == "connect_response":
|
||||
self.connected = payload["response"] == True
|
||||
elif payload["type"] == "sql":
|
||||
self._print_data(payload["data"])
|
||||
return True
|
||||
@@ -1,15 +1,15 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.application import Application, ApplicationOperatingState
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.software import SoftwareType
|
||||
from primaite.simulator.system.services.service import Service, ServiceOperatingState
|
||||
from primaite.simulator.system.software import IOSoftware, SoftwareType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.simulator.system.core.session_manager import SessionManager
|
||||
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from typing import Type, TypeVar
|
||||
|
||||
ServiceClass = TypeVar("ServiceClass", bound=Service)
|
||||
IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware)
|
||||
|
||||
|
||||
class SoftwareManager:
|
||||
@@ -30,57 +30,55 @@ class SoftwareManager:
|
||||
:param session_manager: The session manager handling network communications.
|
||||
"""
|
||||
self.session_manager = session_manager
|
||||
self.services: Dict[str, Service] = {}
|
||||
self.applications: Dict[str, Application] = {}
|
||||
self.software: Dict[str, Union[Service, Application]] = {}
|
||||
self._software_class_to_name_map: Dict[Type[IOSoftwareClass], str] = {}
|
||||
self.port_protocol_mapping: Dict[Tuple[Port, IPProtocol], Union[Service, Application]] = {}
|
||||
self.sys_log: SysLog = sys_log
|
||||
self.file_system: FileSystem = file_system
|
||||
|
||||
def add_service(self, service_class: Type[ServiceClass]):
|
||||
"""
|
||||
Add a Service to the manager.
|
||||
def get_open_ports(self) -> List[Port]:
|
||||
open_ports = [Port.ARP]
|
||||
for software in self.port_protocol_mapping.values():
|
||||
if software.operating_state in {ApplicationOperatingState.RUNNING, ServiceOperatingState.RUNNING}:
|
||||
open_ports.append(software.port)
|
||||
open_ports.sort(key=lambda port: port.value)
|
||||
return open_ports
|
||||
|
||||
:param: service_class: The class of the service to add
|
||||
"""
|
||||
service = service_class(software_manager=self, sys_log=self.sys_log, file_system=self.file_system)
|
||||
def install(self, software_class: Type[IOSoftwareClass]):
|
||||
if software_class in self._software_class_to_name_map:
|
||||
self.sys_log.info(f"Cannot install {software_class} as it is already installed")
|
||||
return
|
||||
software = software_class(software_manager=self, sys_log=self.sys_log, file_system=self.file_system)
|
||||
if isinstance(software, Application):
|
||||
software.install()
|
||||
software.software_manager = self
|
||||
self.software[software.name] = software
|
||||
self.port_protocol_mapping[(software.port, software.protocol)] = software
|
||||
self.sys_log.info(f"Installed {software.name}")
|
||||
if isinstance(software, Application):
|
||||
software.operating_state = ApplicationOperatingState.CLOSED
|
||||
|
||||
service.software_manager = self
|
||||
self.services[service.name] = service
|
||||
self.port_protocol_mapping[(service.port, service.protocol)] = service
|
||||
def uninstall(self, software_name: str):
|
||||
if software_name in self.software:
|
||||
software = self.software.pop(software_name) # noqa
|
||||
del software
|
||||
self.sys_log.info(f"Deleted {software_name}")
|
||||
return
|
||||
self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")
|
||||
|
||||
def add_application(self, name: str, application: Application, port: Port, protocol: IPProtocol):
|
||||
"""
|
||||
Add an Application to the manager.
|
||||
|
||||
:param name: The name of the application.
|
||||
:param application: The application instance.
|
||||
:param port: The port used by the application.
|
||||
:param protocol: The network protocol used by the application.
|
||||
"""
|
||||
application.software_manager = self
|
||||
self.applications[name] = application
|
||||
self.port_protocol_mapping[(port, protocol)] = application
|
||||
|
||||
def send_internal_payload(self, target_software: str, target_software_type: SoftwareType, payload: Any):
|
||||
def send_internal_payload(self, target_software: str, payload: Any):
|
||||
"""
|
||||
Send a payload to a specific service or application.
|
||||
|
||||
:param target_software: The name of the target service or application.
|
||||
:param target_software_type: The type of software (Service, Application, Process).
|
||||
:param payload: The data to be sent.
|
||||
:param receiver_type: The type of the target, either 'service' or 'application'.
|
||||
"""
|
||||
if target_software_type is SoftwareType.SERVICE:
|
||||
receiver = self.services.get(target_software)
|
||||
elif target_software_type is SoftwareType.APPLICATION:
|
||||
receiver = self.applications.get(target_software)
|
||||
else:
|
||||
raise ValueError(f"Invalid receiver type {target_software_type}")
|
||||
receiver = self.software.get(target_software)
|
||||
|
||||
if receiver:
|
||||
receiver.receive_payload(payload)
|
||||
else:
|
||||
raise ValueError(f"No {target_software_type.name.lower()} found with the name {target_software}")
|
||||
self.sys_log.error(f"No Service of Application found with the name {target_software}")
|
||||
|
||||
def send_payload_to_session_manager(
|
||||
self,
|
||||
@@ -121,13 +119,20 @@ class SoftwareManager:
|
||||
|
||||
:param markdown: If True, outputs the table in markdown format. Default is False.
|
||||
"""
|
||||
table = PrettyTable(["Name", "Operating State", "Health State", "Port"])
|
||||
table = PrettyTable(["Name", "Type", "Operating State", "Health State", "Port"])
|
||||
if markdown:
|
||||
table.set_style(MARKDOWN)
|
||||
table.align = "l"
|
||||
table.title = f"{self.sys_log.hostname} Software Manager"
|
||||
for service in self.services.values():
|
||||
for software in self.port_protocol_mapping.values():
|
||||
software_type = "Service" if isinstance(software, Service) else "Application"
|
||||
table.add_row(
|
||||
[service.name, service.operating_state.name, service.health_state_actual.name, service.port.value]
|
||||
[
|
||||
software.name,
|
||||
software_type,
|
||||
software.operating_state.name,
|
||||
software.health_state_actual.name,
|
||||
software.port.value,
|
||||
]
|
||||
)
|
||||
print(table)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from ipaddress import IPv4Address
|
||||
from sqlite3 import OperationalError
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
@@ -8,8 +9,10 @@ from prettytable import MARKDOWN, PrettyTable
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.core.session_manager import Session
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.services.service import Service, ServiceOperatingState
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
|
||||
|
||||
class DatabaseService(Service):
|
||||
@@ -19,11 +22,11 @@ class DatabaseService(Service):
|
||||
This class inherits from the `Service` class and provides methods to manage and query a SQLite database.
|
||||
"""
|
||||
|
||||
backup_server: Optional[IPv4Address] = None
|
||||
"The IP Address of the server the "
|
||||
password: Optional[str] = None
|
||||
connections: Dict[str, datetime] = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "Database"
|
||||
kwargs["name"] = "DatabaseService"
|
||||
kwargs["port"] = Port.POSTGRES_SERVER
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
super().__init__(**kwargs)
|
||||
@@ -62,6 +65,24 @@ class DatabaseService(Service):
|
||||
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db", real=True)
|
||||
self.folder = self._db_file.folder
|
||||
|
||||
def _process_connect(
|
||||
self, session_id: str, password: Optional[str] = None
|
||||
) -> Dict[str, Union[int, Dict[str, bool]]]:
|
||||
status_code = 500 # Default internal server error
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
status_code = 503 # service unavailable
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
if self.password == password:
|
||||
status_code = 200 # ok
|
||||
self.connections[session_id] = datetime.now()
|
||||
self.sys_log.info(f"Connect request for {session_id=} authorised")
|
||||
else:
|
||||
status_code = 401 # Unauthorised
|
||||
self.sys_log.info(f"Connect request for {session_id=} declined")
|
||||
else:
|
||||
status_code = 404 # service not found
|
||||
return {"status_code": status_code, "type": "connect_response", "response": status_code == 200}
|
||||
|
||||
def _process_sql(self, query: str) -> Dict[str, Union[int, List[Any]]]:
|
||||
"""
|
||||
Executes the given SQL query and returns the result.
|
||||
@@ -71,12 +92,21 @@ class DatabaseService(Service):
|
||||
"""
|
||||
try:
|
||||
self._cursor.execute(query)
|
||||
|
||||
self._conn.commit()
|
||||
except OperationalError:
|
||||
# Handle the case where the table does not exist.
|
||||
return {"status_code": 404, "data": []}
|
||||
|
||||
return {"status_code": 200, "data": self._cursor.fetchall()}
|
||||
data = []
|
||||
description = self._cursor.description
|
||||
if description:
|
||||
headers = []
|
||||
for header in description:
|
||||
headers.append(header[0])
|
||||
data = self._cursor.fetchall()
|
||||
if data and headers:
|
||||
data = {row[0]: {header: value for header, value in zip(headers, row)} for row in data}
|
||||
return {"status_code": 200, "type": "sql", "data": data}
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
@@ -97,10 +127,20 @@ class DatabaseService(Service):
|
||||
:param session_id: The session identifier.
|
||||
:return: True if the Status Code is 200, otherwise False.
|
||||
"""
|
||||
result = self._process_sql(payload)
|
||||
result = {"status_code": 500, "data": []}
|
||||
if isinstance(payload, dict) and payload.get("type"):
|
||||
if payload["type"] == "connect_request":
|
||||
result = self._process_connect(session_id=session_id, password=payload.get("password"))
|
||||
elif payload["type"] == "disconnect":
|
||||
if session_id in self.connections:
|
||||
self.connections.pop(session_id)
|
||||
elif payload["type"] == "sql":
|
||||
if session_id in self.connections:
|
||||
result = self._process_sql(payload.get("sql"))
|
||||
else:
|
||||
result = {"status_code": 401, "type": "sql"}
|
||||
self.send(payload=result, session_id=session_id)
|
||||
|
||||
return payload["status_code"] == 200
|
||||
return True
|
||||
|
||||
def send(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
"""
|
||||
|
||||
@@ -98,35 +98,30 @@ class Service(IOSoftware):
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the service."""
|
||||
_LOGGER.debug(f"Stopping service {self.name}")
|
||||
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
|
||||
self.sys_log.info(f"Stopping service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.STOPPED
|
||||
|
||||
def start(self, **kwargs) -> None:
|
||||
"""Start the service."""
|
||||
_LOGGER.debug(f"Starting service {self.name}")
|
||||
if self.operating_state == ServiceOperatingState.STOPPED:
|
||||
self.sys_log.info(f"Starting service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RUNNING
|
||||
|
||||
def pause(self) -> None:
|
||||
"""Pause the service."""
|
||||
_LOGGER.debug(f"Pausing service {self.name}")
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
self.sys_log.info(f"Pausing service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.PAUSED
|
||||
|
||||
def resume(self) -> None:
|
||||
"""Resume paused service."""
|
||||
_LOGGER.debug(f"Resuming service {self.name}")
|
||||
if self.operating_state == ServiceOperatingState.PAUSED:
|
||||
self.sys_log.info(f"Resuming service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RUNNING
|
||||
|
||||
def restart(self) -> None:
|
||||
"""Restart running service."""
|
||||
_LOGGER.debug(f"Restarting service {self.name}")
|
||||
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
|
||||
self.sys_log.info(f"Pausing service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RESTARTING
|
||||
@@ -134,13 +129,11 @@ class Service(IOSoftware):
|
||||
|
||||
def disable(self) -> None:
|
||||
"""Disable the service."""
|
||||
_LOGGER.debug(f"Disabling service {self.name}")
|
||||
self.sys_log.info(f"Disabling Application {self.name}")
|
||||
self.operating_state = ServiceOperatingState.DISABLED
|
||||
|
||||
def enable(self) -> None:
|
||||
"""Enable the disabled service."""
|
||||
_LOGGER.debug(f"Enabling service {self.name}")
|
||||
if self.operating_state == ServiceOperatingState.DISABLED:
|
||||
self.sys_log.info(f"Enabling Application {self.name}")
|
||||
self.operating_state = ServiceOperatingState.STOPPED
|
||||
|
||||
@@ -12,6 +12,8 @@ import pytest
|
||||
from primaite import getLogger
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.primaite_session import PrimaiteSession
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.networks import arcd_uc2_network
|
||||
from tests.mock_and_patch.get_session_path_mock import get_temp_session_path
|
||||
|
||||
ACTION_SPACE_NODE_VALUES = 1
|
||||
@@ -24,6 +26,11 @@ from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.network.hardware.base import Node
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def uc2_network() -> Network:
|
||||
return arcd_uc2_network()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def file_system() -> FileSystem:
|
||||
return Node(hostname="fs_node").file_system
|
||||
|
||||
@@ -1,39 +1,38 @@
|
||||
from ipaddress import IPv4Address
|
||||
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.networks import arcd_uc2_network
|
||||
from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
|
||||
from primaite.simulator.network.transmission.network_layer import IPPacket, Precedence
|
||||
from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.services.database import DatabaseService
|
||||
|
||||
|
||||
def test_database_query_across_the_network():
|
||||
def test_database_client_server_connection(uc2_network):
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
|
||||
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
|
||||
|
||||
assert len(db_service.connections) == 0
|
||||
|
||||
assert db_client.connect(server_ip_address=IPv4Address("192.168.1.14"))
|
||||
assert len(db_service.connections) == 1
|
||||
|
||||
db_client.disconnect()
|
||||
assert len(db_service.connections) == 0
|
||||
|
||||
|
||||
def test_database_client_query(uc2_network):
|
||||
"""Tests DB query across the network returns HTTP status 200 and date."""
|
||||
network = arcd_uc2_network()
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
|
||||
|
||||
client_1: Computer = network.get_node_by_hostname("client_1")
|
||||
db_client.connect(server_ip_address=IPv4Address("192.168.1.14"))
|
||||
|
||||
client_1.arp.send_arp_request(IPv4Address("192.168.1.14"))
|
||||
db_client.query("SELECT * FROM user;")
|
||||
|
||||
dst_mac_address = client_1.arp.get_arp_cache_mac_address(IPv4Address("192.168.1.14"))
|
||||
web_server_nic = web_server.ethernet_port[1]
|
||||
|
||||
outbound_nic = client_1.arp.get_arp_cache_nic(IPv4Address("192.168.1.14"))
|
||||
client_1.ping("192.168.1.14")
|
||||
web_server_last_payload = web_server_nic.pcap.read()[-1]["payload"]
|
||||
|
||||
frame = Frame(
|
||||
ethernet=EthernetHeader(src_mac_addr=client_1.ethernet_port[1].mac_address, dst_mac_addr=dst_mac_address),
|
||||
ip=IPPacket(
|
||||
src_ip_address=client_1.ethernet_port[1].ip_address,
|
||||
dst_ip_address=IPv4Address("192.168.1.14"),
|
||||
precedence=Precedence.FLASH,
|
||||
),
|
||||
tcp=TCPHeader(src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER),
|
||||
payload="SELECT * FROM user;",
|
||||
)
|
||||
|
||||
outbound_nic.send_frame(frame)
|
||||
|
||||
client_1_last_payload = outbound_nic.pcap.read()[-1]["payload"]
|
||||
|
||||
assert client_1_last_payload["status_code"] == 200
|
||||
assert client_1_last_payload["data"]
|
||||
assert web_server_last_payload["status_code"] == 200
|
||||
assert web_server_last_payload["data"]
|
||||
|
||||
Reference in New Issue
Block a user