diff --git a/docs/source/simulation_components/system/data_manipulation_bot.rst b/docs/source/simulation_components/system/data_manipulation_bot.rst index c9f8977a..489f8ae5 100644 --- a/docs/source/simulation_components/system/data_manipulation_bot.rst +++ b/docs/source/simulation_components/system/data_manipulation_bot.rst @@ -41,7 +41,7 @@ Example network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1]) client_1.software_manager.install(DataManipulationBot) data_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"] - data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DROP TABLE IF EXISTS user;") + data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE") data_manipulation_bot.run() This would connect to the database service at 192.168.1.14, authenticate, and execute the SQL statement to drop the 'users' table. diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index ee42cf4f..ddf9d923 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -81,7 +81,7 @@ game_config: # options: # execution_definition: # server_ip: 192.168.1.14 - # payload: "DROP TABLE IF EXISTS user;" + # payload: "DELETE" # success_rate: 80% - type: NODE_FILE_DELETE - type: NODE_FILE_CORRUPT diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index 25d1bd21..c0f9a07e 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -140,7 +140,7 @@ def arcd_uc2_network() -> Network: network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1]) client_1.software_manager.install(DataManipulationBot) db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"] - db_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DROP TABLE IF EXISTS user;") + db_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE") # Client 2 client_2 = Computer( diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index d021cb78..37f89371 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -2,13 +2,14 @@ from ipaddress import IPv4Address from typing import Any, Dict, Optional from uuid import uuid4 -from prettytable import PrettyTable - +from primaite import getLogger from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application, ApplicationOperatingState from primaite.simulator.system.core.software_manager import SoftwareManager +_LOGGER = getLogger(__name__) + class DatabaseClient(Application): """ @@ -148,21 +149,6 @@ class DatabaseClient(Application): self._query_success_tracker[query_id] = False return self._query(sql=sql, query_id=query_id) - def _print_data(self, data: Dict): - """ - Display the contents of the Folder in tabular format. - - :param markdown: Whether to display the table in Markdown format or not. Default is `False`. - """ - if data: - table = PrettyTable(list(data.values())[0]) - - table.align = "l" - table.title = f"{self.sys_log.hostname} Database Client" - for row in data.values(): - table.add_row(row.values()) - print(table) - def receive(self, payload: Any, session_id: str, **kwargs) -> bool: """ Receive a payload from the Software Manager. @@ -179,5 +165,5 @@ class DatabaseClient(Application): status_code = payload.get("status_code") self._query_success_tracker[query_id] = status_code == 200 if self._query_success_tracker[query_id]: - self._print_data(payload["data"]) + _LOGGER.debug(f"Received payload {payload}") return True diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index b04174bf..d7277e1e 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -1,10 +1,6 @@ -import sqlite3 from datetime import datetime from ipaddress import IPv4Address -from sqlite3 import OperationalError -from typing import Any, Dict, List, Optional, Union - -from prettytable import MARKDOWN, PrettyTable +from typing import Any, Dict, List, Literal, Optional, Union from primaite.simulator.file_system.file_system import File from primaite.simulator.network.transmission.network_layer import IPProtocol @@ -19,7 +15,7 @@ class DatabaseService(Service): """ A class for simulating a generic SQL Server service. - This class inherits from the `Service` class and provides methods to manage and query a SQLite database. + This class inherits from the `Service` class and provides methods to simulate a SQL database. """ password: Optional[str] = None @@ -41,38 +37,6 @@ class DatabaseService(Service): super().__init__(**kwargs) self._db_file: File self._create_db_file() - self._connect() - - def _connect(self): - self._conn = sqlite3.connect(self._db_file.sim_path) - self._cursor = self._conn.cursor() - - def tables(self) -> List[str]: - """ - Get a list of table names present in the database. - - :return: List of table names. - """ - sql = "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';" - results = self._process_sql(sql, None) - if isinstance(results["data"], dict): - return list(results["data"].keys()) - return [] - - def show(self, markdown: bool = False): - """ - Prints a list of table names in the database using PrettyTable. - - :param markdown: Whether to output the table in Markdown format. - """ - table = PrettyTable(["Table"]) - if markdown: - table.set_style(MARKDOWN) - table.align = "l" - table.title = f"{self.file_system.sys_log.hostname} Database" - for row in self.tables(): - table.add_row([row]) - print(table) def configure_backup(self, backup_server: IPv4Address): """ @@ -89,8 +53,6 @@ class DatabaseService(Service): self.sys_log.error(f"{self.name} - {self.sys_log.hostname}: not configured.") return False - self._conn.close() - software_manager: SoftwareManager = self.software_manager ftp_client_service: FTPClient = software_manager.software["FTPClient"] @@ -98,12 +60,10 @@ class DatabaseService(Service): response = ftp_client_service.send_file( dest_ip_address=self.backup_server, src_file_name=self._db_file.name, - src_folder_name=self._db_file.folder.name, + src_folder_name=self.folder.name, dest_folder_name=str(self.uuid), dest_file_name="database.db", - real_file_path=self._db_file.sim_path, ) - self._connect() if response: return True @@ -125,25 +85,29 @@ class DatabaseService(Service): dest_ip_address=self.backup_server, ) - if response: - self._conn.close() - # replace db file - self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db") - self.file_system.copy_file( - src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name - ) - self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db") - self._connect() + if not response: + self.sys_log.error("Unable to restore database backup.") + return False - return self._db_file is not None + # replace db file + self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db") + self.file_system.copy_file( + src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name + ) + self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db") - self.sys_log.error("Unable to restore database backup.") - return False + if self._db_file is None: + self.sys_log.error("Copying database backup failed.") + return False + + self.set_health_state(SoftwareHealthState.GOOD) + + return True def _create_db_file(self): """Creates the Simulation File and sqlite file in the file system.""" - self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db", real=True) - self.folder = self._db_file.folder + self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db") + self.folder = self.file_system.get_folder_by_id(self._db_file.folder_id) def _process_connect( self, session_id: str, password: Optional[str] = None @@ -163,31 +127,32 @@ class DatabaseService(Service): status_code = 404 # service not found return {"status_code": status_code, "type": "connect_response", "response": status_code == 200} - def _process_sql(self, query: str, query_id: str) -> Dict[str, Union[int, List[Any]]]: + def _process_sql(self, query: Literal["SELECT", "DELETE"], query_id: str) -> Dict[str, Union[int, List[Any]]]: """ Executes the given SQL query and returns the result. + Possible queries: + - SELECT : returns the data + - DELETE : deletes the data + :param query: The SQL query to be executed. :return: Dictionary containing status code and data fetched. """ self.sys_log.info(f"{self.name}: Running {query}") - try: - self._cursor.execute(query) - self._conn.commit() - except OperationalError: - # Handle the case where the table does not exist. - self.sys_log.error(f"{self.name}: Error, query failed") - return {"status_code": 404, "data": {}} - data = [] - description = self._cursor.description - if description: - headers = [] - for header in description: - headers.append(header[0]) - data = self._cursor.fetchall() - if data and headers: - data = {row[0]: {header: value for header, value in zip(headers, row)} for row in data} - return {"status_code": 200, "type": "sql", "data": data, "uuid": query_id} + if query == "SELECT": + if self.health_state_actual == SoftwareHealthState.GOOD: + return {"status_code": 200, "type": "sql", "data": True, "uuid": query_id} + else: + return {"status_code": 404, "data": False} + elif query == "DELETE": + if self.health_state_actual == SoftwareHealthState.GOOD: + self.health_state_actual = SoftwareHealthState.COMPROMISED + return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id} + else: + return {"status_code": 404, "data": False} + else: + # Invalid query + return {"status_code": 500, "data": False} def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 5957e4cb..cb1a4738 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -106,7 +106,7 @@ class WebServer(Service): # get data from DatabaseServer db_client: DatabaseClient = self.software_manager.software["DatabaseClient"] # get all users - if db_client.query("SELECT * FROM user;"): + if db_client.query("SELECT"): # query succeeded response.status_code = HttpStatusCode.OK diff --git a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py index 13f4d1f3..81bbfc96 100644 --- a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py +++ b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py @@ -19,16 +19,16 @@ def test_data_manipulation(uc2_network): db_service.backup_database() # First check that the DB client on the web_server can successfully query the users table on the database - assert db_client.query("SELECT * FROM user;") + assert db_client.query("SELECT") # Now we run the DataManipulationBot db_manipulation_bot.run() # Now check that the DB client on the web_server cannot query the users table on the database - assert not db_client.query("SELECT * FROM user;") + assert not db_client.query("SELECT") # Now restore the database db_service.restore_backup() # Now check that the DB client on the web_server can successfully query the users table on the database - assert db_client.query("SELECT * FROM user;") + assert db_client.query("SELECT") diff --git a/tests/integration_tests/system/test_database_on_node.py b/tests/integration_tests/system/test_database_on_node.py index 92056981..027fae4a 100644 --- a/tests/integration_tests/system/test_database_on_node.py +++ b/tests/integration_tests/system/test_database_on_node.py @@ -57,7 +57,7 @@ def test_database_client_query(uc2_network): db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] db_client.connect() - assert db_client.query("SELECT * FROM user;") + assert db_client.query("SELECT") def test_create_database_backup(uc2_network): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py index dd785cc1..113ebeb4 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py @@ -17,4 +17,4 @@ def test_creation(): assert data_manipulation_bot.name == "DataManipulationBot" assert data_manipulation_bot.port == Port.POSTGRES_SERVER assert data_manipulation_bot.protocol == IPProtocol.TCP - assert data_manipulation_bot.payload == "DROP TABLE IF EXISTS user;" + assert data_manipulation_bot.payload == "DELETE"