426 lines
17 KiB
Python
426 lines
17 KiB
Python
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
|
from __future__ import annotations
|
|
|
|
from ipaddress import IPv4Address
|
|
from typing import Any, Dict, Optional, Union
|
|
from uuid import uuid4
|
|
|
|
from prettytable import MARKDOWN, PrettyTable
|
|
from pydantic import BaseModel
|
|
|
|
from primaite.interface.request import RequestFormat, RequestResponse
|
|
from primaite.simulator.core import RequestManager, RequestType
|
|
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
|
|
from primaite.simulator.system.applications.application import Application
|
|
from primaite.simulator.system.core.software_manager import SoftwareManager
|
|
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
|
from primaite.utils.validation.ipv4_address import IPV4Address
|
|
from primaite.utils.validation.port import PORT_LOOKUP
|
|
|
|
|
|
class DatabaseClientConnection(BaseModel):
|
|
"""
|
|
DatabaseClientConnection Class.
|
|
|
|
This class is used to record current DatabaseConnections within the DatabaseClient class.
|
|
"""
|
|
|
|
connection_id: str
|
|
"""Connection UUID."""
|
|
|
|
parent_node: HostNode
|
|
"""The parent Node that this connection was created on."""
|
|
|
|
is_active: bool = True
|
|
"""Flag to state whether the connection is still active or not."""
|
|
|
|
@property
|
|
def client(self) -> Optional[DatabaseClient]:
|
|
"""The DatabaseClient that holds this connection."""
|
|
return self.parent_node.software_manager.software.get("DatabaseClient")
|
|
|
|
def query(self, sql: str) -> bool:
|
|
"""
|
|
Query the databaseserver.
|
|
|
|
:return: Boolean value
|
|
"""
|
|
if self.is_active and self.client:
|
|
return self.client._query(connection_id=self.connection_id, sql=sql) # noqa
|
|
return False
|
|
|
|
def disconnect(self):
|
|
"""Disconnect the connection."""
|
|
if self.client and self.is_active:
|
|
self.client._disconnect(self.connection_id) # noqa
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.__class__.__name__}(connection_id='{self.connection_id}', is_active={self.is_active})"
|
|
|
|
def __repr__(self) -> str:
|
|
return str(self)
|
|
|
|
|
|
class DatabaseClient(Application, identifier="DatabaseClient"):
|
|
"""
|
|
A DatabaseClient application.
|
|
|
|
Extends the Application class to provide functionality for connecting, querying, and disconnecting from a
|
|
Database Service. It mainly operates over TCP protocol.
|
|
|
|
:ivar server_ip_address: The IPv4 address of the Database Service server, defaults to None.
|
|
"""
|
|
|
|
server_ip_address: Optional[IPv4Address] = None
|
|
server_password: Optional[str] = None
|
|
_query_success_tracker: Dict[str, bool] = {}
|
|
"""Keep track of connections that were established or verified during this step. Used for rewards."""
|
|
last_query_response: Optional[Dict] = None
|
|
"""Keep track of the latest query response. Used to determine rewards."""
|
|
_server_connection_id: Optional[str] = None
|
|
"""Connection ID to the Database Server."""
|
|
client_connections: Dict[str, DatabaseClientConnection] = {}
|
|
"""Keep track of active connections to Database Server."""
|
|
_client_connection_requests: Dict[str, Optional[Union[str, DatabaseClientConnection]]] = {}
|
|
"""Dictionary of connection requests to Database Server."""
|
|
connected: bool = False
|
|
"""Boolean Value for whether connected to DB Server."""
|
|
native_connection: Optional[DatabaseClientConnection] = None
|
|
"""Native Client Connection for using the client directly (similar to psql in a terminal)."""
|
|
|
|
def __init__(self, **kwargs):
|
|
kwargs["name"] = "DatabaseClient"
|
|
kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"]
|
|
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
|
|
super().__init__(**kwargs)
|
|
|
|
def _init_request_manager(self) -> RequestManager:
|
|
"""
|
|
Initialise the request manager.
|
|
|
|
More information in user guide and docstring for SimComponent._init_request_manager.
|
|
"""
|
|
rm = super()._init_request_manager()
|
|
rm.add_request("execute", RequestType(func=lambda request, context: RequestResponse.from_bool(self.execute())))
|
|
|
|
def _configure(request: RequestFormat, context: Dict) -> RequestResponse:
|
|
ip, pw = request[-1].get("server_ip_address"), request[-1].get("server_password")
|
|
ip = None if ip is None else IPV4Address(ip)
|
|
success = self.configure(server_ip_address=ip, server_password=pw)
|
|
return RequestResponse.from_bool(success)
|
|
|
|
rm.add_request("configure", RequestType(func=_configure))
|
|
return rm
|
|
|
|
def execute(self) -> bool:
|
|
"""Execution definition for db client: perform a select query."""
|
|
if not self._can_perform_action():
|
|
return False
|
|
|
|
self.num_executions += 1 # trying to connect counts as an execution
|
|
|
|
if not self.native_connection:
|
|
self.connect()
|
|
|
|
if self.native_connection:
|
|
return self.check_connection(connection_id=self.native_connection.connection_id)
|
|
|
|
return False
|
|
|
|
def describe_state(self) -> Dict:
|
|
"""
|
|
Describes the current state of the ACLRule.
|
|
|
|
:return: A dictionary representing the current state.
|
|
"""
|
|
state = super().describe_state()
|
|
return state
|
|
|
|
def show(self, markdown: bool = False):
|
|
"""
|
|
Display the client connections in tabular format.
|
|
|
|
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
|
|
"""
|
|
table = PrettyTable(["Connection ID", "Active"])
|
|
if markdown:
|
|
table.set_style(MARKDOWN)
|
|
table.align = "l"
|
|
table.title = f"{self.sys_log.hostname} {self.name} Client Connections"
|
|
if self.native_connection:
|
|
table.add_row([self.native_connection.connection_id, self.native_connection.is_active])
|
|
for connection_id, connection in self.client_connections.items():
|
|
table.add_row([connection_id, connection.is_active])
|
|
print(table.get_string(sortby="Connection ID"))
|
|
|
|
def configure(self, server_ip_address: Optional[IPv4Address] = None, server_password: Optional[str] = None) -> bool:
|
|
"""
|
|
Configure the DatabaseClient to communicate with a DatabaseService.
|
|
|
|
:param server_ip_address: The IP address of the Node the DatabaseService is on.
|
|
:param server_password: The password on the DatabaseService.
|
|
"""
|
|
self.server_ip_address = server_ip_address or self.server_ip_address
|
|
self.server_password = server_password or self.server_password
|
|
self.sys_log.info(f"{self.name}: Configured the {self.name} with {server_ip_address=}, {server_password=}.")
|
|
return True
|
|
|
|
def connect(self) -> bool:
|
|
"""Connect the native client connection."""
|
|
if self.native_connection:
|
|
return True
|
|
self.native_connection = self.get_new_connection()
|
|
return self.native_connection is not None
|
|
|
|
def disconnect(self):
|
|
"""Disconnect the native client connection."""
|
|
if self.native_connection:
|
|
self._disconnect(self.native_connection.connection_id)
|
|
self.native_connection = None
|
|
|
|
def check_connection(self, connection_id: str) -> bool:
|
|
"""Check whether the connection can be successfully re-established.
|
|
|
|
:param connection_id: connection ID to check
|
|
:type connection_id: str
|
|
:return: Whether the connection was successfully re-established.
|
|
:rtype: bool
|
|
"""
|
|
if not self._can_perform_action():
|
|
return False
|
|
return self._query("SELECT * FROM pg_stat_activity", connection_id=connection_id)
|
|
|
|
def _validate_client_connection_request(self, connection_id: str) -> bool:
|
|
"""Check that client_connection_id is valid."""
|
|
return True if connection_id in self._client_connection_requests else False
|
|
|
|
def _connect(
|
|
self,
|
|
server_ip_address: IPv4Address,
|
|
connection_request_id: str,
|
|
password: Optional[str] = None,
|
|
is_reattempt: bool = False,
|
|
) -> Optional[DatabaseClientConnection]:
|
|
"""
|
|
Connects the DatabaseClient to the DatabaseServer.
|
|
|
|
:param: server_ip_address: IP address of the database server
|
|
:type: server_ip_address: IPv4Address
|
|
|
|
:param: password: Password used to connect to the database server. Optional.
|
|
:type: password: Optional[str]
|
|
|
|
:param: is_reattempt: True if the connect request has been reattempted. Default False
|
|
:type: is_reattempt: Optional[bool]
|
|
"""
|
|
if is_reattempt:
|
|
valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id)
|
|
if valid_connection_request:
|
|
database_client_connection = self._client_connection_requests.pop(connection_request_id)
|
|
if isinstance(database_client_connection, DatabaseClientConnection):
|
|
self.sys_log.info(
|
|
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} authorised. "
|
|
f"Using connection id {database_client_connection}"
|
|
)
|
|
self.connected = True
|
|
return database_client_connection
|
|
else:
|
|
self.sys_log.info(
|
|
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined"
|
|
)
|
|
return None
|
|
else:
|
|
self.sys_log.info(
|
|
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined "
|
|
f"due to unknown client-side connection request id"
|
|
)
|
|
return None
|
|
|
|
payload = {"type": "connect_request", "password": password, "connection_request_id": connection_request_id}
|
|
software_manager: SoftwareManager = self.software_manager
|
|
software_manager.send_payload_to_session_manager(
|
|
payload=payload, dest_ip_address=server_ip_address, dest_port=self.port
|
|
)
|
|
return self._connect(
|
|
server_ip_address=server_ip_address,
|
|
password=password,
|
|
is_reattempt=True,
|
|
connection_request_id=connection_request_id,
|
|
)
|
|
|
|
def _disconnect(self, connection_id: str) -> bool:
|
|
"""Disconnect from the Database Service.
|
|
|
|
If no connection_id is provided, connect from first ID in
|
|
self.client_connections.
|
|
|
|
:param: connection_id: connection ID to disconnect.
|
|
:type: connection_id: str
|
|
|
|
:return: bool
|
|
"""
|
|
if not self._can_perform_action():
|
|
return False
|
|
|
|
# if there are no connections - nothing to disconnect
|
|
if len(self.client_connections) == 0:
|
|
self.sys_log.warning(f"{self.name}: Unable to disconnect, no active connections.")
|
|
return False
|
|
if not self.client_connections.get(connection_id):
|
|
return False
|
|
software_manager: SoftwareManager = self.software_manager
|
|
software_manager.send_payload_to_session_manager(
|
|
payload={"type": "disconnect", "connection_id": connection_id},
|
|
dest_ip_address=self.server_ip_address,
|
|
dest_port=self.port,
|
|
)
|
|
connection = self.client_connections.pop(connection_id)
|
|
self.terminate_connection(connection_id=connection_id)
|
|
|
|
connection.is_active = False
|
|
|
|
self.sys_log.info(f"{self.name}: DatabaseClient disconnected {connection_id} from {self.server_ip_address}")
|
|
self.connected = False
|
|
return True
|
|
|
|
def uninstall(self) -> None:
|
|
"""
|
|
Uninstall the DatabaseClient.
|
|
|
|
Calls disconnect on all client connections to ensure that both client and server connections are killed.
|
|
"""
|
|
while self.client_connections:
|
|
conn_key = next(iter(self.client_connections.keys()))
|
|
conn_obj: DatabaseClientConnection = self.client_connections[conn_key]
|
|
conn_obj.disconnect()
|
|
if conn_obj.is_active or conn_key in self.client_connections:
|
|
self.sys_log.error(
|
|
"Attempted to uninstall database client but could not drop active connections. "
|
|
"Forcing uninstall anyway."
|
|
)
|
|
self.client_connections.pop(conn_key, None)
|
|
super().uninstall()
|
|
|
|
def get_new_connection(self) -> Optional[DatabaseClientConnection]:
|
|
"""Get a new connection to the DatabaseServer.
|
|
|
|
:return: DatabaseClientConnection object
|
|
"""
|
|
if not self._can_perform_action():
|
|
return None
|
|
if self.server_ip_address is None:
|
|
self.sys_log.warning(f"{self.name}: Database server IP address not provided.")
|
|
return None
|
|
|
|
connection_request_id = str(uuid4())
|
|
self._client_connection_requests[connection_request_id] = None
|
|
|
|
self.sys_log.info(
|
|
f"{self.name}: Sending new connection request ({connection_request_id}) to {self.server_ip_address}"
|
|
)
|
|
|
|
return self._connect(
|
|
server_ip_address=self.server_ip_address,
|
|
password=self.server_password,
|
|
connection_request_id=connection_request_id,
|
|
)
|
|
|
|
def _create_client_connection(self, connection_id: str, connection_request_id: str) -> None:
|
|
"""Create a new DatabaseClientConnection Object."""
|
|
client_connection = DatabaseClientConnection(
|
|
connection_id=connection_id, client=self, parent_node=self.software_manager.node
|
|
)
|
|
self.client_connections[connection_id] = client_connection
|
|
self._client_connection_requests[connection_request_id] = client_connection
|
|
|
|
def _query(self, sql: str, connection_id: str, query_id: Optional[str] = False, is_reattempt: bool = False) -> bool:
|
|
"""
|
|
Send a query to the connected database server.
|
|
|
|
:param: sql: SQL query to send to the database server.
|
|
:type: sql: str
|
|
|
|
:param: query_id: ID of the query, used as reference
|
|
:type: query_id: str
|
|
|
|
:param: connection_id: ID of the connection to the database server.
|
|
:type: connection_id: str
|
|
|
|
:param: is_reattempt: True if the query request has been reattempted. Default False
|
|
:type: is_reattempt: Optional[bool]
|
|
"""
|
|
if not query_id:
|
|
query_id = str(uuid4())
|
|
if is_reattempt:
|
|
success = self._query_success_tracker.get(query_id)
|
|
if success:
|
|
self.sys_log.info(f"{self.name}: Query successful {sql}")
|
|
return True
|
|
self.sys_log.error(f"{self.name}: Unable to run query {sql}")
|
|
return False
|
|
else:
|
|
software_manager: SoftwareManager = self.software_manager
|
|
software_manager.send_payload_to_session_manager(
|
|
payload={"type": "sql", "sql": sql, "uuid": query_id, "connection_id": connection_id},
|
|
dest_ip_address=self.server_ip_address,
|
|
dest_port=self.port,
|
|
)
|
|
return self._query(sql=sql, query_id=query_id, connection_id=connection_id, is_reattempt=True)
|
|
|
|
def run(self) -> None:
|
|
"""Run the DatabaseClient."""
|
|
super().run()
|
|
|
|
def query(self, sql: str) -> bool:
|
|
"""
|
|
Send a query to the Database Service.
|
|
|
|
:param: sql: The SQL query.
|
|
:type: sql: str
|
|
|
|
:return: True if the query was successful, otherwise False.
|
|
"""
|
|
if not self._can_perform_action():
|
|
return False
|
|
|
|
if not self.native_connection:
|
|
return False
|
|
|
|
# reset last query response
|
|
self.last_query_response = None
|
|
|
|
uuid = str(uuid4())
|
|
self._query_success_tracker[uuid] = False
|
|
return self.native_connection.query(sql)
|
|
|
|
def receive(self, session_id: str, payload: Any, **kwargs) -> bool:
|
|
"""
|
|
Receive a payload from the Software Manager.
|
|
|
|
:param payload: A payload to receive.
|
|
:param session_id: The session id the payload relates to.
|
|
:return: True.
|
|
"""
|
|
if not self._can_perform_action():
|
|
return False
|
|
if isinstance(payload, dict) and payload.get("type"):
|
|
if payload["type"] == "connect_response":
|
|
if payload["response"] is True:
|
|
# add connection
|
|
connection_id = payload["connection_id"]
|
|
self._create_client_connection(
|
|
connection_id=connection_id, connection_request_id=payload["connection_request_id"]
|
|
)
|
|
elif payload["type"] == "sql":
|
|
self.last_query_response = payload
|
|
query_id = payload.get("uuid")
|
|
status_code = payload.get("status_code")
|
|
self._query_success_tracker[query_id] = status_code == 200
|
|
if self._query_success_tracker[query_id]:
|
|
self.sys_log.debug(f"Received {payload=}")
|
|
elif payload["type"] == "disconnect":
|
|
connection_id = payload["connection_id"]
|
|
self.sys_log.info(f"{self.name}: Received disconnect command for {connection_id=} from the server")
|
|
self._disconnect(payload["connection_id"])
|
|
return True
|