Files
PrimAITE/src/primaite/simulator/system/services/database/database_service.py

234 lines
9.3 KiB
Python

from datetime import datetime
from ipaddress import IPv4Address
from typing import Any, Dict, List, Literal, Optional, Union
from primaite import getLogger
from primaite.simulator.file_system.file_system import File
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.service import Service, ServiceOperatingState
from primaite.simulator.system.software import SoftwareHealthState
_LOGGER = getLogger(__name__)
class DatabaseService(Service):
"""
A class for simulating a generic SQL Server service.
This class inherits from the `Service` class and provides methods to simulate a SQL database.
"""
password: Optional[str] = None
connections: Dict[str, datetime] = {}
backup_server_ip: IPv4Address = None
"""IP address of the backup server."""
latest_backup_directory: str = None
"""Directory of latest backup."""
latest_backup_file_name: str = None
"""File name of latest backup."""
def __init__(self, **kwargs):
kwargs["name"] = "DatabaseService"
kwargs["port"] = Port.POSTGRES_SERVER
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
self._db_file: File
self._create_db_file()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting DatabaseService original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {
"password",
"connections",
"backup_server",
"latest_backup_directory",
"latest_backup_file_name",
}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}")
self.connections.clear()
super().reset_component_for_episode(episode)
def configure_backup(self, backup_server: IPv4Address):
"""
Set up the database backup.
:param: backup_server_ip: The IP address of the backup server
"""
self.backup_server_ip = backup_server
def backup_database(self) -> bool:
"""Create a backup of the database to the configured backup server."""
# check if this action can be performed
if not self._can_perform_action():
return False
# check if the backup server was configured
if self.backup_server_ip is None:
self.sys_log.error(f"{self.name} - {self.sys_log.hostname}: not configured.")
return False
software_manager: SoftwareManager = self.software_manager
ftp_client_service: FTPClient = software_manager.software.get("FTPClient")
# send backup copy of database file to FTP server
response = ftp_client_service.send_file(
dest_ip_address=self.backup_server_ip,
src_file_name=self._db_file.name,
src_folder_name=self.folder.name,
dest_folder_name=str(self.uuid),
dest_file_name="database.db",
)
if response:
return True
self.sys_log.error("Unable to create database backup.")
return False
def restore_backup(self) -> bool:
"""Restore a backup from backup server."""
# check if this action can be performed
if not self._can_perform_action():
return False
software_manager: SoftwareManager = self.software_manager
ftp_client_service: FTPClient = software_manager.software.get("FTPClient")
# retrieve backup file from backup server
response = ftp_client_service.request_file(
src_folder_name=str(self.uuid),
src_file_name="database.db",
dest_folder_name="downloads",
dest_file_name="database.db",
dest_ip_address=self.backup_server_ip,
)
if not response:
self.sys_log.error("Unable to restore database backup.")
return False
# replace db file
self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db")
self.file_system.copy_file(
src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name
)
self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db")
if self._db_file is None:
self.sys_log.error("Copying database backup failed.")
return False
self.set_health_state(SoftwareHealthState.GOOD)
return True
def _create_db_file(self):
"""Creates the Simulation File and sqlite file in the file system."""
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db")
self.folder = self.file_system.get_folder_by_id(self._db_file.folder_id)
def _process_connect(
self, session_id: str, password: Optional[str] = None
) -> Dict[str, Union[int, Dict[str, bool]]]:
status_code = 500 # Default internal server error
if self.operating_state == ServiceOperatingState.RUNNING:
status_code = 503 # service unavailable
if self.health_state_actual == SoftwareHealthState.GOOD:
if self.password == password:
status_code = 200 # ok
self.connections[session_id] = datetime.now()
self.sys_log.info(f"{self.name}: Connect request for {session_id=} authorised")
else:
status_code = 401 # Unauthorised
self.sys_log.info(f"{self.name}: Connect request for {session_id=} declined")
else:
status_code = 404 # service not found
return {"status_code": status_code, "type": "connect_response", "response": status_code == 200}
def _process_sql(self, query: Literal["SELECT", "DELETE"], query_id: str) -> Dict[str, Union[int, List[Any]]]:
"""
Executes the given SQL query and returns the result.
Possible queries:
- SELECT : returns the data
- DELETE : deletes the data
:param query: The SQL query to be executed.
:return: Dictionary containing status code and data fetched.
"""
self.sys_log.info(f"{self.name}: Running {query}")
if query == "SELECT":
if self._db_file.health_status == FileSystemItemHealthStatus.GOOD:
return {"status_code": 200, "type": "sql", "data": True, "uuid": query_id}
else:
return {"status_code": 404, "data": False}
elif query == "DELETE":
self._db_file.health_status = FileSystemItemHealthStatus.COMPROMISED
return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id}
else:
# Invalid query
return {"status_code": 500, "data": False}
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
return super().describe_state()
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Processes the incoming SQL payload and sends the result back.
:param payload: The SQL query to be executed.
:param session_id: The session identifier.
:return: True if the Status Code is 200, otherwise False.
"""
if not super().receive(payload=payload, session_id=session_id, **kwargs):
return False
result = {"status_code": 500, "data": []}
if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "connect_request":
result = self._process_connect(session_id=session_id, password=payload.get("password"))
elif payload["type"] == "disconnect":
if session_id in self.connections:
self.connections.pop(session_id)
elif payload["type"] == "sql":
if session_id in self.connections:
result = self._process_sql(query=payload["sql"], query_id=payload["uuid"])
else:
result = {"status_code": 401, "type": "sql"}
self.send(payload=result, session_id=session_id)
return True
def send(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Send a SQL response back down to the SessionManager.
:param payload: The SQL query results.
:param session_id: The session identifier.
:return: True if the Status Code is 200, otherwise False.
"""
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id)
return payload["status_code"] == 200