Files
PrimAITE/src/primaite/simulator/system/applications/database_client.py
2024-10-30 11:15:39 +00:00

426 lines
17 KiB
Python

# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from __future__ import annotations
from ipaddress import IPv4Address
from typing import Any, Dict, Optional, Union
from uuid import uuid4
from prettytable import MARKDOWN, PrettyTable
from pydantic import BaseModel
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.ipv4_address import IPV4Address
from primaite.utils.validation.port import PORT_LOOKUP
class DatabaseClientConnection(BaseModel):
"""
DatabaseClientConnection Class.
This class is used to record current DatabaseConnections within the DatabaseClient class.
"""
connection_id: str
"""Connection UUID."""
parent_node: HostNode
"""The parent Node that this connection was created on."""
is_active: bool = True
"""Flag to state whether the connection is still active or not."""
@property
def client(self) -> Optional[DatabaseClient]:
"""The DatabaseClient that holds this connection."""
return self.parent_node.software_manager.software.get("DatabaseClient")
def query(self, sql: str) -> bool:
"""
Query the databaseserver.
:return: Boolean value
"""
if self.is_active and self.client:
return self.client._query(connection_id=self.connection_id, sql=sql) # noqa
return False
def disconnect(self):
"""Disconnect the connection."""
if self.client and self.is_active:
self.client._disconnect(self.connection_id) # noqa
def __str__(self) -> str:
return f"{self.__class__.__name__}(connection_id='{self.connection_id}', is_active={self.is_active})"
def __repr__(self) -> str:
return str(self)
class DatabaseClient(Application, identifier="DatabaseClient"):
"""
A DatabaseClient application.
Extends the Application class to provide functionality for connecting, querying, and disconnecting from a
Database Service. It mainly operates over TCP protocol.
:ivar server_ip_address: The IPv4 address of the Database Service server, defaults to None.
"""
server_ip_address: Optional[IPv4Address] = None
server_password: Optional[str] = None
_query_success_tracker: Dict[str, bool] = {}
"""Keep track of connections that were established or verified during this step. Used for rewards."""
last_query_response: Optional[Dict] = None
"""Keep track of the latest query response. Used to determine rewards."""
_server_connection_id: Optional[str] = None
"""Connection ID to the Database Server."""
client_connections: Dict[str, DatabaseClientConnection] = {}
"""Keep track of active connections to Database Server."""
_client_connection_requests: Dict[str, Optional[Union[str, DatabaseClientConnection]]] = {}
"""Dictionary of connection requests to Database Server."""
connected: bool = False
"""Boolean Value for whether connected to DB Server."""
native_connection: Optional[DatabaseClientConnection] = None
"""Native Client Connection for using the client directly (similar to psql in a terminal)."""
def __init__(self, **kwargs):
kwargs["name"] = "DatabaseClient"
kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"]
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
rm.add_request("execute", RequestType(func=lambda request, context: RequestResponse.from_bool(self.execute())))
def _configure(request: RequestFormat, context: Dict) -> RequestResponse:
ip, pw = request[-1].get("server_ip_address"), request[-1].get("server_password")
ip = None if ip is None else IPV4Address(ip)
success = self.configure(server_ip_address=ip, server_password=pw)
return RequestResponse.from_bool(success)
rm.add_request("configure", RequestType(func=_configure))
return rm
def execute(self) -> bool:
"""Execution definition for db client: perform a select query."""
if not self._can_perform_action():
return False
self.num_executions += 1 # trying to connect counts as an execution
if not self.native_connection:
self.connect()
if self.native_connection:
return self.check_connection(connection_id=self.native_connection.connection_id)
return False
def describe_state(self) -> Dict:
"""
Describes the current state of the ACLRule.
:return: A dictionary representing the current state.
"""
state = super().describe_state()
return state
def show(self, markdown: bool = False):
"""
Display the client connections in tabular format.
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
"""
table = PrettyTable(["Connection ID", "Active"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.sys_log.hostname} {self.name} Client Connections"
if self.native_connection:
table.add_row([self.native_connection.connection_id, self.native_connection.is_active])
for connection_id, connection in self.client_connections.items():
table.add_row([connection_id, connection.is_active])
print(table.get_string(sortby="Connection ID"))
def configure(self, server_ip_address: Optional[IPv4Address] = None, server_password: Optional[str] = None) -> bool:
"""
Configure the DatabaseClient to communicate with a DatabaseService.
:param server_ip_address: The IP address of the Node the DatabaseService is on.
:param server_password: The password on the DatabaseService.
"""
self.server_ip_address = server_ip_address or self.server_ip_address
self.server_password = server_password or self.server_password
self.sys_log.info(f"{self.name}: Configured the {self.name} with {server_ip_address=}, {server_password=}.")
return True
def connect(self) -> bool:
"""Connect the native client connection."""
if self.native_connection:
return True
self.native_connection = self.get_new_connection()
return self.native_connection is not None
def disconnect(self):
"""Disconnect the native client connection."""
if self.native_connection:
self._disconnect(self.native_connection.connection_id)
self.native_connection = None
def check_connection(self, connection_id: str) -> bool:
"""Check whether the connection can be successfully re-established.
:param connection_id: connection ID to check
:type connection_id: str
:return: Whether the connection was successfully re-established.
:rtype: bool
"""
if not self._can_perform_action():
return False
return self._query("SELECT * FROM pg_stat_activity", connection_id=connection_id)
def _validate_client_connection_request(self, connection_id: str) -> bool:
"""Check that client_connection_id is valid."""
return True if connection_id in self._client_connection_requests else False
def _connect(
self,
server_ip_address: IPv4Address,
connection_request_id: str,
password: Optional[str] = None,
is_reattempt: bool = False,
) -> Optional[DatabaseClientConnection]:
"""
Connects the DatabaseClient to the DatabaseServer.
:param: server_ip_address: IP address of the database server
:type: server_ip_address: IPv4Address
:param: password: Password used to connect to the database server. Optional.
:type: password: Optional[str]
:param: is_reattempt: True if the connect request has been reattempted. Default False
:type: is_reattempt: Optional[bool]
"""
if is_reattempt:
valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id)
if valid_connection_request:
database_client_connection = self._client_connection_requests.pop(connection_request_id)
if isinstance(database_client_connection, DatabaseClientConnection):
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} authorised. "
f"Using connection id {database_client_connection}"
)
self.connected = True
return database_client_connection
else:
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined"
)
return None
else:
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined "
f"due to unknown client-side connection request id"
)
return None
payload = {"type": "connect_request", "password": password, "connection_request_id": connection_request_id}
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload=payload, dest_ip_address=server_ip_address, dest_port=self.port
)
return self._connect(
server_ip_address=server_ip_address,
password=password,
is_reattempt=True,
connection_request_id=connection_request_id,
)
def _disconnect(self, connection_id: str) -> bool:
"""Disconnect from the Database Service.
If no connection_id is provided, connect from first ID in
self.client_connections.
:param: connection_id: connection ID to disconnect.
:type: connection_id: str
:return: bool
"""
if not self._can_perform_action():
return False
# if there are no connections - nothing to disconnect
if len(self.client_connections) == 0:
self.sys_log.warning(f"{self.name}: Unable to disconnect, no active connections.")
return False
if not self.client_connections.get(connection_id):
return False
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "disconnect", "connection_id": connection_id},
dest_ip_address=self.server_ip_address,
dest_port=self.port,
)
connection = self.client_connections.pop(connection_id)
self.terminate_connection(connection_id=connection_id)
connection.is_active = False
self.sys_log.info(f"{self.name}: DatabaseClient disconnected {connection_id} from {self.server_ip_address}")
self.connected = False
return True
def uninstall(self) -> None:
"""
Uninstall the DatabaseClient.
Calls disconnect on all client connections to ensure that both client and server connections are killed.
"""
while self.client_connections:
conn_key = next(iter(self.client_connections.keys()))
conn_obj: DatabaseClientConnection = self.client_connections[conn_key]
conn_obj.disconnect()
if conn_obj.is_active or conn_key in self.client_connections:
self.sys_log.error(
"Attempted to uninstall database client but could not drop active connections. "
"Forcing uninstall anyway."
)
self.client_connections.pop(conn_key, None)
super().uninstall()
def get_new_connection(self) -> Optional[DatabaseClientConnection]:
"""Get a new connection to the DatabaseServer.
:return: DatabaseClientConnection object
"""
if not self._can_perform_action():
return None
if self.server_ip_address is None:
self.sys_log.warning(f"{self.name}: Database server IP address not provided.")
return None
connection_request_id = str(uuid4())
self._client_connection_requests[connection_request_id] = None
self.sys_log.info(
f"{self.name}: Sending new connection request ({connection_request_id}) to {self.server_ip_address}"
)
return self._connect(
server_ip_address=self.server_ip_address,
password=self.server_password,
connection_request_id=connection_request_id,
)
def _create_client_connection(self, connection_id: str, connection_request_id: str) -> None:
"""Create a new DatabaseClientConnection Object."""
client_connection = DatabaseClientConnection(
connection_id=connection_id, client=self, parent_node=self.software_manager.node
)
self.client_connections[connection_id] = client_connection
self._client_connection_requests[connection_request_id] = client_connection
def _query(self, sql: str, connection_id: str, query_id: Optional[str] = False, is_reattempt: bool = False) -> bool:
"""
Send a query to the connected database server.
:param: sql: SQL query to send to the database server.
:type: sql: str
:param: query_id: ID of the query, used as reference
:type: query_id: str
:param: connection_id: ID of the connection to the database server.
:type: connection_id: str
:param: is_reattempt: True if the query request has been reattempted. Default False
:type: is_reattempt: Optional[bool]
"""
if not query_id:
query_id = str(uuid4())
if is_reattempt:
success = self._query_success_tracker.get(query_id)
if success:
self.sys_log.info(f"{self.name}: Query successful {sql}")
return True
self.sys_log.error(f"{self.name}: Unable to run query {sql}")
return False
else:
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "sql", "sql": sql, "uuid": query_id, "connection_id": connection_id},
dest_ip_address=self.server_ip_address,
dest_port=self.port,
)
return self._query(sql=sql, query_id=query_id, connection_id=connection_id, is_reattempt=True)
def run(self) -> None:
"""Run the DatabaseClient."""
super().run()
def query(self, sql: str) -> bool:
"""
Send a query to the Database Service.
:param: sql: The SQL query.
:type: sql: str
:return: True if the query was successful, otherwise False.
"""
if not self._can_perform_action():
return False
if not self.native_connection:
return False
# reset last query response
self.last_query_response = None
uuid = str(uuid4())
self._query_success_tracker[uuid] = False
return self.native_connection.query(sql)
def receive(self, session_id: str, payload: Any, **kwargs) -> bool:
"""
Receive a payload from the Software Manager.
:param payload: A payload to receive.
:param session_id: The session id the payload relates to.
:return: True.
"""
if not self._can_perform_action():
return False
if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "connect_response":
if payload["response"] is True:
# add connection
connection_id = payload["connection_id"]
self._create_client_connection(
connection_id=connection_id, connection_request_id=payload["connection_request_id"]
)
elif payload["type"] == "sql":
self.last_query_response = payload
query_id = payload.get("uuid")
status_code = payload.get("status_code")
self._query_success_tracker[query_id] = status_code == 200
if self._query_success_tracker[query_id]:
self.sys_log.debug(f"Received {payload=}")
elif payload["type"] == "disconnect":
connection_id = payload["connection_id"]
self.sys_log.info(f"{self.name}: Received disconnect command for {connection_id=} from the server")
self._disconnect(payload["connection_id"])
return True