Merged PR 348: #2462 - Refactor of DatabaseClient and DatabaseServer
## Summary Refactor of `DatabaseClient` and `DatabaseService` to update how connection IDs are generated. These are now provided by DatabaseService when establishing a connection. Creation of `DatabaseClientConnection` class. This is used by `DatabaseClient` to hold a dictionary of active db connections. ## Test process Tests have been updated to reflect the changes and all pass ## Checklist - [X] PR is linked to a **work item** - [X] **acceptance criteria** of linked ticket are met - [X] performed **self-review** of the code - [X] written **tests** for any new functionality added with this PR - [X] updated the **documentation** if this PR changes or adds functionality - [ ] written/updated **design docs** if this PR implements new functionality - [X] updated the **change log** - [X] ran **pre-commit** checks for code style - [X] attended to any **TO-DOs** left in the code Related work items: #2462
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
@@ -145,8 +146,16 @@ class DatabaseService(Service):
|
||||
"""Returns the database folder."""
|
||||
return self.file_system.get_folder_by_id(self.db_file.folder_id)
|
||||
|
||||
def _generate_connection_id(self) -> str:
|
||||
"""Generate a unique connection ID."""
|
||||
return str(uuid4())
|
||||
|
||||
def _process_connect(
|
||||
self, connection_id: str, password: Optional[str] = None
|
||||
self,
|
||||
src_ip: IPv4Address,
|
||||
connection_request_id: str,
|
||||
password: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Dict[str, Union[int, Dict[str, bool]]]:
|
||||
"""Process an incoming connection request.
|
||||
|
||||
@@ -158,17 +167,17 @@ class DatabaseService(Service):
|
||||
:rtype: Dict[str, Union[int, Dict[str, bool]]]
|
||||
"""
|
||||
status_code = 500 # Default internal server error
|
||||
connection_id = None
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
status_code = 503 # service unavailable
|
||||
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
|
||||
self.sys_log.error(
|
||||
f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity."
|
||||
)
|
||||
self.sys_log.error(f"{self.name}: Connect request for {src_ip=} declined. Service is at capacity.")
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
if self.password == password:
|
||||
status_code = 200 # ok
|
||||
connection_id = self._generate_connection_id()
|
||||
# try to create connection
|
||||
if not self.add_connection(connection_id=connection_id):
|
||||
if not self.add_connection(connection_id=connection_id, session_id=session_id):
|
||||
status_code = 500
|
||||
self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined")
|
||||
else:
|
||||
@@ -183,6 +192,7 @@ class DatabaseService(Service):
|
||||
"type": "connect_response",
|
||||
"response": status_code == 200,
|
||||
"connection_id": connection_id,
|
||||
"connection_request_id": connection_request_id,
|
||||
}
|
||||
|
||||
def _process_sql(
|
||||
@@ -299,19 +309,34 @@ class DatabaseService(Service):
|
||||
:return: True if the Status Code is 200, otherwise False.
|
||||
"""
|
||||
result = {"status_code": 500, "data": []}
|
||||
|
||||
# if server service is down, return error
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
if isinstance(payload, dict) and payload.get("type"):
|
||||
if payload["type"] == "connect_request":
|
||||
src_ip = kwargs.get("frame").ip.src_ip_address
|
||||
result = self._process_connect(
|
||||
connection_id=payload.get("connection_id"), password=payload.get("password")
|
||||
src_ip=src_ip,
|
||||
password=payload.get("password"),
|
||||
connection_request_id=payload.get("connection_request_id"),
|
||||
session_id=session_id,
|
||||
)
|
||||
elif payload["type"] == "disconnect":
|
||||
if payload["connection_id"] in self.connections:
|
||||
self.remove_connection(connection_id=payload["connection_id"])
|
||||
connection_id = payload["connection_id"]
|
||||
connected_ip_address = self.connections[connection_id]["ip_address"]
|
||||
frame = kwargs.get("frame")
|
||||
if connected_ip_address == frame.ip.src_ip_address:
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Received disconnect command for {connection_id=} from {connected_ip_address}"
|
||||
)
|
||||
self.terminate_connection(connection_id=payload["connection_id"], send_disconnect=False)
|
||||
else:
|
||||
self.sys_log.warning(
|
||||
f"{self.name}: Ignoring disconnect command for {connection_id=} as the command source "
|
||||
f"({frame.ip.src_ip_address}) doesn't match the connection source ({connected_ip_address})"
|
||||
)
|
||||
elif payload["type"] == "sql":
|
||||
if payload.get("connection_id") in self.connections:
|
||||
result = self._process_sql(
|
||||
|
||||
Reference in New Issue
Block a user