Merge branch 'dev' into feature/2041_2042-Add-NTP-Services
This commit is contained in:
@@ -1,10 +1,6 @@
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from ipaddress import IPv4Address
|
||||
from sqlite3 import OperationalError
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
@@ -19,7 +15,7 @@ class DatabaseService(Service):
|
||||
"""
|
||||
A class for simulating a generic SQL Server service.
|
||||
|
||||
This class inherits from the `Service` class and provides methods to manage and query a SQLite database.
|
||||
This class inherits from the `Service` class and provides methods to simulate a SQL database.
|
||||
"""
|
||||
|
||||
password: Optional[str] = None
|
||||
@@ -41,38 +37,6 @@ class DatabaseService(Service):
|
||||
super().__init__(**kwargs)
|
||||
self._db_file: File
|
||||
self._create_db_file()
|
||||
self._connect()
|
||||
|
||||
def _connect(self):
|
||||
self._conn = sqlite3.connect(self._db_file.sim_path)
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
def tables(self) -> List[str]:
|
||||
"""
|
||||
Get a list of table names present in the database.
|
||||
|
||||
:return: List of table names.
|
||||
"""
|
||||
sql = "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';"
|
||||
results = self._process_sql(sql, None)
|
||||
if isinstance(results["data"], dict):
|
||||
return list(results["data"].keys())
|
||||
return []
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
"""
|
||||
Prints a list of table names in the database using PrettyTable.
|
||||
|
||||
:param markdown: Whether to output the table in Markdown format.
|
||||
"""
|
||||
table = PrettyTable(["Table"])
|
||||
if markdown:
|
||||
table.set_style(MARKDOWN)
|
||||
table.align = "l"
|
||||
table.title = f"{self.file_system.sys_log.hostname} Database"
|
||||
for row in self.tables():
|
||||
table.add_row([row])
|
||||
print(table)
|
||||
|
||||
def configure_backup(self, backup_server: IPv4Address):
|
||||
"""
|
||||
@@ -89,8 +53,6 @@ class DatabaseService(Service):
|
||||
self.sys_log.error(f"{self.name} - {self.sys_log.hostname}: not configured.")
|
||||
return False
|
||||
|
||||
self._conn.close()
|
||||
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
ftp_client_service: FTPClient = software_manager.software["FTPClient"]
|
||||
|
||||
@@ -98,12 +60,10 @@ class DatabaseService(Service):
|
||||
response = ftp_client_service.send_file(
|
||||
dest_ip_address=self.backup_server,
|
||||
src_file_name=self._db_file.name,
|
||||
src_folder_name=self._db_file.folder.name,
|
||||
src_folder_name=self.folder.name,
|
||||
dest_folder_name=str(self.uuid),
|
||||
dest_file_name="database.db",
|
||||
real_file_path=self._db_file.sim_path,
|
||||
)
|
||||
self._connect()
|
||||
|
||||
if response:
|
||||
return True
|
||||
@@ -125,25 +85,29 @@ class DatabaseService(Service):
|
||||
dest_ip_address=self.backup_server,
|
||||
)
|
||||
|
||||
if response:
|
||||
self._conn.close()
|
||||
# replace db file
|
||||
self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db")
|
||||
self.file_system.copy_file(
|
||||
src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name
|
||||
)
|
||||
self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db")
|
||||
self._connect()
|
||||
if not response:
|
||||
self.sys_log.error("Unable to restore database backup.")
|
||||
return False
|
||||
|
||||
return self._db_file is not None
|
||||
# replace db file
|
||||
self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db")
|
||||
self.file_system.copy_file(
|
||||
src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name
|
||||
)
|
||||
self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db")
|
||||
|
||||
self.sys_log.error("Unable to restore database backup.")
|
||||
return False
|
||||
if self._db_file is None:
|
||||
self.sys_log.error("Copying database backup failed.")
|
||||
return False
|
||||
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
|
||||
return True
|
||||
|
||||
def _create_db_file(self):
|
||||
"""Creates the Simulation File and sqlite file in the file system."""
|
||||
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db", real=True)
|
||||
self.folder = self._db_file.folder
|
||||
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db")
|
||||
self.folder = self.file_system.get_folder_by_id(self._db_file.folder_id)
|
||||
|
||||
def _process_connect(
|
||||
self, session_id: str, password: Optional[str] = None
|
||||
@@ -163,31 +127,32 @@ class DatabaseService(Service):
|
||||
status_code = 404 # service not found
|
||||
return {"status_code": status_code, "type": "connect_response", "response": status_code == 200}
|
||||
|
||||
def _process_sql(self, query: str, query_id: str) -> Dict[str, Union[int, List[Any]]]:
|
||||
def _process_sql(self, query: Literal["SELECT", "DELETE"], query_id: str) -> Dict[str, Union[int, List[Any]]]:
|
||||
"""
|
||||
Executes the given SQL query and returns the result.
|
||||
|
||||
Possible queries:
|
||||
- SELECT : returns the data
|
||||
- DELETE : deletes the data
|
||||
|
||||
:param query: The SQL query to be executed.
|
||||
:return: Dictionary containing status code and data fetched.
|
||||
"""
|
||||
self.sys_log.info(f"{self.name}: Running {query}")
|
||||
try:
|
||||
self._cursor.execute(query)
|
||||
self._conn.commit()
|
||||
except OperationalError:
|
||||
# Handle the case where the table does not exist.
|
||||
self.sys_log.error(f"{self.name}: Error, query failed")
|
||||
return {"status_code": 404, "data": {}}
|
||||
data = []
|
||||
description = self._cursor.description
|
||||
if description:
|
||||
headers = []
|
||||
for header in description:
|
||||
headers.append(header[0])
|
||||
data = self._cursor.fetchall()
|
||||
if data and headers:
|
||||
data = {row[0]: {header: value for header, value in zip(headers, row)} for row in data}
|
||||
return {"status_code": 200, "type": "sql", "data": data, "uuid": query_id}
|
||||
if query == "SELECT":
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
return {"status_code": 200, "type": "sql", "data": True, "uuid": query_id}
|
||||
else:
|
||||
return {"status_code": 404, "data": False}
|
||||
elif query == "DELETE":
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
self.health_state_actual = SoftwareHealthState.COMPROMISED
|
||||
return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id}
|
||||
else:
|
||||
return {"status_code": 404, "data": False}
|
||||
else:
|
||||
# Invalid query
|
||||
return {"status_code": 500, "data": False}
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
|
||||
@@ -106,7 +106,7 @@ class WebServer(Service):
|
||||
# get data from DatabaseServer
|
||||
db_client: DatabaseClient = self.software_manager.software["DatabaseClient"]
|
||||
# get all users
|
||||
if db_client.query("SELECT * FROM user;"):
|
||||
if db_client.query("SELECT"):
|
||||
# query succeeded
|
||||
response.status_code = HttpStatusCode.OK
|
||||
|
||||
|
||||
Reference in New Issue
Block a user