Files
PrimAITE/src/primaite/simulator/system/applications/database_client.py
2024-03-11 10:20:47 +00:00

260 lines
10 KiB
Python

from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from uuid import uuid4
from primaite import getLogger
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.core.software_manager import SoftwareManager
_LOGGER = getLogger(__name__)
class DatabaseClient(Application):
"""
A DatabaseClient application.
Extends the Application class to provide functionality for connecting, querying, and disconnecting from a
Database Service. It mainly operates over TCP protocol.
:ivar server_ip_address: The IPv4 address of the Database Service server, defaults to None.
"""
server_ip_address: Optional[IPv4Address] = None
server_password: Optional[str] = None
connected: bool = False
_query_success_tracker: Dict[str, bool] = {}
_last_connection_successful: Optional[bool] = None
"""Keep track of connections that were established or verified during this step. Used for rewards."""
def __init__(self, **kwargs):
kwargs["name"] = "DatabaseClient"
kwargs["port"] = Port.POSTGRES_SERVER
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
rm.add_request("execute", RequestType(func=lambda request, context: RequestResponse.from_bool(self.execute())))
return rm
def execute(self) -> bool:
"""Execution definition for db client: perform a select query."""
if self.connections:
can_connect = self.check_connection(connection_id=list(self.connections.keys())[-1])
else:
can_connect = self.check_connection(connection_id=str(uuid4()))
self._last_connection_successful = can_connect
return can_connect
def describe_state(self) -> Dict:
"""
Describes the current state of the ACLRule.
:return: A dictionary representing the current state.
"""
state = super().describe_state()
# list of connections that were established or verified during this step.
state["last_connection_successful"] = self._last_connection_successful
return state
def configure(self, server_ip_address: IPv4Address, server_password: Optional[str] = None):
"""
Configure the DatabaseClient to communicate with a DatabaseService.
:param server_ip_address: The IP address of the Node the DatabaseService is on.
:param server_password: The password on the DatabaseService.
"""
self.server_ip_address = server_ip_address
self.server_password = server_password
self.sys_log.info(f"{self.name}: Configured the {self.name} with {server_ip_address=}, {server_password=}.")
def connect(self, connection_id: Optional[str] = None) -> bool:
"""Connect to a Database Service."""
if not self._can_perform_action():
return False
if not connection_id:
connection_id = str(uuid4())
self.connected = self._connect(
server_ip_address=self.server_ip_address, password=self.server_password, connection_id=connection_id
)
return self.connected
def check_connection(self, connection_id: str) -> bool:
"""Check whether the connection can be successfully re-established.
:param connection_id: connection ID to check
:type connection_id: str
:return: Whether the connection was successfully re-established.
:rtype: bool
"""
if not self._can_perform_action():
return False
return self.query("SELECT * FROM pg_stat_activity", connection_id=connection_id)
def _connect(
self,
server_ip_address: IPv4Address,
connection_id: Optional[str] = None,
password: Optional[str] = None,
is_reattempt: bool = False,
) -> bool:
"""
Connects the DatabaseClient to the DatabaseServer.
:param: server_ip_address: IP address of the database server
:type: server_ip_address: IPv4Address
:param: password: Password used to connect to the database server. Optional.
:type: password: Optional[str]
:param: is_reattempt: True if the connect request has been reattempted. Default False
:type: is_reattempt: Optional[bool]
"""
if is_reattempt:
if self.connections.get(connection_id):
self.sys_log.info(
f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} authorised"
)
self.server_ip_address = server_ip_address
return True
else:
self.sys_log.info(
f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} declined"
)
return False
payload = {
"type": "connect_request",
"password": password,
"connection_id": connection_id,
}
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload=payload, dest_ip_address=server_ip_address, dest_port=self.port
)
return self._connect(
server_ip_address=server_ip_address, password=password, connection_id=connection_id, is_reattempt=True
)
def disconnect(self, connection_id: Optional[str] = None) -> bool:
"""Disconnect from the Database Service."""
if not self._can_perform_action():
self.sys_log.error(f"Unable to disconnect - {self.name} is {self.operating_state.name}")
return False
# if there are no connections - nothing to disconnect
if not len(self.connections):
self.sys_log.error(f"Unable to disconnect - {self.name} has no active connections.")
return False
# if no connection provided, disconnect the first connection
if not connection_id:
connection_id = list(self.connections.keys())[0]
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "disconnect", "connection_id": connection_id},
dest_ip_address=self.server_ip_address,
dest_port=self.port,
)
self.remove_connection(connection_id=connection_id)
self.sys_log.info(
f"{self.name}: DatabaseClient disconnected connection {connection_id} from {self.server_ip_address}"
)
self.connected = False
def _query(self, sql: str, query_id: str, connection_id: str, is_reattempt: bool = False) -> bool:
"""
Send a query to the connected database server.
:param: sql: SQL query to send to the database server.
:type: sql: str
:param: query_id: ID of the query, used as reference
:type: query_id: str
:param: is_reattempt: True if the query request has been reattempted. Default False
:type: is_reattempt: Optional[bool]
"""
if is_reattempt:
success = self._query_success_tracker.get(query_id)
if success:
self.sys_log.info(f"{self.name}: Query successful {sql}")
return True
self.sys_log.info(f"{self.name}: Unable to run query {sql}")
return False
else:
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "sql", "sql": sql, "uuid": query_id, "connection_id": connection_id},
dest_ip_address=self.server_ip_address,
dest_port=self.port,
)
return self._query(sql=sql, query_id=query_id, connection_id=connection_id, is_reattempt=True)
def run(self) -> None:
"""Run the DatabaseClient."""
super().run()
def query(self, sql: str, connection_id: Optional[str] = None) -> bool:
"""
Send a query to the Database Service.
:param: sql: The SQL query.
:param: is_reattempt: If true, the action has been reattempted.
:return: True if the query was successful, otherwise False.
"""
if not self._can_perform_action():
return False
if connection_id is None:
if self.connections:
connection_id = list(self.connections.keys())[-1]
# TODO: if the most recent connection dies, it should be automatically cleared.
else:
connection_id = str(uuid4())
if not self.connections.get(connection_id):
if not self.connect(connection_id=connection_id):
return False
# Initialise the tracker of this ID to False
uuid = str(uuid4())
self._query_success_tracker[uuid] = False
return self._query(sql=sql, query_id=uuid, connection_id=connection_id)
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Receive a payload from the Software Manager.
:param payload: A payload to receive.
:param session_id: The session id the payload relates to.
:return: True.
"""
if not self._can_perform_action():
return False
if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "connect_response":
if payload["response"] is True:
# add connection
self.add_connection(connection_id=payload.get("connection_id"), session_id=session_id)
elif payload["type"] == "sql":
query_id = payload.get("uuid")
status_code = payload.get("status_code")
self._query_success_tracker[query_id] = status_code == 200
if self._query_success_tracker[query_id]:
_LOGGER.debug(f"Received payload {payload}")
return True