From 3dc8a0f222636b8f0ee6c9e2703e7077a9765643 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Fri, 29 Sep 2023 20:14:42 +0100 Subject: [PATCH] #1796 - Made the FTP copy real files. Hardcoded the DatabaseService backup folder and filename. Added db restore and final query check to the data manipulation e2e test. --- .../services/database/database_service.py | 39 +++++++++---------- .../system/services/ftp/ftp_client.py | 1 + .../system/services/ftp/ftp_service.py | 10 ++++- .../test_uc2_data_manipulation_scenario.py | 8 ++++ 4 files changed, 36 insertions(+), 22 deletions(-) diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 268bd54f..f874b89b 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -41,6 +41,9 @@ class DatabaseService(Service): super().__init__(**kwargs) self._db_file: File self._create_db_file() + self._connect() + + def _connect(self): self._conn = sqlite3.connect(self._db_file.sim_path) self._cursor = self._conn.cursor() @@ -51,8 +54,10 @@ class DatabaseService(Service): :return: List of table names. """ sql = "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';" - results = self._process_sql(sql) - return [row[0] for row in results["data"]] + results = self._process_sql(sql, None) + if isinstance(results["data"], dict): + return list(results["data"].keys()) + return [] def show(self, markdown: bool = False): """ @@ -77,9 +82,7 @@ class DatabaseService(Service): """ self.backup_server = backup_server - def backup_database( - self, backup_directory: Optional[str] = "db_backup", backup_file_name: Optional[str] = None - ) -> bool: + def backup_database(self) -> bool: """ Create a backup of the database to the configured backup server. @@ -94,8 +97,7 @@ class DatabaseService(Service): self.sys_log.error(f"{self.name} - {self.sys_log.hostname}: not configured.") return False - if backup_file_name is None: - backup_file_name = f"{datetime.now().strftime('%d-%m-%Y_%H-%M')}.db" + self._conn.close() software_manager: SoftwareManager = self.software_manager ftp_client_service: FTPClient = software_manager.software["FTPClient"] @@ -105,19 +107,19 @@ class DatabaseService(Service): dest_ip_address=self.backup_server, src_file_name=self._db_file.name, src_folder_name=self._db_file.folder.name, - dest_folder_name=backup_directory, - dest_file_name=backup_file_name, + dest_folder_name=str(self.uuid), + dest_file_name="database.db", + real_file_path=self._db_file.sim_path, ) + self._connect() if response: - self.latest_backup_directory = backup_directory - self.latest_backup_file_name = backup_file_name return True self.sys_log.error("Unable to create database backup.") return False - def restore_backup(self, backup_directory: Optional[str] = None, backup_file_name: Optional[str] = None) -> bool: + def restore_backup(self) -> bool: """ Restore a backup from backup server. @@ -127,32 +129,27 @@ class DatabaseService(Service): :param: backup_file_name: Name of file where backup will be stored. Optional. :type: backup_file_name: Optional[str] """ - if backup_directory is None: - backup_directory = self.latest_backup_directory - - if backup_file_name is None: - backup_file_name = self.latest_backup_file_name - software_manager: SoftwareManager = self.software_manager ftp_client_service: FTPClient = software_manager.software["FTPClient"] # retrieve backup file from backup server response = ftp_client_service.request_file( - src_folder_name=backup_directory, - src_file_name=backup_file_name, + src_folder_name=str(self.uuid), + src_file_name="database.db", dest_folder_name="downloads", dest_file_name="database.db", dest_ip_address=self.backup_server, ) if response: + self._conn.close() # replace db file self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db") - self.file_system.move_file( src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name ) self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db") + self._connect() return self._db_file is not None diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py index 8359e8a0..c22f704b 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_client.py +++ b/src/primaite/simulator/system/services/ftp/ftp_client.py @@ -119,6 +119,7 @@ class FTPClient(FTPServiceABC): dest_folder_name: str, dest_file_name: str, dest_port: Optional[Port] = Port.FTP, + real_file_path: Optional[str] = None, ) -> bool: """ Send a file to a target IP address. diff --git a/src/primaite/simulator/system/services/ftp/ftp_service.py b/src/primaite/simulator/system/services/ftp/ftp_service.py index c6d63751..5314b6a3 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_service.py +++ b/src/primaite/simulator/system/services/ftp/ftp_service.py @@ -1,3 +1,4 @@ +import shutil from abc import ABC from ipaddress import IPv4Address from typing import Optional @@ -51,11 +52,17 @@ class FTPServiceABC(Service, ABC): file_name = payload.ftp_command_args["dest_file_name"] folder_name = payload.ftp_command_args["dest_folder_name"] file_size = payload.ftp_command_args["file_size"] - self.file_system.create_file(file_name=file_name, folder_name=folder_name, size=file_size, real=True) + real_file_path = payload.ftp_command_args.get("real_file_path") + is_real = real_file_path is not None + file = self.file_system.create_file( + file_name=file_name, folder_name=folder_name, size=file_size, real=is_real + ) self.sys_log.info( f"Created item in {self.sys_log.hostname}: {payload.ftp_command_args['dest_folder_name']}/" f"{payload.ftp_command_args['dest_file_name']}" ) + if is_real: + shutil.copy(real_file_path, file.sim_path) # file should exist return self.file_system.get_file(file_name=file_name, folder_name=folder_name) is not None except Exception as e: @@ -99,6 +106,7 @@ class FTPServiceABC(Service, ABC): "dest_folder_name": dest_folder_name, "dest_file_name": dest_file_name, "file_size": file.sim_size, + "real_file_path": file.sim_path if file.real else None, }, packet_payload_size=file.sim_size, ) diff --git a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py index 50998f09..955fa20e 100644 --- a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py +++ b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py @@ -15,6 +15,8 @@ def test_data_manipulation(uc2_network): web_server: Server = uc2_network.get_node_by_hostname("web_server") db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + db_service.backup_database() + # First check that the DB client on the web_server can successfully query the users table on the database assert db_client.query("SELECT * FROM user;") @@ -23,3 +25,9 @@ def test_data_manipulation(uc2_network): # Now check that the DB client on the web_server cannot query the users table on the database assert not db_client.query("SELECT * FROM user;") + + # Now restore the database + db_service.restore_backup() + + # Now check that the DB client on the web_server can successfully query the users table on the database + assert db_client.query("SELECT * FROM user;")