Merge branch 'feature/1816_Database-Service-(Network-and-User-Interaction)' into feature/1752-dns-server-and-client
This commit is contained in:
170
src/primaite/simulator/file_system/file_type.py
Normal file
170
src/primaite/simulator/file_system/file_type.py
Normal file
@@ -0,0 +1,170 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from random import choice
|
||||
|
||||
|
||||
class FileType(Enum):
|
||||
"""An enumeration of common file types."""
|
||||
|
||||
UNKNOWN = 0
|
||||
"Unknown file type."
|
||||
|
||||
# Text formats
|
||||
TXT = 1
|
||||
"Plain text file."
|
||||
DOC = 2
|
||||
"Microsoft Word document (.doc)"
|
||||
DOCX = 3
|
||||
"Microsoft Word document (.docx)"
|
||||
PDF = 4
|
||||
"Portable Document Format."
|
||||
HTML = 5
|
||||
"HyperText Markup Language file."
|
||||
XML = 6
|
||||
"Extensible Markup Language file."
|
||||
CSV = 7
|
||||
"Comma-Separated Values file."
|
||||
|
||||
# Spreadsheet formats
|
||||
XLS = 8
|
||||
"Microsoft Excel file (.xls)"
|
||||
XLSX = 9
|
||||
"Microsoft Excel file (.xlsx)"
|
||||
|
||||
# Image formats
|
||||
JPEG = 10
|
||||
"JPEG image file."
|
||||
PNG = 11
|
||||
"PNG image file."
|
||||
GIF = 12
|
||||
"GIF image file."
|
||||
BMP = 13
|
||||
"Bitmap image file."
|
||||
|
||||
# Audio formats
|
||||
MP3 = 14
|
||||
"MP3 audio file."
|
||||
WAV = 15
|
||||
"WAV audio file."
|
||||
|
||||
# Video formats
|
||||
MP4 = 16
|
||||
"MP4 video file."
|
||||
AVI = 17
|
||||
"AVI video file."
|
||||
MKV = 18
|
||||
"MKV video file."
|
||||
FLV = 19
|
||||
"FLV video file."
|
||||
|
||||
# Presentation formats
|
||||
PPT = 20
|
||||
"Microsoft PowerPoint file (.ppt)"
|
||||
PPTX = 21
|
||||
"Microsoft PowerPoint file (.pptx)"
|
||||
|
||||
# Web formats
|
||||
JS = 22
|
||||
"JavaScript file."
|
||||
CSS = 23
|
||||
"Cascading Style Sheets file."
|
||||
|
||||
# Programming languages
|
||||
PY = 24
|
||||
"Python script file."
|
||||
C = 25
|
||||
"C source code file."
|
||||
CPP = 26
|
||||
"C++ source code file."
|
||||
JAVA = 27
|
||||
"Java source code file."
|
||||
|
||||
# Compressed file types
|
||||
RAR = 28
|
||||
"RAR archive file."
|
||||
ZIP = 29
|
||||
"ZIP archive file."
|
||||
TAR = 30
|
||||
"TAR archive file."
|
||||
GZ = 31
|
||||
"Gzip compressed file."
|
||||
|
||||
# Database file types
|
||||
DB = 32
|
||||
"Generic DB file. Used by sqlite3."
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
return cls.UNKNOWN
|
||||
|
||||
@classmethod
|
||||
def random(cls) -> FileType:
|
||||
"""
|
||||
Returns a random FileType.
|
||||
|
||||
:return: A random FileType.
|
||||
"""
|
||||
return choice(list(FileType))
|
||||
|
||||
@property
|
||||
def default_size(self) -> int:
|
||||
"""
|
||||
Get the default size of the FileType in bytes.
|
||||
|
||||
Returns 0 if a default size does not exist.
|
||||
"""
|
||||
size = file_type_sizes_bytes[self]
|
||||
return size if size else 0
|
||||
|
||||
|
||||
def get_file_type_from_extension(file_type_extension: str):
|
||||
"""
|
||||
Get a FileType from a file type extension.
|
||||
|
||||
If a matching extension does not exist, FileType.UNKNOWN is returned.
|
||||
|
||||
:param file_type_extension: A file type extension.
|
||||
:return: A file type extension.
|
||||
"""
|
||||
try:
|
||||
return FileType[file_type_extension.upper()]
|
||||
except KeyError:
|
||||
return FileType.UNKNOWN
|
||||
|
||||
|
||||
file_type_sizes_bytes = {
|
||||
FileType.UNKNOWN: 0,
|
||||
FileType.TXT: 4096,
|
||||
FileType.DOC: 51200,
|
||||
FileType.DOCX: 30720,
|
||||
FileType.PDF: 102400,
|
||||
FileType.HTML: 15360,
|
||||
FileType.XML: 10240,
|
||||
FileType.CSV: 15360,
|
||||
FileType.XLS: 102400,
|
||||
FileType.XLSX: 25600,
|
||||
FileType.JPEG: 102400,
|
||||
FileType.PNG: 40960,
|
||||
FileType.GIF: 30720,
|
||||
FileType.BMP: 307200,
|
||||
FileType.MP3: 5120000,
|
||||
FileType.WAV: 25600000,
|
||||
FileType.MP4: 25600000,
|
||||
FileType.AVI: 51200000,
|
||||
FileType.MKV: 51200000,
|
||||
FileType.FLV: 15360000,
|
||||
FileType.PPT: 204800,
|
||||
FileType.PPTX: 102400,
|
||||
FileType.JS: 10240,
|
||||
FileType.CSS: 5120,
|
||||
FileType.PY: 5120,
|
||||
FileType.C: 5120,
|
||||
FileType.CPP: 10240,
|
||||
FileType.JAVA: 10240,
|
||||
FileType.RAR: 1024000,
|
||||
FileType.ZIP: 1024000,
|
||||
FileType.TAR: 1024000,
|
||||
FileType.GZ: 819200,
|
||||
FileType.DB: 15360000,
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import secrets
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
@@ -959,7 +959,24 @@ class Node(SimComponent):
|
||||
)
|
||||
return state
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
def show(self, markdown: bool = False, component: Literal["NIC", "OPEN_PORTS"] = "NIC"):
|
||||
if component == "NIC":
|
||||
self._show_nic(markdown)
|
||||
elif component == "OPEN_PORTS":
|
||||
self._show_open_ports(markdown)
|
||||
|
||||
def _show_open_ports(self, markdown: bool = False):
|
||||
"""Prints a table of the open ports on the Node."""
|
||||
table = PrettyTable(["Port", "Name"])
|
||||
if markdown:
|
||||
table.set_style(MARKDOWN)
|
||||
table.align = "l"
|
||||
table.title = f"{self.hostname} Open Ports"
|
||||
for port in self.software_manager.get_open_ports():
|
||||
table.add_row([port.value, port.name])
|
||||
print(table)
|
||||
|
||||
def _show_nic(self, markdown: bool = False):
|
||||
"""Prints a table of the NICs on the Node."""
|
||||
table = PrettyTable(["Port", "MAC Address", "Address", "Speed", "Status"])
|
||||
if markdown:
|
||||
@@ -1048,29 +1065,30 @@ class Node(SimComponent):
|
||||
:param pings: The number of pings to attempt, default is 4.
|
||||
:return: True if the ping is successful, otherwise False.
|
||||
"""
|
||||
if not isinstance(target_ip_address, IPv4Address):
|
||||
target_ip_address = IPv4Address(target_ip_address)
|
||||
if target_ip_address.is_loopback:
|
||||
self.sys_log.info("Pinging loopback address")
|
||||
return any(nic.enabled for nic in self.nics.values())
|
||||
if self.operating_state == NodeOperatingState.ON:
|
||||
self.sys_log.info(f"Pinging {target_ip_address}:")
|
||||
sequence, identifier = 0, None
|
||||
while sequence < pings:
|
||||
sequence, identifier = self.icmp.ping(target_ip_address, sequence, identifier, pings)
|
||||
request_replies = self.icmp.request_replies.get(identifier)
|
||||
passed = request_replies == pings
|
||||
if request_replies:
|
||||
self.icmp.request_replies.pop(identifier)
|
||||
else:
|
||||
request_replies = 0
|
||||
self.sys_log.info(
|
||||
f"Ping statistics for {target_ip_address}: "
|
||||
f"Packets: Sent = {pings}, "
|
||||
f"Received = {request_replies}, "
|
||||
f"Lost = {pings-request_replies} ({(pings-request_replies)/pings*100}% loss)"
|
||||
)
|
||||
return passed
|
||||
if not isinstance(target_ip_address, IPv4Address):
|
||||
target_ip_address = IPv4Address(target_ip_address)
|
||||
if target_ip_address.is_loopback:
|
||||
self.sys_log.info("Pinging loopback address")
|
||||
return any(nic.enabled for nic in self.nics.values())
|
||||
if self.operating_state == NodeOperatingState.ON:
|
||||
self.sys_log.info(f"Pinging {target_ip_address}:")
|
||||
sequence, identifier = 0, None
|
||||
while sequence < pings:
|
||||
sequence, identifier = self.icmp.ping(target_ip_address, sequence, identifier, pings)
|
||||
request_replies = self.icmp.request_replies.get(identifier)
|
||||
passed = request_replies == pings
|
||||
if request_replies:
|
||||
self.icmp.request_replies.pop(identifier)
|
||||
else:
|
||||
request_replies = 0
|
||||
self.sys_log.info(
|
||||
f"Ping statistics for {target_ip_address}: "
|
||||
f"Packets: Sent = {pings}, "
|
||||
f"Received = {request_replies}, "
|
||||
f"Lost = {pings-request_replies} ({(pings-request_replies)/pings*100}% loss)"
|
||||
)
|
||||
return passed
|
||||
return False
|
||||
|
||||
def send_frame(self, frame: Frame):
|
||||
@@ -1079,7 +1097,8 @@ class Node(SimComponent):
|
||||
|
||||
:param frame: The Frame to be sent.
|
||||
"""
|
||||
nic: NIC = self._get_arp_cache_nic(frame.ip.dst_ip_address)
|
||||
if self.operating_state == NodeOperatingState.ON:
|
||||
nic: NIC = self._get_arp_cache_nic(frame.ip.dst_ip_address)
|
||||
nic.send_frame(frame)
|
||||
|
||||
def receive_frame(self, frame: Frame, from_nic: NIC):
|
||||
@@ -1092,20 +1111,27 @@ class Node(SimComponent):
|
||||
:param frame: The Frame being received.
|
||||
:param from_nic: The NIC that received the frame.
|
||||
"""
|
||||
if frame.ip:
|
||||
if frame.ip.src_ip_address in self.arp:
|
||||
self.arp.add_arp_cache_entry(
|
||||
ip_address=frame.ip.src_ip_address, mac_address=frame.ethernet.src_mac_addr, nic=from_nic
|
||||
)
|
||||
if frame.ip.protocol == IPProtocol.TCP:
|
||||
if frame.tcp.src_port == Port.ARP:
|
||||
self.arp.process_arp_packet(from_nic=from_nic, arp_packet=frame.arp)
|
||||
if self.operating_state == NodeOperatingState.ON:
|
||||
if frame.ip:
|
||||
if frame.ip.src_ip_address in self.arp:
|
||||
self.arp.add_arp_cache_entry(
|
||||
ip_address=frame.ip.src_ip_address, mac_address=frame.ethernet.src_mac_addr, nic=from_nic
|
||||
)
|
||||
if frame.ip.protocol == IPProtocol.ICMP:
|
||||
self.icmp.process_icmp(frame=frame, from_nic=from_nic)
|
||||
return
|
||||
# Check if the destination port is open on the Node
|
||||
if frame.tcp.dst_port in self.software_manager.get_open_ports():
|
||||
# accept thr frame as the port is open
|
||||
if frame.tcp.src_port == Port.ARP:
|
||||
self.arp.process_arp_packet(from_nic=from_nic, arp_packet=frame.arp)
|
||||
else:
|
||||
self.session_manager.receive_frame(frame)
|
||||
else:
|
||||
self.session_manager.receive_frame(frame)
|
||||
elif frame.ip.protocol == IPProtocol.UDP:
|
||||
pass
|
||||
elif frame.ip.protocol == IPProtocol.ICMP:
|
||||
self.icmp.process_icmp(frame=frame, from_nic=from_nic)
|
||||
# denied as port closed
|
||||
self.sys_log.info(f"Ignoring frame for port {frame.tcp.dst_port.value} from {frame.ip.src_ip_address}")
|
||||
# TODO: do we need to do anything more here?
|
||||
pass
|
||||
|
||||
def install_service(self, service: Service) -> None:
|
||||
"""
|
||||
|
||||
@@ -6,6 +6,7 @@ from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.services.database import DatabaseService
|
||||
|
||||
|
||||
@@ -149,6 +150,9 @@ def arcd_uc2_network() -> Network:
|
||||
hostname="web_server", ip_address="192.168.1.12", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
|
||||
)
|
||||
web_server.power_on()
|
||||
web_server.software_manager.install(DatabaseClient)
|
||||
database_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
|
||||
database_client.run()
|
||||
network.connect(endpoint_b=web_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[2])
|
||||
|
||||
# Database Server
|
||||
@@ -187,12 +191,12 @@ def arcd_uc2_network() -> Network:
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');", # noqa
|
||||
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');", # noqa
|
||||
]
|
||||
database_server.software_manager.add_service(DatabaseService)
|
||||
database: DatabaseService = database_server.software_manager.services["Database"] # noqa
|
||||
database.start()
|
||||
database._process_sql(ddl) # noqa
|
||||
database_server.software_manager.install(DatabaseService)
|
||||
database_service: DatabaseService = database_server.software_manager.software["DatabaseService"] # noqa
|
||||
database_service.start()
|
||||
database_service._process_sql(ddl) # noqa
|
||||
for insert_statement in user_insert_statements:
|
||||
database._process_sql(insert_statement) # noqa
|
||||
database_service._process_sql(insert_statement) # noqa
|
||||
|
||||
# Backup Server
|
||||
backup_server = Server(
|
||||
|
||||
@@ -23,9 +23,9 @@ class Application(IOSoftware):
|
||||
Applications are user-facing programs that may perform input/output operations.
|
||||
"""
|
||||
|
||||
operating_state: ApplicationOperatingState
|
||||
operating_state: ApplicationOperatingState = ApplicationOperatingState.CLOSED
|
||||
"The current operating state of the Application."
|
||||
execution_control_status: str
|
||||
execution_control_status: str = "manual"
|
||||
"Control status of the application's execution. It could be 'manual' or 'automatic'."
|
||||
num_executions: int = 0
|
||||
"The number of times the application has been executed. Default is 0."
|
||||
@@ -53,6 +53,25 @@ class Application(IOSoftware):
|
||||
)
|
||||
return state
|
||||
|
||||
def run(self) -> None:
|
||||
"""Open the Application"""
|
||||
if self.operating_state == ApplicationOperatingState.CLOSED:
|
||||
self.sys_log.info(f"Running Application {self.name}")
|
||||
self.operating_state = ApplicationOperatingState.RUNNING
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the Application"""
|
||||
if self.operating_state == ApplicationOperatingState.RUNNING:
|
||||
self.sys_log.info(f"Closed Application{self.name}")
|
||||
self.operating_state = ApplicationOperatingState.CLOSED
|
||||
|
||||
def install(self) -> None:
|
||||
"""Install Application."""
|
||||
super().install()
|
||||
if self.operating_state == ApplicationOperatingState.CLOSED:
|
||||
self.sys_log.info(f"Installing Application {self.name}")
|
||||
self.operating_state = ApplicationOperatingState.INSTALLING
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
Resets the Application component for a new episode.
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from prettytable import PrettyTable
|
||||
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
|
||||
|
||||
class DatabaseClient(Application):
|
||||
server_ip_address: Optional[IPv4Address] = None
|
||||
connected: bool = False
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "DatabaseClient"
|
||||
kwargs["port"] = Port.POSTGRES_SERVER
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
return super().describe_state()
|
||||
|
||||
def connect(self, server_ip_address: IPv4Address, password: Optional[str] = None) -> bool:
|
||||
if not self.connected and self.operating_state.RUNNING:
|
||||
return self._connect(server_ip_address, password)
|
||||
|
||||
def _connect(
|
||||
self, server_ip_address: IPv4Address, password: Optional[str] = None, is_reattempt: bool = False
|
||||
) -> bool:
|
||||
if is_reattempt:
|
||||
if self.connected:
|
||||
self.sys_log.info(f"DatabaseClient connected to {server_ip_address} authorised")
|
||||
self.server_ip_address = server_ip_address
|
||||
return self.connected
|
||||
else:
|
||||
self.sys_log.info(f"DatabaseClient connected to {server_ip_address} declined")
|
||||
payload = {"type": "connect_request", "password": password}
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_ip_address=server_ip_address, dest_port=self.port
|
||||
)
|
||||
return self._connect(server_ip_address, password, True)
|
||||
|
||||
def disconnect(self):
|
||||
if self.connected and self.operating_state.RUNNING:
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "disconnect"}, dest_ip_address=self.server_ip_address, dest_port=self.port
|
||||
)
|
||||
|
||||
self.sys_log.info(f"DatabaseClient disconnected from {self.server_ip_address}")
|
||||
self.server_ip_address = None
|
||||
|
||||
def query(self, sql: str):
|
||||
if self.connected and self.operating_state.RUNNING:
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "sql", "sql": sql}, dest_ip_address=self.server_ip_address, dest_port=self.port
|
||||
)
|
||||
|
||||
def _print_data(self, data: Dict):
|
||||
"""
|
||||
Display the contents of the Folder in tabular format.
|
||||
|
||||
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
|
||||
"""
|
||||
table = PrettyTable(list(data.values())[0])
|
||||
|
||||
table.align = "l"
|
||||
table.title = f"{self.sys_log.hostname} Database Client"
|
||||
for row in data.values():
|
||||
table.add_row(row.values())
|
||||
print(table)
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
if isinstance(payload, dict) and payload.get("type"):
|
||||
if payload["type"] == "connect_response":
|
||||
self.connected = payload["response"] == True
|
||||
elif payload["type"] == "sql":
|
||||
self._print_data(payload["data"])
|
||||
return True
|
||||
@@ -1,15 +1,15 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.application import Application, ApplicationOperatingState
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.software import SoftwareType
|
||||
from primaite.simulator.system.services.service import Service, ServiceOperatingState
|
||||
from primaite.simulator.system.software import IOSoftware, SoftwareType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.simulator.system.core.session_manager import SessionManager
|
||||
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from typing import Type, TypeVar
|
||||
|
||||
ServiceClass = TypeVar("ServiceClass", bound=Service)
|
||||
IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware)
|
||||
|
||||
|
||||
class SoftwareManager:
|
||||
@@ -30,57 +30,55 @@ class SoftwareManager:
|
||||
:param session_manager: The session manager handling network communications.
|
||||
"""
|
||||
self.session_manager = session_manager
|
||||
self.services: Dict[str, Service] = {}
|
||||
self.applications: Dict[str, Application] = {}
|
||||
self.software: Dict[str, Union[Service, Application]] = {}
|
||||
self._software_class_to_name_map: Dict[Type[IOSoftwareClass], str] = {}
|
||||
self.port_protocol_mapping: Dict[Tuple[Port, IPProtocol], Union[Service, Application]] = {}
|
||||
self.sys_log: SysLog = sys_log
|
||||
self.file_system: FileSystem = file_system
|
||||
|
||||
def add_service(self, service_class: Type[ServiceClass]):
|
||||
"""
|
||||
Add a Service to the manager.
|
||||
def get_open_ports(self) -> List[Port]:
|
||||
open_ports = [Port.ARP]
|
||||
for software in self.port_protocol_mapping.values():
|
||||
if software.operating_state in {ApplicationOperatingState.RUNNING, ServiceOperatingState.RUNNING}:
|
||||
open_ports.append(software.port)
|
||||
open_ports.sort(key=lambda port: port.value)
|
||||
return open_ports
|
||||
|
||||
:param: service_class: The class of the service to add
|
||||
"""
|
||||
service = service_class(software_manager=self, sys_log=self.sys_log, file_system=self.file_system)
|
||||
def install(self, software_class: Type[IOSoftwareClass]):
|
||||
if software_class in self._software_class_to_name_map:
|
||||
self.sys_log.info(f"Cannot install {software_class} as it is already installed")
|
||||
return
|
||||
software = software_class(software_manager=self, sys_log=self.sys_log, file_system=self.file_system)
|
||||
if isinstance(software, Application):
|
||||
software.install()
|
||||
software.software_manager = self
|
||||
self.software[software.name] = software
|
||||
self.port_protocol_mapping[(software.port, software.protocol)] = software
|
||||
self.sys_log.info(f"Installed {software.name}")
|
||||
if isinstance(software, Application):
|
||||
software.operating_state = ApplicationOperatingState.CLOSED
|
||||
|
||||
service.software_manager = self
|
||||
self.services[service.name] = service
|
||||
self.port_protocol_mapping[(service.port, service.protocol)] = service
|
||||
def uninstall(self, software_name: str):
|
||||
if software_name in self.software:
|
||||
software = self.software.pop(software_name) # noqa
|
||||
del software
|
||||
self.sys_log.info(f"Deleted {software_name}")
|
||||
return
|
||||
self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")
|
||||
|
||||
def add_application(self, name: str, application: Application, port: Port, protocol: IPProtocol):
|
||||
"""
|
||||
Add an Application to the manager.
|
||||
|
||||
:param name: The name of the application.
|
||||
:param application: The application instance.
|
||||
:param port: The port used by the application.
|
||||
:param protocol: The network protocol used by the application.
|
||||
"""
|
||||
application.software_manager = self
|
||||
self.applications[name] = application
|
||||
self.port_protocol_mapping[(port, protocol)] = application
|
||||
|
||||
def send_internal_payload(self, target_software: str, target_software_type: SoftwareType, payload: Any):
|
||||
def send_internal_payload(self, target_software: str, payload: Any):
|
||||
"""
|
||||
Send a payload to a specific service or application.
|
||||
|
||||
:param target_software: The name of the target service or application.
|
||||
:param target_software_type: The type of software (Service, Application, Process).
|
||||
:param payload: The data to be sent.
|
||||
:param receiver_type: The type of the target, either 'service' or 'application'.
|
||||
"""
|
||||
if target_software_type is SoftwareType.SERVICE:
|
||||
receiver = self.services.get(target_software)
|
||||
elif target_software_type is SoftwareType.APPLICATION:
|
||||
receiver = self.applications.get(target_software)
|
||||
else:
|
||||
raise ValueError(f"Invalid receiver type {target_software_type}")
|
||||
receiver = self.software.get(target_software)
|
||||
|
||||
if receiver:
|
||||
receiver.receive_payload(payload)
|
||||
else:
|
||||
raise ValueError(f"No {target_software_type.name.lower()} found with the name {target_software}")
|
||||
self.sys_log.error(f"No Service of Application found with the name {target_software}")
|
||||
|
||||
def send_payload_to_session_manager(
|
||||
self,
|
||||
@@ -121,13 +119,20 @@ class SoftwareManager:
|
||||
|
||||
:param markdown: If True, outputs the table in markdown format. Default is False.
|
||||
"""
|
||||
table = PrettyTable(["Name", "Operating State", "Health State", "Port"])
|
||||
table = PrettyTable(["Name", "Type", "Operating State", "Health State", "Port"])
|
||||
if markdown:
|
||||
table.set_style(MARKDOWN)
|
||||
table.align = "l"
|
||||
table.title = f"{self.sys_log.hostname} Software Manager"
|
||||
for service in self.services.values():
|
||||
for software in self.port_protocol_mapping.values():
|
||||
software_type = "Service" if isinstance(software, Service) else "Application"
|
||||
table.add_row(
|
||||
[service.name, service.operating_state.name, service.health_state_actual.name, service.port.value]
|
||||
[
|
||||
software.name,
|
||||
software_type,
|
||||
software.operating_state.name,
|
||||
software.health_state_actual.name,
|
||||
software.port.value,
|
||||
]
|
||||
)
|
||||
print(table)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from ipaddress import IPv4Address
|
||||
from sqlite3 import OperationalError
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
@@ -8,8 +9,10 @@ from prettytable import MARKDOWN, PrettyTable
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.core.session_manager import Session
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.services.service import Service, ServiceOperatingState
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
|
||||
|
||||
class DatabaseService(Service):
|
||||
@@ -19,11 +22,11 @@ class DatabaseService(Service):
|
||||
This class inherits from the `Service` class and provides methods to manage and query a SQLite database.
|
||||
"""
|
||||
|
||||
backup_server: Optional[IPv4Address] = None
|
||||
"The IP Address of the server the "
|
||||
password: Optional[str] = None
|
||||
connections: Dict[str, datetime] = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "Database"
|
||||
kwargs["name"] = "DatabaseService"
|
||||
kwargs["port"] = Port.POSTGRES_SERVER
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
super().__init__(**kwargs)
|
||||
@@ -62,6 +65,24 @@ class DatabaseService(Service):
|
||||
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db", real=True)
|
||||
self.folder = self._db_file.folder
|
||||
|
||||
def _process_connect(
|
||||
self, session_id: str, password: Optional[str] = None
|
||||
) -> Dict[str, Union[int, Dict[str, bool]]]:
|
||||
status_code = 500 # Default internal server error
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
status_code = 503 # service unavailable
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
if self.password == password:
|
||||
status_code = 200 # ok
|
||||
self.connections[session_id] = datetime.now()
|
||||
self.sys_log.info(f"Connect request for {session_id=} authorised")
|
||||
else:
|
||||
status_code = 401 # Unauthorised
|
||||
self.sys_log.info(f"Connect request for {session_id=} declined")
|
||||
else:
|
||||
status_code = 404 # service not found
|
||||
return {"status_code": status_code, "type": "connect_response", "response": status_code == 200}
|
||||
|
||||
def _process_sql(self, query: str) -> Dict[str, Union[int, List[Any]]]:
|
||||
"""
|
||||
Executes the given SQL query and returns the result.
|
||||
@@ -71,12 +92,21 @@ class DatabaseService(Service):
|
||||
"""
|
||||
try:
|
||||
self._cursor.execute(query)
|
||||
|
||||
self._conn.commit()
|
||||
except OperationalError:
|
||||
# Handle the case where the table does not exist.
|
||||
return {"status_code": 404, "data": []}
|
||||
|
||||
return {"status_code": 200, "data": self._cursor.fetchall()}
|
||||
data = []
|
||||
description = self._cursor.description
|
||||
if description:
|
||||
headers = []
|
||||
for header in description:
|
||||
headers.append(header[0])
|
||||
data = self._cursor.fetchall()
|
||||
if data and headers:
|
||||
data = {row[0]: {header: value for header, value in zip(headers, row)} for row in data}
|
||||
return {"status_code": 200, "type": "sql", "data": data}
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
@@ -97,10 +127,20 @@ class DatabaseService(Service):
|
||||
:param session_id: The session identifier.
|
||||
:return: True if the Status Code is 200, otherwise False.
|
||||
"""
|
||||
result = self._process_sql(payload)
|
||||
result = {"status_code": 500, "data": []}
|
||||
if isinstance(payload, dict) and payload.get("type"):
|
||||
if payload["type"] == "connect_request":
|
||||
result = self._process_connect(session_id=session_id, password=payload.get("password"))
|
||||
elif payload["type"] == "disconnect":
|
||||
if session_id in self.connections:
|
||||
self.connections.pop(session_id)
|
||||
elif payload["type"] == "sql":
|
||||
if session_id in self.connections:
|
||||
result = self._process_sql(payload.get("sql"))
|
||||
else:
|
||||
result = {"status_code": 401, "type": "sql"}
|
||||
self.send(payload=result, session_id=session_id)
|
||||
|
||||
return payload["status_code"] == 200
|
||||
return True
|
||||
|
||||
def send(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
"""
|
||||
|
||||
@@ -125,35 +125,30 @@ class Service(IOSoftware):
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the service."""
|
||||
_LOGGER.debug(f"Stopping service {self.name}")
|
||||
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
|
||||
self.sys_log.info(f"Stopping service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.STOPPED
|
||||
|
||||
def start(self, **kwargs) -> None:
|
||||
"""Start the service."""
|
||||
_LOGGER.debug(f"Starting service {self.name}")
|
||||
if self.operating_state == ServiceOperatingState.STOPPED:
|
||||
self.sys_log.info(f"Starting service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RUNNING
|
||||
|
||||
def pause(self) -> None:
|
||||
"""Pause the service."""
|
||||
_LOGGER.debug(f"Pausing service {self.name}")
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
self.sys_log.info(f"Pausing service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.PAUSED
|
||||
|
||||
def resume(self) -> None:
|
||||
"""Resume paused service."""
|
||||
_LOGGER.debug(f"Resuming service {self.name}")
|
||||
if self.operating_state == ServiceOperatingState.PAUSED:
|
||||
self.sys_log.info(f"Resuming service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RUNNING
|
||||
|
||||
def restart(self) -> None:
|
||||
"""Restart running service."""
|
||||
_LOGGER.debug(f"Restarting service {self.name}")
|
||||
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
|
||||
self.sys_log.info(f"Pausing service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RESTARTING
|
||||
@@ -161,13 +156,11 @@ class Service(IOSoftware):
|
||||
|
||||
def disable(self) -> None:
|
||||
"""Disable the service."""
|
||||
_LOGGER.debug(f"Disabling service {self.name}")
|
||||
self.sys_log.info(f"Disabling Application {self.name}")
|
||||
self.operating_state = ServiceOperatingState.DISABLED
|
||||
|
||||
def enable(self) -> None:
|
||||
"""Enable the disabled service."""
|
||||
_LOGGER.debug(f"Enabling service {self.name}")
|
||||
if self.operating_state == ServiceOperatingState.DISABLED:
|
||||
self.sys_log.info(f"Enabling Application {self.name}")
|
||||
self.operating_state = ServiceOperatingState.STOPPED
|
||||
|
||||
Reference in New Issue
Block a user