diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index cbd640f6..22ae0ff3 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -197,7 +197,11 @@ class DatabaseService(Service): status_code = 503 # service unavailable if self.health_state_actual == SoftwareHealthState.OVERWHELMED: self.sys_log.error(f"{self.name}: Connect request for {src_ip=} declined. Service is at capacity.") - if self.health_state_actual == SoftwareHealthState.GOOD: + if self.health_state_actual in [ + SoftwareHealthState.GOOD, + SoftwareHealthState.FIXING, + SoftwareHealthState.COMPROMISED, + ]: if self.password == password: status_code = 200 # ok connection_id = self._generate_connection_id() @@ -244,6 +248,10 @@ class DatabaseService(Service): self.sys_log.error(f"{self.name}: Failed to run {query} because the database file is missing.") return {"status_code": 404, "type": "sql", "data": False} + if self.health_state_actual is not SoftwareHealthState.GOOD: + self.sys_log.error(f"{self.name}: Failed to run {query} because the database service is unavailable.") + return {"status_code": 500, "type": "sql", "data": False} + if query == "SELECT": if self.db_file.health_status == FileSystemItemHealthStatus.CORRUPT: return { diff --git a/tests/integration_tests/system/test_database_on_node.py b/tests/integration_tests/system/test_database_on_node.py index 8da3bb1a..965b4ae8 100644 --- a/tests/integration_tests/system/test_database_on_node.py +++ b/tests/integration_tests/system/test_database_on_node.py @@ -14,6 +14,7 @@ from primaite.simulator.system.applications.database_client import DatabaseClien from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.simulator.system.software import SoftwareHealthState @pytest.fixture(scope="function") @@ -213,6 +214,110 @@ def test_restore_backup_after_deleting_file_without_updating_scan(uc2_network): assert db_service.db_file.visible_health_status == FileSystemItemHealthStatus.GOOD # now looks good +def test_database_service_fix(uc2_network): + """Test that the software fix applies to database service.""" + db_server: Server = uc2_network.get_node_by_hostname("database_server") + db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] + + assert db_service.backup_database() is True + + # delete database locally + db_service.file_system.delete_file(folder_name="database", file_name="database.db") + + # db file is gone, reduced to atoms + assert db_service.db_file is None + + db_service.fix() # fix the database service + + assert db_service.health_state_actual == SoftwareHealthState.FIXING + + # apply timestep until the fix is applied + for i in range(db_service.fixing_duration + 1): + uc2_network.apply_timestep(i) + + assert db_service.db_file.health_status == FileSystemItemHealthStatus.GOOD + assert db_service.health_state_actual == SoftwareHealthState.GOOD + + +def test_database_cannot_be_queried_while_fixing(uc2_network): + """Tests that the database service cannot be queried if the service is being fixed.""" + db_server: Server = uc2_network.get_node_by_hostname("database_server") + db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] + + web_server: Server = uc2_network.get_node_by_hostname("web_server") + db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + + db_connection: DatabaseClientConnection = db_client.get_new_connection() + + assert db_connection.query(sql="SELECT") + + assert db_service.backup_database() is True + + # delete database locally + db_service.file_system.delete_file(folder_name="database", file_name="database.db") + + # db file is gone, reduced to atoms + assert db_service.db_file is None + + db_service.fix() # fix the database service + assert db_service.health_state_actual == SoftwareHealthState.FIXING + + # fails to query because database is in FIXING state + assert db_connection.query(sql="SELECT") is False + + # apply timestep until the fix is applied + for i in range(db_service.fixing_duration + 1): + uc2_network.apply_timestep(i) + + assert db_service.health_state_actual == SoftwareHealthState.GOOD + + assert db_service.db_file.health_status == FileSystemItemHealthStatus.GOOD + + assert db_connection.query(sql="SELECT") + + +def test_database_can_create_connection_while_fixing(uc2_network): + """Tests that connections cannot be created while the database is being fixed.""" + db_server: Server = uc2_network.get_node_by_hostname("database_server") + db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] + + client_2: Server = uc2_network.get_node_by_hostname("client_2") + db_client: DatabaseClient = client_2.software_manager.software["DatabaseClient"] + + db_connection: DatabaseClientConnection = db_client.get_new_connection() + + assert db_connection.query(sql="SELECT") + + assert db_service.backup_database() is True + + # delete database locally + db_service.file_system.delete_file(folder_name="database", file_name="database.db") + + # db file is gone, reduced to atoms + assert db_service.db_file is None + + db_service.fix() # fix the database service + assert db_service.health_state_actual == SoftwareHealthState.FIXING + + # fails to query because database is in FIXING state + assert db_connection.query(sql="SELECT") is False + + # should be able to create a new connection + new_db_connection: DatabaseClientConnection = db_client.get_new_connection() + assert new_db_connection is not None + assert new_db_connection.query(sql="SELECT") is False # still should fail to query because FIXING + + # apply timestep until the fix is applied + for i in range(db_service.fixing_duration + 1): + uc2_network.apply_timestep(i) + + assert db_service.health_state_actual == SoftwareHealthState.GOOD + assert db_service.db_file.health_status == FileSystemItemHealthStatus.GOOD + + assert db_connection.query(sql="SELECT") + assert new_db_connection.query(sql="SELECT") + + def test_database_client_cannot_query_offline_database_server(uc2_network): """Tests DB query across the network returns HTTP status 404 when db server is offline.""" db_server: Server = uc2_network.get_node_by_hostname("database_server") diff --git a/tests/integration_tests/system/test_web_client_server_and_database.py b/tests/integration_tests/system/test_web_client_server_and_database.py index 3fe77fa0..5a765763 100644 --- a/tests/integration_tests/system/test_web_client_server_and_database.py +++ b/tests/integration_tests/system/test_web_client_server_and_database.py @@ -16,6 +16,7 @@ from primaite.simulator.system.services.database.database_service import Databas from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.simulator.system.software import SoftwareHealthState @pytest.fixture(scope="function") @@ -110,6 +111,29 @@ def test_web_client_requests_users(web_client_web_server_database): assert web_browser.get_webpage() +def test_database_fix_disrupts_web_client(uc2_network): + """Tests that the database service being in fixed state disrupts the web client.""" + computer: Computer = uc2_network.get_node_by_hostname("client_1") + db_server: Server = uc2_network.get_node_by_hostname("database_server") + + web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") + database_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") + + # fix the database service + database_service.fix() + + assert database_service.health_state_actual == SoftwareHealthState.FIXING + + assert web_browser.get_webpage() is False + + for i in range(database_service.fixing_duration + 1): + uc2_network.apply_timestep(i) + + assert database_service.health_state_actual == SoftwareHealthState.GOOD + + assert web_browser.get_webpage() + + class TestWebBrowserHistory: def test_populating_history(self, web_client_web_server_database): network, computer, _, _ = web_client_web_server_database