Merged PR 348: #2462 - Refactor of DatabaseClient and DatabaseServer

## Summary
Refactor of `DatabaseClient` and `DatabaseService` to update how connection IDs are generated. These are now provided by DatabaseService when establishing a connection.
Creation of `DatabaseClientConnection` class. This is used by `DatabaseClient` to hold a dictionary of active db connections.

## Test process
Tests have been updated to reflect the changes and all pass

## Checklist
- [X] PR is linked to a **work item**
- [X] **acceptance criteria** of linked ticket are met
- [X] performed **self-review** of the code
- [X] written **tests** for any new functionality added with this PR
- [X] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [X] updated the **change log**
- [X] ran **pre-commit** checks for code style
- [X] attended to any **TO-DOs** left in the code

Related work items: #2462
This commit is contained in:
Charlie Crane
2024-04-26 14:52:21 +00:00
parent e1ac6255ad
commit 5ee23dcb17
21 changed files with 502 additions and 156 deletions

View File

@@ -1,5 +1,6 @@
from ipaddress import IPv4Address
from typing import Any, Dict, List, Literal, Optional, Union
from uuid import uuid4
from primaite import getLogger
from primaite.simulator.file_system.file_system import File
@@ -145,8 +146,16 @@ class DatabaseService(Service):
"""Returns the database folder."""
return self.file_system.get_folder_by_id(self.db_file.folder_id)
def _generate_connection_id(self) -> str:
"""Generate a unique connection ID."""
return str(uuid4())
def _process_connect(
self, connection_id: str, password: Optional[str] = None
self,
src_ip: IPv4Address,
connection_request_id: str,
password: Optional[str] = None,
session_id: Optional[str] = None,
) -> Dict[str, Union[int, Dict[str, bool]]]:
"""Process an incoming connection request.
@@ -158,17 +167,17 @@ class DatabaseService(Service):
:rtype: Dict[str, Union[int, Dict[str, bool]]]
"""
status_code = 500 # Default internal server error
connection_id = None
if self.operating_state == ServiceOperatingState.RUNNING:
status_code = 503 # service unavailable
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
self.sys_log.error(
f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity."
)
self.sys_log.error(f"{self.name}: Connect request for {src_ip=} declined. Service is at capacity.")
if self.health_state_actual == SoftwareHealthState.GOOD:
if self.password == password:
status_code = 200 # ok
connection_id = self._generate_connection_id()
# try to create connection
if not self.add_connection(connection_id=connection_id):
if not self.add_connection(connection_id=connection_id, session_id=session_id):
status_code = 500
self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined")
else:
@@ -183,6 +192,7 @@ class DatabaseService(Service):
"type": "connect_response",
"response": status_code == 200,
"connection_id": connection_id,
"connection_request_id": connection_request_id,
}
def _process_sql(
@@ -299,19 +309,34 @@ class DatabaseService(Service):
:return: True if the Status Code is 200, otherwise False.
"""
result = {"status_code": 500, "data": []}
# if server service is down, return error
if not self._can_perform_action():
return False
if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "connect_request":
src_ip = kwargs.get("frame").ip.src_ip_address
result = self._process_connect(
connection_id=payload.get("connection_id"), password=payload.get("password")
src_ip=src_ip,
password=payload.get("password"),
connection_request_id=payload.get("connection_request_id"),
session_id=session_id,
)
elif payload["type"] == "disconnect":
if payload["connection_id"] in self.connections:
self.remove_connection(connection_id=payload["connection_id"])
connection_id = payload["connection_id"]
connected_ip_address = self.connections[connection_id]["ip_address"]
frame = kwargs.get("frame")
if connected_ip_address == frame.ip.src_ip_address:
self.sys_log.info(
f"{self.name}: Received disconnect command for {connection_id=} from {connected_ip_address}"
)
self.terminate_connection(connection_id=payload["connection_id"], send_disconnect=False)
else:
self.sys_log.warning(
f"{self.name}: Ignoring disconnect command for {connection_id=} as the command source "
f"({frame.ip.src_ip_address}) doesn't match the connection source ({connected_ip_address})"
)
elif payload["type"] == "sql":
if payload.get("connection_id") in self.connections:
result = self._process_sql(