Merged PR 213: Replace SQLite database implementation with a simulation
## Summary Remove SQLLite and real sql queries and replace them with `SELECT` and `DELETE` queries that interact with the node's file system and service operating status. ## Test process Existing unit tests pass after logic was changed. ## Checklist - [x] PR is linked to a **work item** - [ ] **acceptance criteria** of linked ticket are met - [x] performed **self-review** of the code - [x] written **tests** for any new functionality added with this PR - [x] updated the **documentation** if this PR changes or adds functionality - [ ] written/updated **design docs** if this PR implements new functionality - [ ] updated the **change log** - [x] ran **pre-commit** checks for code style - [x] attended to any **TO-DOs** left in the code Related work items: #1971
This commit is contained in:
@@ -41,7 +41,7 @@ Example
|
||||
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
|
||||
client_1.software_manager.install(DataManipulationBot)
|
||||
data_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
|
||||
data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DROP TABLE IF EXISTS user;")
|
||||
data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE")
|
||||
data_manipulation_bot.run()
|
||||
|
||||
This would connect to the database service at 192.168.1.14, authenticate, and execute the SQL statement to drop the 'users' table.
|
||||
|
||||
@@ -14,10 +14,10 @@ The ``DatabaseService`` provides a SQL database server simulation by extending t
|
||||
Key capabilities
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
- Initialises a SQLite database file in the ``Node`` 's ``FileSystem`` upon creation.
|
||||
- Creates a database file in the ``Node`` 's ``FileSystem`` upon creation.
|
||||
- Handles connecting clients by maintaining a dictionary of connections mapped to session IDs.
|
||||
- Authenticates connections using a configurable password.
|
||||
- Executes SQL queries against the SQLite database.
|
||||
- Simulates ``SELECT`` and ``DELETE`` SQL queries.
|
||||
- Returns query results and status codes back to clients.
|
||||
- Leverages the Service base class for install/uninstall, status tracking, etc.
|
||||
|
||||
@@ -30,10 +30,9 @@ Usage
|
||||
Implementation
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
- Uses SQLite for persistent storage.
|
||||
- Creates the database file within the node's file system.
|
||||
- Manages client connections in a dictionary by session ID.
|
||||
- Processes SQL queries via the SQLite cursor and connection.
|
||||
- Processes SQL queries.
|
||||
- Returns results and status codes in a standard dictionary format.
|
||||
- Extends Service class for integration with ``SoftwareManager``.
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ game_config:
|
||||
# options:
|
||||
# execution_definition:
|
||||
# server_ip: 192.168.1.14
|
||||
# payload: "DROP TABLE IF EXISTS user;"
|
||||
# payload: "DELETE"
|
||||
# success_rate: 80%
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_CORRUPT
|
||||
|
||||
@@ -140,7 +140,7 @@ def arcd_uc2_network() -> Network:
|
||||
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
|
||||
client_1.software_manager.install(DataManipulationBot)
|
||||
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
|
||||
db_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DROP TABLE IF EXISTS user;")
|
||||
db_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE")
|
||||
|
||||
# Client 2
|
||||
client_2 = Computer(
|
||||
|
||||
@@ -2,13 +2,14 @@ from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from prettytable import PrettyTable
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application, ApplicationOperatingState
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseClient(Application):
|
||||
"""
|
||||
@@ -148,21 +149,6 @@ class DatabaseClient(Application):
|
||||
self._query_success_tracker[query_id] = False
|
||||
return self._query(sql=sql, query_id=query_id)
|
||||
|
||||
def _print_data(self, data: Dict):
|
||||
"""
|
||||
Display the contents of the Folder in tabular format.
|
||||
|
||||
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
|
||||
"""
|
||||
if data:
|
||||
table = PrettyTable(list(data.values())[0])
|
||||
|
||||
table.align = "l"
|
||||
table.title = f"{self.sys_log.hostname} Database Client"
|
||||
for row in data.values():
|
||||
table.add_row(row.values())
|
||||
print(table)
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
"""
|
||||
Receive a payload from the Software Manager.
|
||||
@@ -179,5 +165,5 @@ class DatabaseClient(Application):
|
||||
status_code = payload.get("status_code")
|
||||
self._query_success_tracker[query_id] = status_code == 200
|
||||
if self._query_success_tracker[query_id]:
|
||||
self._print_data(payload["data"])
|
||||
_LOGGER.debug(f"Received payload {payload}")
|
||||
return True
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from ipaddress import IPv4Address
|
||||
from sqlite3 import OperationalError
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
@@ -19,7 +15,7 @@ class DatabaseService(Service):
|
||||
"""
|
||||
A class for simulating a generic SQL Server service.
|
||||
|
||||
This class inherits from the `Service` class and provides methods to manage and query a SQLite database.
|
||||
This class inherits from the `Service` class and provides methods to simulate a SQL database.
|
||||
"""
|
||||
|
||||
password: Optional[str] = None
|
||||
@@ -41,38 +37,6 @@ class DatabaseService(Service):
|
||||
super().__init__(**kwargs)
|
||||
self._db_file: File
|
||||
self._create_db_file()
|
||||
self._connect()
|
||||
|
||||
def _connect(self):
|
||||
self._conn = sqlite3.connect(self._db_file.sim_path)
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
def tables(self) -> List[str]:
|
||||
"""
|
||||
Get a list of table names present in the database.
|
||||
|
||||
:return: List of table names.
|
||||
"""
|
||||
sql = "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';"
|
||||
results = self._process_sql(sql, None)
|
||||
if isinstance(results["data"], dict):
|
||||
return list(results["data"].keys())
|
||||
return []
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
"""
|
||||
Prints a list of table names in the database using PrettyTable.
|
||||
|
||||
:param markdown: Whether to output the table in Markdown format.
|
||||
"""
|
||||
table = PrettyTable(["Table"])
|
||||
if markdown:
|
||||
table.set_style(MARKDOWN)
|
||||
table.align = "l"
|
||||
table.title = f"{self.file_system.sys_log.hostname} Database"
|
||||
for row in self.tables():
|
||||
table.add_row([row])
|
||||
print(table)
|
||||
|
||||
def configure_backup(self, backup_server: IPv4Address):
|
||||
"""
|
||||
@@ -89,8 +53,6 @@ class DatabaseService(Service):
|
||||
self.sys_log.error(f"{self.name} - {self.sys_log.hostname}: not configured.")
|
||||
return False
|
||||
|
||||
self._conn.close()
|
||||
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
ftp_client_service: FTPClient = software_manager.software["FTPClient"]
|
||||
|
||||
@@ -98,12 +60,10 @@ class DatabaseService(Service):
|
||||
response = ftp_client_service.send_file(
|
||||
dest_ip_address=self.backup_server,
|
||||
src_file_name=self._db_file.name,
|
||||
src_folder_name=self._db_file.folder.name,
|
||||
src_folder_name=self.folder.name,
|
||||
dest_folder_name=str(self.uuid),
|
||||
dest_file_name="database.db",
|
||||
real_file_path=self._db_file.sim_path,
|
||||
)
|
||||
self._connect()
|
||||
|
||||
if response:
|
||||
return True
|
||||
@@ -125,25 +85,29 @@ class DatabaseService(Service):
|
||||
dest_ip_address=self.backup_server,
|
||||
)
|
||||
|
||||
if response:
|
||||
self._conn.close()
|
||||
# replace db file
|
||||
self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db")
|
||||
self.file_system.copy_file(
|
||||
src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name
|
||||
)
|
||||
self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db")
|
||||
self._connect()
|
||||
if not response:
|
||||
self.sys_log.error("Unable to restore database backup.")
|
||||
return False
|
||||
|
||||
return self._db_file is not None
|
||||
# replace db file
|
||||
self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db")
|
||||
self.file_system.copy_file(
|
||||
src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name
|
||||
)
|
||||
self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db")
|
||||
|
||||
self.sys_log.error("Unable to restore database backup.")
|
||||
return False
|
||||
if self._db_file is None:
|
||||
self.sys_log.error("Copying database backup failed.")
|
||||
return False
|
||||
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
|
||||
return True
|
||||
|
||||
def _create_db_file(self):
|
||||
"""Creates the Simulation File and sqlite file in the file system."""
|
||||
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db", real=True)
|
||||
self.folder = self._db_file.folder
|
||||
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db")
|
||||
self.folder = self.file_system.get_folder_by_id(self._db_file.folder_id)
|
||||
|
||||
def _process_connect(
|
||||
self, session_id: str, password: Optional[str] = None
|
||||
@@ -163,31 +127,32 @@ class DatabaseService(Service):
|
||||
status_code = 404 # service not found
|
||||
return {"status_code": status_code, "type": "connect_response", "response": status_code == 200}
|
||||
|
||||
def _process_sql(self, query: str, query_id: str) -> Dict[str, Union[int, List[Any]]]:
|
||||
def _process_sql(self, query: Literal["SELECT", "DELETE"], query_id: str) -> Dict[str, Union[int, List[Any]]]:
|
||||
"""
|
||||
Executes the given SQL query and returns the result.
|
||||
|
||||
Possible queries:
|
||||
- SELECT : returns the data
|
||||
- DELETE : deletes the data
|
||||
|
||||
:param query: The SQL query to be executed.
|
||||
:return: Dictionary containing status code and data fetched.
|
||||
"""
|
||||
self.sys_log.info(f"{self.name}: Running {query}")
|
||||
try:
|
||||
self._cursor.execute(query)
|
||||
self._conn.commit()
|
||||
except OperationalError:
|
||||
# Handle the case where the table does not exist.
|
||||
self.sys_log.error(f"{self.name}: Error, query failed")
|
||||
return {"status_code": 404, "data": {}}
|
||||
data = []
|
||||
description = self._cursor.description
|
||||
if description:
|
||||
headers = []
|
||||
for header in description:
|
||||
headers.append(header[0])
|
||||
data = self._cursor.fetchall()
|
||||
if data and headers:
|
||||
data = {row[0]: {header: value for header, value in zip(headers, row)} for row in data}
|
||||
return {"status_code": 200, "type": "sql", "data": data, "uuid": query_id}
|
||||
if query == "SELECT":
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
return {"status_code": 200, "type": "sql", "data": True, "uuid": query_id}
|
||||
else:
|
||||
return {"status_code": 404, "data": False}
|
||||
elif query == "DELETE":
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
self.health_state_actual = SoftwareHealthState.COMPROMISED
|
||||
return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id}
|
||||
else:
|
||||
return {"status_code": 404, "data": False}
|
||||
else:
|
||||
# Invalid query
|
||||
return {"status_code": 500, "data": False}
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
|
||||
@@ -106,7 +106,7 @@ class WebServer(Service):
|
||||
# get data from DatabaseServer
|
||||
db_client: DatabaseClient = self.software_manager.software["DatabaseClient"]
|
||||
# get all users
|
||||
if db_client.query("SELECT * FROM user;"):
|
||||
if db_client.query("SELECT"):
|
||||
# query succeeded
|
||||
response.status_code = HttpStatusCode.OK
|
||||
|
||||
|
||||
@@ -19,16 +19,16 @@ def test_data_manipulation(uc2_network):
|
||||
db_service.backup_database()
|
||||
|
||||
# First check that the DB client on the web_server can successfully query the users table on the database
|
||||
assert db_client.query("SELECT * FROM user;")
|
||||
assert db_client.query("SELECT")
|
||||
|
||||
# Now we run the DataManipulationBot
|
||||
db_manipulation_bot.run()
|
||||
|
||||
# Now check that the DB client on the web_server cannot query the users table on the database
|
||||
assert not db_client.query("SELECT * FROM user;")
|
||||
assert not db_client.query("SELECT")
|
||||
|
||||
# Now restore the database
|
||||
db_service.restore_backup()
|
||||
|
||||
# Now check that the DB client on the web_server can successfully query the users table on the database
|
||||
assert db_client.query("SELECT * FROM user;")
|
||||
assert db_client.query("SELECT")
|
||||
|
||||
@@ -57,7 +57,7 @@ def test_database_client_query(uc2_network):
|
||||
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
|
||||
db_client.connect()
|
||||
|
||||
assert db_client.query("SELECT * FROM user;")
|
||||
assert db_client.query("SELECT")
|
||||
|
||||
|
||||
def test_create_database_backup(uc2_network):
|
||||
|
||||
@@ -17,4 +17,4 @@ def test_creation():
|
||||
assert data_manipulation_bot.name == "DataManipulationBot"
|
||||
assert data_manipulation_bot.port == Port.POSTGRES_SERVER
|
||||
assert data_manipulation_bot.protocol == IPProtocol.TCP
|
||||
assert data_manipulation_bot.payload == "DROP TABLE IF EXISTS user;"
|
||||
assert data_manipulation_bot.payload == "DELETE"
|
||||
|
||||
Reference in New Issue
Block a user