#2059: implementing the service connections limit
This commit is contained in:
@@ -252,9 +252,9 @@ def arcd_uc2_network() -> Network:
|
||||
database_service: DatabaseService = database_server.software_manager.software.get("DatabaseService") # noqa
|
||||
database_service.start()
|
||||
database_service.configure_backup(backup_server=IPv4Address("192.168.1.16"))
|
||||
database_service._process_sql(ddl, None) # noqa
|
||||
database_service._process_sql(ddl, None, None) # noqa
|
||||
for insert_statement in user_insert_statements:
|
||||
database_service._process_sql(insert_statement, None) # noqa
|
||||
database_service._process_sql(insert_statement, None, None) # noqa
|
||||
|
||||
# Web Server
|
||||
web_server = Server(
|
||||
|
||||
@@ -23,7 +23,7 @@ class DatabaseClient(Application):
|
||||
|
||||
server_ip_address: Optional[IPv4Address] = None
|
||||
server_password: Optional[str] = None
|
||||
connected: bool = False
|
||||
connections: Dict[str, Dict] = {}
|
||||
_query_success_tracker: Dict[str, bool] = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -66,18 +66,24 @@ class DatabaseClient(Application):
|
||||
self.server_password = server_password
|
||||
self.sys_log.info(f"{self.name}: Configured the {self.name} with {server_ip_address=}, {server_password=}.")
|
||||
|
||||
def connect(self) -> bool:
|
||||
def connect(self, connection_id: Optional[str] = None) -> bool:
|
||||
"""Connect to a Database Service."""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
if not self.connected:
|
||||
return self._connect(self.server_ip_address, self.server_password)
|
||||
# already connected
|
||||
return True
|
||||
if not connection_id:
|
||||
connection_id = str(uuid4())
|
||||
|
||||
return self._connect(
|
||||
server_ip_address=self.server_ip_address, password=self.server_password, connection_id=connection_id
|
||||
)
|
||||
|
||||
def _connect(
|
||||
self, server_ip_address: IPv4Address, password: Optional[str] = None, is_reattempt: bool = False
|
||||
self,
|
||||
server_ip_address: IPv4Address,
|
||||
connection_id: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
is_reattempt: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Connects the DatabaseClient to the DatabaseServer.
|
||||
@@ -92,33 +98,58 @@ class DatabaseClient(Application):
|
||||
:type: is_reattempt: Optional[bool]
|
||||
"""
|
||||
if is_reattempt:
|
||||
if self.connected:
|
||||
self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} authorised")
|
||||
if self.connections.get(connection_id):
|
||||
self.sys_log.info(
|
||||
f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} authorised"
|
||||
)
|
||||
self.server_ip_address = server_ip_address
|
||||
return self.connected
|
||||
return True
|
||||
else:
|
||||
self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} declined")
|
||||
self.sys_log.info(
|
||||
f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} declined"
|
||||
)
|
||||
return False
|
||||
payload = {"type": "connect_request", "password": password}
|
||||
payload = {
|
||||
"type": "connect_request",
|
||||
"password": password,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_ip_address=server_ip_address, dest_port=self.port
|
||||
)
|
||||
return self._connect(server_ip_address, password, True)
|
||||
return self._connect(
|
||||
server_ip_address=server_ip_address, password=password, connection_id=connection_id, is_reattempt=True
|
||||
)
|
||||
|
||||
def disconnect(self):
|
||||
def disconnect(self, connection_id: Optional[str] = None) -> bool:
|
||||
"""Disconnect from the Database Service."""
|
||||
if self.connected and self.operating_state is ApplicationOperatingState.RUNNING:
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "disconnect"}, dest_ip_address=self.server_ip_address, dest_port=self.port
|
||||
)
|
||||
if not self._can_perform_action():
|
||||
self.sys_log.error(f"Unable to disconnect - {self.name} is {self.operating_state.name}")
|
||||
return False
|
||||
|
||||
self.sys_log.info(f"{self.name}: DatabaseClient disconnected from {self.server_ip_address}")
|
||||
self.server_ip_address = None
|
||||
self.connected = False
|
||||
# if there are no connections - nothing to disconnect
|
||||
if not len(self.connections):
|
||||
self.sys_log.error(f"Unable to disconnect - {self.name} has no active connections.")
|
||||
return False
|
||||
|
||||
def _query(self, sql: str, query_id: str, is_reattempt: bool = False) -> bool:
|
||||
# if no connection provided, disconnect the first connection
|
||||
if not connection_id:
|
||||
connection_id = list(self.connections.keys())[0]
|
||||
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "disconnect", "connection_id": connection_id},
|
||||
dest_ip_address=self.server_ip_address,
|
||||
dest_port=self.port,
|
||||
)
|
||||
self.connections.pop(connection_id)
|
||||
|
||||
self.sys_log.info(
|
||||
f"{self.name}: DatabaseClient disconnected connection {connection_id} from {self.server_ip_address}"
|
||||
)
|
||||
|
||||
def _query(self, sql: str, query_id: str, connection_id: str, is_reattempt: bool = False) -> bool:
|
||||
"""
|
||||
Send a query to the connected database server.
|
||||
|
||||
@@ -141,11 +172,11 @@ class DatabaseClient(Application):
|
||||
else:
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "sql", "sql": sql, "uuid": query_id},
|
||||
payload={"type": "sql", "sql": sql, "uuid": query_id, "connection_id": connection_id},
|
||||
dest_ip_address=self.server_ip_address,
|
||||
dest_port=self.port,
|
||||
)
|
||||
return self._query(sql=sql, query_id=query_id, is_reattempt=True)
|
||||
return self._query(sql=sql, query_id=query_id, connection_id=connection_id, is_reattempt=True)
|
||||
|
||||
def run(self) -> None:
|
||||
"""Run the DatabaseClient."""
|
||||
@@ -153,7 +184,7 @@ class DatabaseClient(Application):
|
||||
if self.operating_state == ApplicationOperatingState.RUNNING:
|
||||
self.connect()
|
||||
|
||||
def query(self, sql: str, is_reattempt: bool = False) -> bool:
|
||||
def query(self, sql: str, connection_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Send a query to the Database Service.
|
||||
|
||||
@@ -164,20 +195,17 @@ class DatabaseClient(Application):
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
if self.connected:
|
||||
query_id = str(uuid4())
|
||||
if connection_id is None:
|
||||
connection_id = str(uuid4())
|
||||
|
||||
if not self.connections.get(connection_id):
|
||||
if not self.connect(connection_id=connection_id):
|
||||
return False
|
||||
|
||||
# Initialise the tracker of this ID to False
|
||||
self._query_success_tracker[query_id] = False
|
||||
return self._query(sql=sql, query_id=query_id)
|
||||
else:
|
||||
if is_reattempt:
|
||||
return False
|
||||
|
||||
if not self.connect():
|
||||
return False
|
||||
|
||||
self.query(sql=sql, is_reattempt=True)
|
||||
uuid = str(uuid4())
|
||||
self._query_success_tracker[uuid] = False
|
||||
return self._query(sql=sql, query_id=uuid, connection_id=connection_id)
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
"""
|
||||
@@ -192,13 +220,12 @@ class DatabaseClient(Application):
|
||||
|
||||
if isinstance(payload, dict) and payload.get("type"):
|
||||
if payload["type"] == "connect_response":
|
||||
self.connected = payload["response"] == True
|
||||
if payload["response"] is True:
|
||||
self.connections[payload.get("connection_id")] = payload
|
||||
elif payload["type"] == "sql":
|
||||
query_id = payload.get("uuid")
|
||||
status_code = payload.get("status_code")
|
||||
self._query_success_tracker[query_id] = status_code == 200
|
||||
if self._query_success_tracker[query_id]:
|
||||
_LOGGER.debug(f"Received payload {payload}")
|
||||
else:
|
||||
self.connected = False
|
||||
return True
|
||||
|
||||
@@ -149,9 +149,9 @@ class DataManipulationBot(DatabaseClient):
|
||||
if simulate_trial(p_of_success):
|
||||
self.sys_log.info(f"{self.name}: Performing data manipulation")
|
||||
# perform the attack
|
||||
if not self.connected:
|
||||
if not len(self.connections):
|
||||
self.connect()
|
||||
if self.connected:
|
||||
if len(self.connections):
|
||||
self.query(self.payload)
|
||||
self.sys_log.info(f"{self.name} payload delivered: {self.payload}")
|
||||
attack_successful = True
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from datetime import datetime
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
@@ -22,7 +21,6 @@ class DatabaseService(Service):
|
||||
"""
|
||||
|
||||
password: Optional[str] = None
|
||||
connections: Dict[str, datetime] = {}
|
||||
|
||||
backup_server: IPv4Address = None
|
||||
"""IP address of the backup server."""
|
||||
@@ -140,7 +138,7 @@ class DatabaseService(Service):
|
||||
self.folder = self.file_system.get_folder_by_id(self._db_file.folder_id)
|
||||
|
||||
def _process_connect(
|
||||
self, session_id: str, password: Optional[str] = None
|
||||
self, connection_id: str, password: Optional[str] = None
|
||||
) -> Dict[str, Union[int, Dict[str, bool]]]:
|
||||
status_code = 500 # Default internal server error
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
@@ -148,16 +146,27 @@ class DatabaseService(Service):
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
if self.password == password:
|
||||
status_code = 200 # ok
|
||||
self.connections[session_id] = datetime.now()
|
||||
self.sys_log.info(f"{self.name}: Connect request for {session_id=} authorised")
|
||||
# try to create connection
|
||||
if not self.add_connection(connection_id=connection_id):
|
||||
status_code = 500
|
||||
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} declined")
|
||||
else:
|
||||
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
|
||||
else:
|
||||
status_code = 401 # Unauthorised
|
||||
self.sys_log.info(f"{self.name}: Connect request for {session_id=} declined")
|
||||
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} declined")
|
||||
else:
|
||||
status_code = 404 # service not found
|
||||
return {"status_code": status_code, "type": "connect_response", "response": status_code == 200}
|
||||
return {
|
||||
"status_code": status_code,
|
||||
"type": "connect_response",
|
||||
"response": status_code == 200,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
|
||||
def _process_sql(self, query: Literal["SELECT", "DELETE"], query_id: str) -> Dict[str, Union[int, List[Any]]]:
|
||||
def _process_sql(
|
||||
self, query: Literal["SELECT", "DELETE"], query_id: str, connection_id: Optional[str] = None
|
||||
) -> Dict[str, Union[int, List[Any]]]:
|
||||
"""
|
||||
Executes the given SQL query and returns the result.
|
||||
|
||||
@@ -169,15 +178,28 @@ class DatabaseService(Service):
|
||||
:return: Dictionary containing status code and data fetched.
|
||||
"""
|
||||
self.sys_log.info(f"{self.name}: Running {query}")
|
||||
|
||||
if query == "SELECT":
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
return {"status_code": 200, "type": "sql", "data": True, "uuid": query_id}
|
||||
return {
|
||||
"status_code": 200,
|
||||
"type": "sql",
|
||||
"data": True,
|
||||
"uuid": query_id,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
else:
|
||||
return {"status_code": 404, "data": False}
|
||||
elif query == "DELETE":
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
self.health_state_actual = SoftwareHealthState.COMPROMISED
|
||||
return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id}
|
||||
return {
|
||||
"status_code": 200,
|
||||
"type": "sql",
|
||||
"data": False,
|
||||
"uuid": query_id,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
else:
|
||||
return {"status_code": 404, "data": False}
|
||||
else:
|
||||
@@ -207,15 +229,24 @@ class DatabaseService(Service):
|
||||
return False
|
||||
|
||||
result = {"status_code": 500, "data": []}
|
||||
|
||||
# if server service is down, return error
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
if isinstance(payload, dict) and payload.get("type"):
|
||||
if payload["type"] == "connect_request":
|
||||
result = self._process_connect(session_id=session_id, password=payload.get("password"))
|
||||
result = self._process_connect(
|
||||
connection_id=payload.get("connection_id"), password=payload.get("password")
|
||||
)
|
||||
elif payload["type"] == "disconnect":
|
||||
if session_id in self.connections:
|
||||
self.connections.pop(session_id)
|
||||
if payload["connection_id"] in self.connections:
|
||||
self.remove_connection(connection_id=payload["connection_id"])
|
||||
elif payload["type"] == "sql":
|
||||
if session_id in self.connections:
|
||||
result = self._process_sql(query=payload["sql"], query_id=payload["uuid"])
|
||||
if payload.get("connection_id") in self.connections:
|
||||
result = self._process_sql(
|
||||
query=payload["sql"], query_id=payload["uuid"], connection_id=payload["connection_id"]
|
||||
)
|
||||
else:
|
||||
result = {"status_code": 401, "type": "sql"}
|
||||
self.send(payload=result, session_id=session_id)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
@@ -21,9 +20,6 @@ class FTPServer(FTPServiceABC):
|
||||
server_password: Optional[str] = None
|
||||
"""Password needed to connect to FTP server. Default is None."""
|
||||
|
||||
connections: Dict[str, IPv4Address] = {}
|
||||
"""Current active connections to the FTP server."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "FTPServer"
|
||||
kwargs["port"] = Port.FTP
|
||||
@@ -62,9 +58,6 @@ class FTPServer(FTPServiceABC):
|
||||
|
||||
self.sys_log.info(f"{self.name}: Received FTP {payload.ftp_command.name} {payload.ftp_command_args}")
|
||||
|
||||
if session_id:
|
||||
session_details = self._get_session_details(session_id)
|
||||
|
||||
if payload.ftp_command is not None:
|
||||
self.sys_log.info(f"Received FTP {payload.ftp_command.name} command.")
|
||||
|
||||
@@ -73,7 +66,7 @@ class FTPServer(FTPServiceABC):
|
||||
# check that the port is valid
|
||||
if isinstance(payload.ftp_command_args, Port) and payload.ftp_command_args.value in range(0, 65535):
|
||||
# return successful connection
|
||||
self.connections[session_id] = session_details.with_ip_address
|
||||
self.add_connection(connection_id=session_id, session_id=session_id)
|
||||
payload.status_code = FTPStatusCode.OK
|
||||
return payload
|
||||
|
||||
@@ -81,7 +74,7 @@ class FTPServer(FTPServiceABC):
|
||||
return payload
|
||||
|
||||
if payload.ftp_command == FTPCommand.QUIT:
|
||||
self.connections.pop(session_id)
|
||||
self.remove_connection(connection_id=session_id)
|
||||
payload.status_code = FTPStatusCode.OK
|
||||
return payload
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import copy
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@@ -40,6 +42,15 @@ class Service(IOSoftware):
|
||||
restart_countdown: Optional[int] = None
|
||||
"If currently restarting, how many timesteps remain until the restart is finished."
|
||||
|
||||
_connections: Dict[str, Dict] = {}
|
||||
"Active connections to the Service."
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.health_state_visible = SoftwareHealthState.UNUSED
|
||||
self.health_state_actual = SoftwareHealthState.UNUSED
|
||||
|
||||
def _can_perform_action(self) -> bool:
|
||||
"""
|
||||
Checks if the service can perform actions.
|
||||
@@ -74,12 +85,6 @@ class Service(IOSoftware):
|
||||
"""
|
||||
return super().receive(payload=payload, session_id=session_id, **kwargs)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.health_state_visible = SoftwareHealthState.UNUSED
|
||||
self.health_state_actual = SoftwareHealthState.UNUSED
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
@@ -98,6 +103,11 @@ class Service(IOSoftware):
|
||||
rm.add_request("enable", RequestType(func=lambda request, context: self.enable()))
|
||||
return rm
|
||||
|
||||
@property
|
||||
def connections(self) -> Dict[str, Dict]:
|
||||
"""Return the public version of connections."""
|
||||
return copy.copy(self._connections)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
@@ -113,6 +123,56 @@ class Service(IOSoftware):
|
||||
state["health_state_visible"] = self.health_state_visible.value
|
||||
return state
|
||||
|
||||
def add_connection(self, connection_id: str, session_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Create a new connection to this service.
|
||||
|
||||
Returns true if connection successfully created
|
||||
|
||||
:param: connection_id: UUID of the connection to create
|
||||
:type: string
|
||||
"""
|
||||
# if over or at capacity, set to overwhelmed
|
||||
if len(self._connections) >= self.max_sessions:
|
||||
self.health_state_actual = SoftwareHealthState.OVERWHELMED
|
||||
self.sys_log.error(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.")
|
||||
return False
|
||||
else:
|
||||
# if service was previously overwhelmed, set to good because there is enough space for connections
|
||||
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
|
||||
self.health_state_actual = SoftwareHealthState.GOOD
|
||||
|
||||
# check that connection already doesn't exist
|
||||
if not self._connections.get(connection_id):
|
||||
session_details = None
|
||||
if session_id:
|
||||
session_details = self._get_session_details(session_id)
|
||||
self._connections[connection_id] = {
|
||||
"ip_address": session_details.with_ip_address if session_details else None,
|
||||
"time": datetime.now(),
|
||||
}
|
||||
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
|
||||
return True
|
||||
# connection with given id already exists
|
||||
self.sys_log.error(
|
||||
f"{self.name}: Connect request for {connection_id=} declined. Connection already exists."
|
||||
)
|
||||
return False
|
||||
|
||||
def remove_connection(self, connection_id: str) -> bool:
|
||||
"""
|
||||
Remove a connection from this service.
|
||||
|
||||
Returns true if connection successfully removed
|
||||
|
||||
:param: connection_id: UUID of the connection to create
|
||||
:type: string
|
||||
"""
|
||||
if self._connections.get(connection_id):
|
||||
self._connections.pop(connection_id)
|
||||
self.sys_log.info(f"{self.name}: Connection {connection_id=} closed.")
|
||||
return True
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the service."""
|
||||
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
|
||||
|
||||
@@ -198,7 +198,7 @@ class IOSoftware(Software):
|
||||
|
||||
installing_count: int = 0
|
||||
"The number of times the software has been installed. Default is 0."
|
||||
max_sessions: int = 1
|
||||
max_sessions: int = 100
|
||||
"The maximum number of sessions that the software can handle simultaneously. Default is 0."
|
||||
tcp: bool = True
|
||||
"Indicates if the software uses TCP protocol for communication. Default is True."
|
||||
|
||||
Reference in New Issue
Block a user