#2059: implementing the service connections limit
This commit is contained in:
@@ -23,7 +23,7 @@ class DatabaseClient(Application):
|
||||
|
||||
server_ip_address: Optional[IPv4Address] = None
|
||||
server_password: Optional[str] = None
|
||||
connected: bool = False
|
||||
connections: Dict[str, Dict] = {}
|
||||
_query_success_tracker: Dict[str, bool] = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -66,18 +66,24 @@ class DatabaseClient(Application):
|
||||
self.server_password = server_password
|
||||
self.sys_log.info(f"{self.name}: Configured the {self.name} with {server_ip_address=}, {server_password=}.")
|
||||
|
||||
def connect(self) -> bool:
|
||||
def connect(self, connection_id: Optional[str] = None) -> bool:
|
||||
"""Connect to a Database Service."""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
if not self.connected:
|
||||
return self._connect(self.server_ip_address, self.server_password)
|
||||
# already connected
|
||||
return True
|
||||
if not connection_id:
|
||||
connection_id = str(uuid4())
|
||||
|
||||
return self._connect(
|
||||
server_ip_address=self.server_ip_address, password=self.server_password, connection_id=connection_id
|
||||
)
|
||||
|
||||
def _connect(
|
||||
self, server_ip_address: IPv4Address, password: Optional[str] = None, is_reattempt: bool = False
|
||||
self,
|
||||
server_ip_address: IPv4Address,
|
||||
connection_id: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
is_reattempt: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Connects the DatabaseClient to the DatabaseServer.
|
||||
@@ -92,33 +98,58 @@ class DatabaseClient(Application):
|
||||
:type: is_reattempt: Optional[bool]
|
||||
"""
|
||||
if is_reattempt:
|
||||
if self.connected:
|
||||
self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} authorised")
|
||||
if self.connections.get(connection_id):
|
||||
self.sys_log.info(
|
||||
f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} authorised"
|
||||
)
|
||||
self.server_ip_address = server_ip_address
|
||||
return self.connected
|
||||
return True
|
||||
else:
|
||||
self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} declined")
|
||||
self.sys_log.info(
|
||||
f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} declined"
|
||||
)
|
||||
return False
|
||||
payload = {"type": "connect_request", "password": password}
|
||||
payload = {
|
||||
"type": "connect_request",
|
||||
"password": password,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_ip_address=server_ip_address, dest_port=self.port
|
||||
)
|
||||
return self._connect(server_ip_address, password, True)
|
||||
return self._connect(
|
||||
server_ip_address=server_ip_address, password=password, connection_id=connection_id, is_reattempt=True
|
||||
)
|
||||
|
||||
def disconnect(self):
|
||||
def disconnect(self, connection_id: Optional[str] = None) -> bool:
|
||||
"""Disconnect from the Database Service."""
|
||||
if self.connected and self.operating_state is ApplicationOperatingState.RUNNING:
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "disconnect"}, dest_ip_address=self.server_ip_address, dest_port=self.port
|
||||
)
|
||||
if not self._can_perform_action():
|
||||
self.sys_log.error(f"Unable to disconnect - {self.name} is {self.operating_state.name}")
|
||||
return False
|
||||
|
||||
self.sys_log.info(f"{self.name}: DatabaseClient disconnected from {self.server_ip_address}")
|
||||
self.server_ip_address = None
|
||||
self.connected = False
|
||||
# if there are no connections - nothing to disconnect
|
||||
if not len(self.connections):
|
||||
self.sys_log.error(f"Unable to disconnect - {self.name} has no active connections.")
|
||||
return False
|
||||
|
||||
def _query(self, sql: str, query_id: str, is_reattempt: bool = False) -> bool:
|
||||
# if no connection provided, disconnect the first connection
|
||||
if not connection_id:
|
||||
connection_id = list(self.connections.keys())[0]
|
||||
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "disconnect", "connection_id": connection_id},
|
||||
dest_ip_address=self.server_ip_address,
|
||||
dest_port=self.port,
|
||||
)
|
||||
self.connections.pop(connection_id)
|
||||
|
||||
self.sys_log.info(
|
||||
f"{self.name}: DatabaseClient disconnected connection {connection_id} from {self.server_ip_address}"
|
||||
)
|
||||
|
||||
def _query(self, sql: str, query_id: str, connection_id: str, is_reattempt: bool = False) -> bool:
|
||||
"""
|
||||
Send a query to the connected database server.
|
||||
|
||||
@@ -141,11 +172,11 @@ class DatabaseClient(Application):
|
||||
else:
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "sql", "sql": sql, "uuid": query_id},
|
||||
payload={"type": "sql", "sql": sql, "uuid": query_id, "connection_id": connection_id},
|
||||
dest_ip_address=self.server_ip_address,
|
||||
dest_port=self.port,
|
||||
)
|
||||
return self._query(sql=sql, query_id=query_id, is_reattempt=True)
|
||||
return self._query(sql=sql, query_id=query_id, connection_id=connection_id, is_reattempt=True)
|
||||
|
||||
def run(self) -> None:
|
||||
"""Run the DatabaseClient."""
|
||||
@@ -153,7 +184,7 @@ class DatabaseClient(Application):
|
||||
if self.operating_state == ApplicationOperatingState.RUNNING:
|
||||
self.connect()
|
||||
|
||||
def query(self, sql: str, is_reattempt: bool = False) -> bool:
|
||||
def query(self, sql: str, connection_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Send a query to the Database Service.
|
||||
|
||||
@@ -164,20 +195,17 @@ class DatabaseClient(Application):
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
if self.connected:
|
||||
query_id = str(uuid4())
|
||||
if connection_id is None:
|
||||
connection_id = str(uuid4())
|
||||
|
||||
if not self.connections.get(connection_id):
|
||||
if not self.connect(connection_id=connection_id):
|
||||
return False
|
||||
|
||||
# Initialise the tracker of this ID to False
|
||||
self._query_success_tracker[query_id] = False
|
||||
return self._query(sql=sql, query_id=query_id)
|
||||
else:
|
||||
if is_reattempt:
|
||||
return False
|
||||
|
||||
if not self.connect():
|
||||
return False
|
||||
|
||||
self.query(sql=sql, is_reattempt=True)
|
||||
uuid = str(uuid4())
|
||||
self._query_success_tracker[uuid] = False
|
||||
return self._query(sql=sql, query_id=uuid, connection_id=connection_id)
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
"""
|
||||
@@ -192,13 +220,12 @@ class DatabaseClient(Application):
|
||||
|
||||
if isinstance(payload, dict) and payload.get("type"):
|
||||
if payload["type"] == "connect_response":
|
||||
self.connected = payload["response"] == True
|
||||
if payload["response"] is True:
|
||||
self.connections[payload.get("connection_id")] = payload
|
||||
elif payload["type"] == "sql":
|
||||
query_id = payload.get("uuid")
|
||||
status_code = payload.get("status_code")
|
||||
self._query_success_tracker[query_id] = status_code == 200
|
||||
if self._query_success_tracker[query_id]:
|
||||
_LOGGER.debug(f"Received payload {payload}")
|
||||
else:
|
||||
self.connected = False
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user