Merge remote-tracking branch 'origin/dev' into feature/2476-training-schedules-mockup
This commit is contained in:
@@ -1299,7 +1299,6 @@ class Node(SimComponent):
|
||||
self.services.pop(service.uuid)
|
||||
service.parent = None
|
||||
self.sys_log.info(f"Uninstalled service {service.name}")
|
||||
_LOGGER.info(f"Removed service {service.name} from node {self.hostname}")
|
||||
self._service_request_manager.remove_request(service.name)
|
||||
|
||||
def install_application(self, application: Application) -> None:
|
||||
@@ -1335,7 +1334,6 @@ class Node(SimComponent):
|
||||
self.applications.pop(application.uuid)
|
||||
application.parent = None
|
||||
self.sys_log.info(f"Uninstalled application {application.name}")
|
||||
_LOGGER.info(f"Removed application {application.name} from node {self.hostname}")
|
||||
self._application_request_manager.remove_request(application.name)
|
||||
|
||||
def application_install_action(self, application: Application, ip_address: Optional[str] = None) -> bool:
|
||||
|
||||
@@ -1,15 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
|
||||
|
||||
class DatabaseClientConnection(BaseModel):
|
||||
"""
|
||||
DatabaseClientConnection Class.
|
||||
|
||||
This class is used to record current DatabaseConnections within the DatabaseClient class.
|
||||
"""
|
||||
|
||||
connection_id: str
|
||||
"""Connection UUID."""
|
||||
|
||||
parent_node: HostNode
|
||||
"""The parent Node that this connection was created on."""
|
||||
|
||||
is_active: bool = True
|
||||
"""Flag to state whether the connection is still active or not."""
|
||||
|
||||
@property
|
||||
def client(self) -> Optional[DatabaseClient]:
|
||||
"""The DatabaseClient that holds this connection."""
|
||||
return self.parent_node.software_manager.software.get("DatabaseClient")
|
||||
|
||||
def query(self, sql: str) -> bool:
|
||||
"""
|
||||
Query the databaseserver.
|
||||
|
||||
:return: Boolean value
|
||||
"""
|
||||
if self.is_active and self.client:
|
||||
return self.client._query(connection_id=self.connection_id, sql=sql) # noqa
|
||||
return False
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect the connection."""
|
||||
if self.client and self.is_active:
|
||||
self.client._disconnect(self.connection_id) # noqa
|
||||
|
||||
|
||||
class DatabaseClient(Application):
|
||||
"""
|
||||
A DatabaseClient application.
|
||||
@@ -22,13 +65,21 @@ class DatabaseClient(Application):
|
||||
|
||||
server_ip_address: Optional[IPv4Address] = None
|
||||
server_password: Optional[str] = None
|
||||
connected: bool = False
|
||||
_query_success_tracker: Dict[str, bool] = {}
|
||||
_last_connection_successful: Optional[bool] = None
|
||||
_query_success_tracker: Dict[str, bool] = {}
|
||||
"""Keep track of connections that were established or verified during this step. Used for rewards."""
|
||||
last_query_response: Optional[Dict] = None
|
||||
"""Keep track of the latest query response. Used to determine rewards."""
|
||||
_server_connection_id: Optional[str] = None
|
||||
"""Connection ID to the Database Server."""
|
||||
client_connections: Dict[str, DatabaseClientConnection] = {}
|
||||
"""Keep track of active connections to Database Server."""
|
||||
_client_connection_requests: Dict[str, Optional[str]] = {}
|
||||
"""Dictionary of connection requests to Database Server."""
|
||||
connected: bool = False
|
||||
"""Boolean Value for whether connected to DB Server."""
|
||||
native_connection: Optional[DatabaseClientConnection] = None
|
||||
"""Native Client Connection for using the client directly (similar to psql in a terminal)."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "DatabaseClient"
|
||||
@@ -48,12 +99,18 @@ class DatabaseClient(Application):
|
||||
|
||||
def execute(self) -> bool:
|
||||
"""Execution definition for db client: perform a select query."""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
self.num_executions += 1 # trying to connect counts as an execution
|
||||
if not self._server_connection_id:
|
||||
|
||||
if not self.native_connection:
|
||||
self.connect()
|
||||
can_connect = self.check_connection(connection_id=self._server_connection_id)
|
||||
self._last_connection_successful = can_connect
|
||||
return can_connect
|
||||
|
||||
if self.native_connection:
|
||||
return self.check_connection(connection_id=self.native_connection.connection_id)
|
||||
|
||||
return False
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
@@ -66,6 +123,23 @@ class DatabaseClient(Application):
|
||||
state["last_connection_successful"] = self._last_connection_successful
|
||||
return state
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
"""
|
||||
Display the client connections in tabular format.
|
||||
|
||||
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
|
||||
"""
|
||||
table = PrettyTable(["Connection ID", "Active"])
|
||||
if markdown:
|
||||
table.set_style(MARKDOWN)
|
||||
table.align = "l"
|
||||
table.title = f"{self.sys_log.hostname} {self.name} Client Connections"
|
||||
if self.native_connection:
|
||||
table.add_row([self.native_connection.connection_id, self.native_connection.is_active])
|
||||
for connection_id, connection in self.client_connections.items():
|
||||
table.add_row([connection_id, connection.is_active])
|
||||
print(table.get_string(sortby="Connection ID"))
|
||||
|
||||
def configure(self, server_ip_address: IPv4Address, server_password: Optional[str] = None):
|
||||
"""
|
||||
Configure the DatabaseClient to communicate with a DatabaseService.
|
||||
@@ -78,21 +152,17 @@ class DatabaseClient(Application):
|
||||
self.sys_log.info(f"{self.name}: Configured the {self.name} with {server_ip_address=}, {server_password=}.")
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""Connect to a Database Service."""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
"""Connect the native client connection."""
|
||||
if self.native_connection:
|
||||
return True
|
||||
self.native_connection = self.get_new_connection()
|
||||
return self.native_connection is not None
|
||||
|
||||
if not self._server_connection_id:
|
||||
self._server_connection_id = str(uuid4())
|
||||
|
||||
self.connected = self._connect(
|
||||
server_ip_address=self.server_ip_address,
|
||||
password=self.server_password,
|
||||
connection_id=self._server_connection_id,
|
||||
)
|
||||
if not self.connected:
|
||||
self._server_connection_id = None
|
||||
return self.connected
|
||||
def disconnect(self):
|
||||
"""Disconnect the native client connection."""
|
||||
if self.native_connection:
|
||||
self._disconnect(self.native_connection.connection_id)
|
||||
self.native_connection = None
|
||||
|
||||
def check_connection(self, connection_id: str) -> bool:
|
||||
"""Check whether the connection can be successfully re-established.
|
||||
@@ -104,15 +174,19 @@ class DatabaseClient(Application):
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
return self.query("SELECT * FROM pg_stat_activity", connection_id=connection_id)
|
||||
return self._query("SELECT * FROM pg_stat_activity", connection_id=connection_id)
|
||||
|
||||
def _check_client_connection(self, connection_id: str) -> bool:
|
||||
"""Check that client_connection_id is valid."""
|
||||
return True if connection_id in self._client_connection_requests else False
|
||||
|
||||
def _connect(
|
||||
self,
|
||||
server_ip_address: IPv4Address,
|
||||
connection_id: Optional[str] = None,
|
||||
connection_request_id: str,
|
||||
password: Optional[str] = None,
|
||||
is_reattempt: bool = False,
|
||||
) -> bool:
|
||||
) -> Optional[DatabaseClientConnection]:
|
||||
"""
|
||||
Connects the DatabaseClient to the DatabaseServer.
|
||||
|
||||
@@ -126,56 +200,106 @@ class DatabaseClient(Application):
|
||||
:type: is_reattempt: Optional[bool]
|
||||
"""
|
||||
if is_reattempt:
|
||||
if self._server_connection_id:
|
||||
valid_connection = self._check_client_connection(connection_id=connection_request_id)
|
||||
if valid_connection:
|
||||
database_client_connection = self._client_connection_requests.pop(connection_request_id)
|
||||
self.sys_log.info(
|
||||
f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} authorised"
|
||||
f"{self.name}: DatabaseClient connection to {server_ip_address} authorised."
|
||||
f"Connection Request ID was {connection_request_id}."
|
||||
)
|
||||
self.server_ip_address = server_ip_address
|
||||
return True
|
||||
self.connected = True
|
||||
self._last_connection_successful = True
|
||||
return database_client_connection
|
||||
else:
|
||||
self.sys_log.warning(
|
||||
f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} declined"
|
||||
f"{self.name}: DatabaseClient connection to {server_ip_address} declined."
|
||||
f"Connection Request ID was {connection_request_id}."
|
||||
)
|
||||
return False
|
||||
payload = {
|
||||
"type": "connect_request",
|
||||
"password": password,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
self._last_connection_successful = False
|
||||
return None
|
||||
payload = {"type": "connect_request", "password": password, "connection_request_id": connection_request_id}
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_ip_address=server_ip_address, dest_port=self.port
|
||||
)
|
||||
return self._connect(
|
||||
server_ip_address=server_ip_address, password=password, connection_id=connection_id, is_reattempt=True
|
||||
server_ip_address=server_ip_address,
|
||||
password=password,
|
||||
is_reattempt=True,
|
||||
connection_request_id=connection_request_id,
|
||||
)
|
||||
|
||||
def disconnect(self) -> bool:
|
||||
"""Disconnect from the Database Service."""
|
||||
def _disconnect(self, connection_id: str) -> bool:
|
||||
"""Disconnect from the Database Service.
|
||||
|
||||
If no connection_id is provided, connect from first ID in
|
||||
self.client_connections.
|
||||
|
||||
:param: connection_id: connection ID to disconnect.
|
||||
:type: connection_id: str
|
||||
|
||||
:return: bool
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
self.sys_log.warning(f"Unable to disconnect - {self.name} is {self.operating_state.name}")
|
||||
return False
|
||||
|
||||
# if there are no connections - nothing to disconnect
|
||||
if not self._server_connection_id:
|
||||
self.sys_log.warning(f"Unable to disconnect - {self.name} has no active connections.")
|
||||
if len(self.client_connections) == 0:
|
||||
self.sys_log.warning(f"{self.name}: Unable to disconnect, no active connections.")
|
||||
return False
|
||||
if not self.client_connections.get(connection_id):
|
||||
return False
|
||||
|
||||
# if no connection provided, disconnect the first connection
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "disconnect", "connection_id": self._server_connection_id},
|
||||
payload={"type": "disconnect", "connection_id": connection_id},
|
||||
dest_ip_address=self.server_ip_address,
|
||||
dest_port=self.port,
|
||||
)
|
||||
self.remove_connection(connection_id=self._server_connection_id)
|
||||
connection = self.client_connections.pop(connection_id)
|
||||
self.terminate_connection(connection_id=connection_id)
|
||||
|
||||
self.sys_log.info(
|
||||
f"{self.name}: DatabaseClient disconnected {self._server_connection_id} from {self.server_ip_address}"
|
||||
)
|
||||
connection.is_active = False
|
||||
|
||||
self.sys_log.info(f"{self.name}: DatabaseClient disconnected {connection_id} from {self.server_ip_address}")
|
||||
self.connected = False
|
||||
return True
|
||||
|
||||
def _query(self, sql: str, query_id: str, connection_id: str, is_reattempt: bool = False) -> bool:
|
||||
def uninstall(self) -> None:
|
||||
"""
|
||||
Uninstall the DatabaseClient.
|
||||
|
||||
Calls disconnect on all client connections to ensure that both client and server connections are killed.
|
||||
"""
|
||||
while self.client_connections.values():
|
||||
client_connection = self.client_connections[next(iter(self.client_connections.keys()))]
|
||||
client_connection.disconnect()
|
||||
super().uninstall()
|
||||
|
||||
def get_new_connection(self) -> Optional[DatabaseClientConnection]:
|
||||
"""Get a new connection to the DatabaseServer.
|
||||
|
||||
:return: DatabaseClientConnection object
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return None
|
||||
connection_request_id = str(uuid4())
|
||||
self._client_connection_requests[connection_request_id] = None
|
||||
|
||||
return self._connect(
|
||||
server_ip_address=self.server_ip_address,
|
||||
password=self.server_password,
|
||||
connection_request_id=connection_request_id,
|
||||
)
|
||||
|
||||
def _create_client_connection(self, connection_id: str, connection_request_id: str) -> None:
|
||||
"""Create a new DatabaseClientConnection Object."""
|
||||
client_connection = DatabaseClientConnection(
|
||||
connection_id=connection_id, client=self, parent_node=self.software_manager.node
|
||||
)
|
||||
self.client_connections[connection_id] = client_connection
|
||||
self._client_connection_requests[connection_request_id] = client_connection
|
||||
|
||||
def _query(self, sql: str, connection_id: str, query_id: Optional[str] = False, is_reattempt: bool = False) -> bool:
|
||||
"""
|
||||
Send a query to the connected database server.
|
||||
|
||||
@@ -185,15 +309,22 @@ class DatabaseClient(Application):
|
||||
:param: query_id: ID of the query, used as reference
|
||||
:type: query_id: str
|
||||
|
||||
:param: connection_id: ID of the connection to the database server.
|
||||
:type: connection_id: str
|
||||
|
||||
:param: is_reattempt: True if the query request has been reattempted. Default False
|
||||
:type: is_reattempt: Optional[bool]
|
||||
"""
|
||||
if not query_id:
|
||||
query_id = str(uuid4())
|
||||
if is_reattempt:
|
||||
success = self._query_success_tracker.get(query_id)
|
||||
if success:
|
||||
self.sys_log.info(f"{self.name}: Query successful {sql}")
|
||||
self._last_connection_successful = True
|
||||
return True
|
||||
self.sys_log.error(f"{self.name}: Unable to run query {sql}")
|
||||
self._last_connection_successful = False
|
||||
return False
|
||||
else:
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
@@ -208,39 +339,29 @@ class DatabaseClient(Application):
|
||||
"""Run the DatabaseClient."""
|
||||
super().run()
|
||||
|
||||
def query(self, sql: str, connection_id: Optional[str] = None) -> bool:
|
||||
def query(self, sql: str) -> bool:
|
||||
"""
|
||||
Send a query to the Database Service.
|
||||
|
||||
:param: sql: The SQL query.
|
||||
:param: is_reattempt: If true, the action has been reattempted.
|
||||
:type: sql: str
|
||||
|
||||
:return: True if the query was successful, otherwise False.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
if not self.native_connection:
|
||||
return False
|
||||
|
||||
# reset last query response
|
||||
self.last_query_response = None
|
||||
|
||||
connection_id: str
|
||||
|
||||
if not connection_id:
|
||||
connection_id = self._server_connection_id
|
||||
|
||||
if not connection_id:
|
||||
self.connect()
|
||||
connection_id = self._server_connection_id
|
||||
|
||||
if not connection_id:
|
||||
msg = "Cannot run sql query, could not establish connection with the server."
|
||||
self.parent.sys_log.warning(msg)
|
||||
return False
|
||||
|
||||
uuid = str(uuid4())
|
||||
self._query_success_tracker[uuid] = False
|
||||
return self._query(sql=sql, query_id=uuid, connection_id=connection_id)
|
||||
return self.native_connection.query(sql)
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
def receive(self, session_id: str, payload: Any, **kwargs) -> bool:
|
||||
"""
|
||||
Receive a payload from the Software Manager.
|
||||
|
||||
@@ -250,12 +371,14 @@ class DatabaseClient(Application):
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
if isinstance(payload, dict) and payload.get("type"):
|
||||
if payload["type"] == "connect_response":
|
||||
if payload["response"] is True:
|
||||
# add connection
|
||||
self.add_connection(connection_id=payload.get("connection_id"), session_id=session_id)
|
||||
connection_id = payload["connection_id"]
|
||||
self._create_client_connection(
|
||||
connection_id=connection_id, connection_request_id=payload["connection_request_id"]
|
||||
)
|
||||
elif payload["type"] == "sql":
|
||||
self.last_query_response = payload
|
||||
query_id = payload.get("uuid")
|
||||
@@ -263,4 +386,8 @@ class DatabaseClient(Application):
|
||||
self._query_success_tracker[query_id] = status_code == 200
|
||||
if self._query_success_tracker[query_id]:
|
||||
self.sys_log.debug(f"Received {payload=}")
|
||||
elif payload["type"] == "disconnect":
|
||||
connection_id = payload["connection_id"]
|
||||
self.sys_log.info(f"{self.name}: Received disconnect command for {connection_id=} from the server")
|
||||
self._disconnect(payload["connection_id"])
|
||||
return True
|
||||
|
||||
@@ -9,7 +9,7 @@ from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -53,6 +53,7 @@ class DataManipulationBot(Application):
|
||||
kwargs["protocol"] = IPProtocol.NONE
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self._db_connection: Optional[DatabaseClientConnection] = None
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
@@ -148,6 +149,11 @@ class DataManipulationBot(Application):
|
||||
self.sys_log.debug(f"{self.name}: ")
|
||||
self.attack_stage = DataManipulationAttackStage.PORT_SCAN
|
||||
|
||||
def _establish_db_connection(self) -> bool:
|
||||
"""Establish a db connection to the Database Server."""
|
||||
self._db_connection = self._host_db_client.get_new_connection()
|
||||
return True if self._db_connection else False
|
||||
|
||||
def _perform_data_manipulation(self, p_of_success: Optional[float] = 0.1):
|
||||
"""
|
||||
Execute the data manipulation attack on the target.
|
||||
@@ -167,12 +173,11 @@ class DataManipulationBot(Application):
|
||||
if simulate_trial(p_of_success):
|
||||
self.sys_log.info(f"{self.name}: Performing data manipulation")
|
||||
# perform the attack
|
||||
if not len(self._host_db_client.connections):
|
||||
self._host_db_client.connect()
|
||||
if len(self._host_db_client.connections):
|
||||
self._host_db_client.query(self.payload)
|
||||
if not self._db_connection:
|
||||
self._establish_db_connection()
|
||||
if self._db_connection:
|
||||
attack_successful = self._db_connection.query(self.payload)
|
||||
self.sys_log.info(f"{self.name} payload delivered: {self.payload}")
|
||||
attack_successful = True
|
||||
if attack_successful:
|
||||
self.sys_log.info(f"{self.name}: Data manipulation successful")
|
||||
self.attack_stage = DataManipulationAttackStage.SUCCEEDED
|
||||
|
||||
@@ -8,7 +8,7 @@ from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
|
||||
|
||||
|
||||
class RansomwareAttackStage(IntEnum):
|
||||
@@ -73,6 +73,7 @@ class RansomwareScript(Application):
|
||||
kwargs["protocol"] = IPProtocol.NONE
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self._db_connection: Optional[DatabaseClientConnection] = None
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
@@ -256,6 +257,11 @@ class RansomwareScript(Application):
|
||||
self.num_executions += 1
|
||||
return self._application_loop()
|
||||
|
||||
def _establish_db_connection(self) -> bool:
|
||||
"""Establish a db connection to the Database Server."""
|
||||
self._db_connection = self._host_db_client.get_new_connection()
|
||||
return True if self._db_connection else False
|
||||
|
||||
def _perform_ransomware_encrypt(self):
|
||||
"""
|
||||
Execute the Ransomware Encrypt payload on the target.
|
||||
@@ -273,12 +279,11 @@ class RansomwareScript(Application):
|
||||
if self.attack_stage == RansomwareAttackStage.COMMAND_AND_CONTROL:
|
||||
if simulate_trial(self.ransomware_encrypt_p_of_success):
|
||||
self.sys_log.info(f"{self.name}: Attempting to launch payload")
|
||||
if not len(self._host_db_client.connections):
|
||||
self._host_db_client.connect()
|
||||
if len(self._host_db_client.connections):
|
||||
self._host_db_client.query(self.payload)
|
||||
if not self._db_connection:
|
||||
self._establish_db_connection()
|
||||
if self._db_connection:
|
||||
attack_successful = self._db_connection.query(self.payload)
|
||||
self.sys_log.info(f"{self.name} Payload delivered: {self.payload}")
|
||||
attack_successful = True
|
||||
if attack_successful:
|
||||
self.sys_log.info(f"{self.name}: Payload Successful")
|
||||
self.attack_stage = RansomwareAttackStage.SUCCEEDED
|
||||
|
||||
@@ -113,6 +113,7 @@ class SoftwareManager:
|
||||
:param software_name: The software name.
|
||||
"""
|
||||
if software_name in self.software:
|
||||
self.software[software_name].uninstall()
|
||||
software = self.software.pop(software_name) # noqa
|
||||
if isinstance(software, Application):
|
||||
self.node.uninstall_application(software)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
@@ -145,8 +146,16 @@ class DatabaseService(Service):
|
||||
"""Returns the database folder."""
|
||||
return self.file_system.get_folder_by_id(self.db_file.folder_id)
|
||||
|
||||
def _generate_connection_id(self) -> str:
|
||||
"""Generate a unique connection ID."""
|
||||
return str(uuid4())
|
||||
|
||||
def _process_connect(
|
||||
self, connection_id: str, password: Optional[str] = None
|
||||
self,
|
||||
src_ip: IPv4Address,
|
||||
connection_request_id: str,
|
||||
password: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Dict[str, Union[int, Dict[str, bool]]]:
|
||||
"""Process an incoming connection request.
|
||||
|
||||
@@ -158,17 +167,17 @@ class DatabaseService(Service):
|
||||
:rtype: Dict[str, Union[int, Dict[str, bool]]]
|
||||
"""
|
||||
status_code = 500 # Default internal server error
|
||||
connection_id = None
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
status_code = 503 # service unavailable
|
||||
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
|
||||
self.sys_log.error(
|
||||
f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity."
|
||||
)
|
||||
self.sys_log.error(f"{self.name}: Connect request for {src_ip=} declined. Service is at capacity.")
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
if self.password == password:
|
||||
status_code = 200 # ok
|
||||
connection_id = self._generate_connection_id()
|
||||
# try to create connection
|
||||
if not self.add_connection(connection_id=connection_id):
|
||||
if not self.add_connection(connection_id=connection_id, session_id=session_id):
|
||||
status_code = 500
|
||||
self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined")
|
||||
else:
|
||||
@@ -183,6 +192,7 @@ class DatabaseService(Service):
|
||||
"type": "connect_response",
|
||||
"response": status_code == 200,
|
||||
"connection_id": connection_id,
|
||||
"connection_request_id": connection_request_id,
|
||||
}
|
||||
|
||||
def _process_sql(
|
||||
@@ -299,19 +309,34 @@ class DatabaseService(Service):
|
||||
:return: True if the Status Code is 200, otherwise False.
|
||||
"""
|
||||
result = {"status_code": 500, "data": []}
|
||||
|
||||
# if server service is down, return error
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
if isinstance(payload, dict) and payload.get("type"):
|
||||
if payload["type"] == "connect_request":
|
||||
src_ip = kwargs.get("frame").ip.src_ip_address
|
||||
result = self._process_connect(
|
||||
connection_id=payload.get("connection_id"), password=payload.get("password")
|
||||
src_ip=src_ip,
|
||||
password=payload.get("password"),
|
||||
connection_request_id=payload.get("connection_request_id"),
|
||||
session_id=session_id,
|
||||
)
|
||||
elif payload["type"] == "disconnect":
|
||||
if payload["connection_id"] in self.connections:
|
||||
self.remove_connection(connection_id=payload["connection_id"])
|
||||
connection_id = payload["connection_id"]
|
||||
connected_ip_address = self.connections[connection_id]["ip_address"]
|
||||
frame = kwargs.get("frame")
|
||||
if connected_ip_address == frame.ip.src_ip_address:
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Received disconnect command for {connection_id=} from {connected_ip_address}"
|
||||
)
|
||||
self.terminate_connection(connection_id=payload["connection_id"], send_disconnect=False)
|
||||
else:
|
||||
self.sys_log.warning(
|
||||
f"{self.name}: Ignoring disconnect command for {connection_id=} as the command source "
|
||||
f"({frame.ip.src_ip_address}) doesn't match the connection source ({connected_ip_address})"
|
||||
)
|
||||
elif payload["type"] == "sql":
|
||||
if payload.get("connection_id") in self.connections:
|
||||
result = self._process_sql(
|
||||
|
||||
@@ -276,7 +276,7 @@ class FTPClient(FTPServiceABC):
|
||||
|
||||
# if QUIT succeeded, remove the session from active connection list
|
||||
if payload.ftp_command is FTPCommand.QUIT and payload.status_code is FTPStatusCode.OK:
|
||||
self.remove_connection(connection_id=session_id)
|
||||
self.terminate_connection(connection_id=session_id)
|
||||
|
||||
self.sys_log.info(f"{self.name}: Received FTP Response {payload.ftp_command.name} {payload.status_code.value}")
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ class FTPServer(FTPServiceABC):
|
||||
return payload
|
||||
|
||||
if payload.ftp_command == FTPCommand.QUIT:
|
||||
self.remove_connection(connection_id=session_id)
|
||||
self.terminate_connection(connection_id=session_id)
|
||||
payload.status_code = FTPStatusCode.OK
|
||||
return payload
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from primaite.simulator.network.protocols.http import (
|
||||
)
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClientConnection
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
|
||||
@@ -48,6 +48,7 @@ class WebServer(Service):
|
||||
super().__init__(**kwargs)
|
||||
self._install_web_files()
|
||||
self.start()
|
||||
self.db_connection: Optional[DatabaseClientConnection] = None
|
||||
|
||||
def _install_web_files(self):
|
||||
"""
|
||||
@@ -108,9 +109,11 @@ class WebServer(Service):
|
||||
|
||||
if path.startswith("users"):
|
||||
# get data from DatabaseServer
|
||||
db_client: DatabaseClient = self.software_manager.software.get("DatabaseClient")
|
||||
# get all users
|
||||
if db_client.query("SELECT"):
|
||||
if not self.db_connection:
|
||||
self._establish_db_connection()
|
||||
|
||||
if self.db_connection.query("SELECT"):
|
||||
# query succeeded
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
response.status_code = HttpStatusCode.OK
|
||||
@@ -123,6 +126,11 @@ class WebServer(Service):
|
||||
response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR
|
||||
return response
|
||||
|
||||
def _establish_db_connection(self) -> None:
|
||||
"""Establish a connection to db."""
|
||||
db_client = self.software_manager.software.get("DatabaseClient")
|
||||
self.db_connection: DatabaseClientConnection = db_client.get_new_connection()
|
||||
|
||||
def send(
|
||||
self,
|
||||
payload: HttpResponsePacket,
|
||||
|
||||
@@ -5,6 +5,8 @@ from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.file_system.file_system import FileSystem, Folder
|
||||
@@ -298,7 +300,7 @@ class IOSoftware(Software):
|
||||
"""Return the public version of connections."""
|
||||
return copy.copy(self._connections)
|
||||
|
||||
def add_connection(self, connection_id: str, session_id: Optional[str] = None) -> bool:
|
||||
def add_connection(self, connection_id: Union[str, int], session_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Create a new connection to this service.
|
||||
|
||||
@@ -323,6 +325,7 @@ class IOSoftware(Software):
|
||||
if session_id:
|
||||
session_details = self._get_session_details(session_id)
|
||||
self._connections[connection_id] = {
|
||||
"session_id": session_id,
|
||||
"ip_address": session_details.with_ip_address if session_details else None,
|
||||
"time": datetime.now(),
|
||||
}
|
||||
@@ -334,19 +337,41 @@ class IOSoftware(Software):
|
||||
)
|
||||
return False
|
||||
|
||||
def remove_connection(self, connection_id: str) -> bool:
|
||||
def terminate_connection(self, connection_id: str, send_disconnect: bool = True) -> bool:
|
||||
"""
|
||||
Remove a connection from this service.
|
||||
Terminates a connection from this service.
|
||||
|
||||
Returns true if connection successfully removed
|
||||
|
||||
:param: connection_id: UUID of the connection to create
|
||||
:param send_disconnect: If True, sends a disconnect payload to the ip address of the associated session.
|
||||
:type: string
|
||||
"""
|
||||
if self.connections.get(connection_id):
|
||||
self._connections.pop(connection_id)
|
||||
self.sys_log.info(f"{self.name}: Connection {connection_id=} closed.")
|
||||
return True
|
||||
connection_dict = self._connections.pop(connection_id)
|
||||
if send_disconnect:
|
||||
self.software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "disconnect", "connection_id": connection_id},
|
||||
session_id=connection_dict["session_id"],
|
||||
)
|
||||
self.sys_log.info(f"{self.name}: Connection {connection_id=} terminated")
|
||||
return True
|
||||
return False
|
||||
|
||||
def show_connections(self, markdown: bool = False):
|
||||
"""
|
||||
Display the connections in tabular format.
|
||||
|
||||
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
|
||||
"""
|
||||
table = PrettyTable(["IP Address", "Connection ID", "Creation Timestamp"])
|
||||
if markdown:
|
||||
table.set_style(MARKDOWN)
|
||||
table.align = "l"
|
||||
table.title = f"{self.sys_log.hostname} {self.name} Connections"
|
||||
for connection_id, data in self.connections.items():
|
||||
table.add_row([data["ip_address"], connection_id, str(data["time"])])
|
||||
print(table.get_string(sortby="Creation Timestamp"))
|
||||
|
||||
def clear_connections(self):
|
||||
"""Clears all the connections from the software."""
|
||||
|
||||
Reference in New Issue
Block a user