From cd5ed48b007c0b4e8304dd75f861698b488337bb Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Fri, 8 Dec 2023 17:07:57 +0000 Subject: [PATCH] #2059: implementing the service connections limit --- src/primaite/simulator/network/networks.py | 4 +- .../system/applications/database_client.py | 109 ++++++++++------ .../red_applications/data_manipulation_bot.py | 4 +- .../services/database/database_service.py | 61 ++++++--- .../system/services/ftp/ftp_server.py | 13 +- .../simulator/system/services/service.py | 72 ++++++++++- src/primaite/simulator/system/software.py | 2 +- .../system/test_database_on_node.py | 121 +++++++++++++----- .../test_data_manipulation_bot.py | 2 +- .../_applications/test_database_client.py | 19 +-- 10 files changed, 280 insertions(+), 127 deletions(-) diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index 61ec7baf..630846b3 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -252,9 +252,9 @@ def arcd_uc2_network() -> Network: database_service: DatabaseService = database_server.software_manager.software.get("DatabaseService") # noqa database_service.start() database_service.configure_backup(backup_server=IPv4Address("192.168.1.16")) - database_service._process_sql(ddl, None) # noqa + database_service._process_sql(ddl, None, None) # noqa for insert_statement in user_insert_statements: - database_service._process_sql(insert_statement, None) # noqa + database_service._process_sql(insert_statement, None, None) # noqa # Web Server web_server = Server( diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index f57246fc..9d7bfcaa 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -23,7 +23,7 @@ class DatabaseClient(Application): server_ip_address: Optional[IPv4Address] = None server_password: Optional[str] = None - connected: bool = False + connections: Dict[str, Dict] = {} _query_success_tracker: Dict[str, bool] = {} def __init__(self, **kwargs): @@ -66,18 +66,24 @@ class DatabaseClient(Application): self.server_password = server_password self.sys_log.info(f"{self.name}: Configured the {self.name} with {server_ip_address=}, {server_password=}.") - def connect(self) -> bool: + def connect(self, connection_id: Optional[str] = None) -> bool: """Connect to a Database Service.""" if not self._can_perform_action(): return False - if not self.connected: - return self._connect(self.server_ip_address, self.server_password) - # already connected - return True + if not connection_id: + connection_id = str(uuid4()) + + return self._connect( + server_ip_address=self.server_ip_address, password=self.server_password, connection_id=connection_id + ) def _connect( - self, server_ip_address: IPv4Address, password: Optional[str] = None, is_reattempt: bool = False + self, + server_ip_address: IPv4Address, + connection_id: Optional[str] = None, + password: Optional[str] = None, + is_reattempt: bool = False, ) -> bool: """ Connects the DatabaseClient to the DatabaseServer. @@ -92,33 +98,58 @@ class DatabaseClient(Application): :type: is_reattempt: Optional[bool] """ if is_reattempt: - if self.connected: - self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} authorised") + if self.connections.get(connection_id): + self.sys_log.info( + f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} authorised" + ) self.server_ip_address = server_ip_address - return self.connected + return True else: - self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} declined") + self.sys_log.info( + f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} declined" + ) return False - payload = {"type": "connect_request", "password": password} + payload = { + "type": "connect_request", + "password": password, + "connection_id": connection_id, + } software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( payload=payload, dest_ip_address=server_ip_address, dest_port=self.port ) - return self._connect(server_ip_address, password, True) + return self._connect( + server_ip_address=server_ip_address, password=password, connection_id=connection_id, is_reattempt=True + ) - def disconnect(self): + def disconnect(self, connection_id: Optional[str] = None) -> bool: """Disconnect from the Database Service.""" - if self.connected and self.operating_state is ApplicationOperatingState.RUNNING: - software_manager: SoftwareManager = self.software_manager - software_manager.send_payload_to_session_manager( - payload={"type": "disconnect"}, dest_ip_address=self.server_ip_address, dest_port=self.port - ) + if not self._can_perform_action(): + self.sys_log.error(f"Unable to disconnect - {self.name} is {self.operating_state.name}") + return False - self.sys_log.info(f"{self.name}: DatabaseClient disconnected from {self.server_ip_address}") - self.server_ip_address = None - self.connected = False + # if there are no connections - nothing to disconnect + if not len(self.connections): + self.sys_log.error(f"Unable to disconnect - {self.name} has no active connections.") + return False - def _query(self, sql: str, query_id: str, is_reattempt: bool = False) -> bool: + # if no connection provided, disconnect the first connection + if not connection_id: + connection_id = list(self.connections.keys())[0] + + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( + payload={"type": "disconnect", "connection_id": connection_id}, + dest_ip_address=self.server_ip_address, + dest_port=self.port, + ) + self.connections.pop(connection_id) + + self.sys_log.info( + f"{self.name}: DatabaseClient disconnected connection {connection_id} from {self.server_ip_address}" + ) + + def _query(self, sql: str, query_id: str, connection_id: str, is_reattempt: bool = False) -> bool: """ Send a query to the connected database server. @@ -141,11 +172,11 @@ class DatabaseClient(Application): else: software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( - payload={"type": "sql", "sql": sql, "uuid": query_id}, + payload={"type": "sql", "sql": sql, "uuid": query_id, "connection_id": connection_id}, dest_ip_address=self.server_ip_address, dest_port=self.port, ) - return self._query(sql=sql, query_id=query_id, is_reattempt=True) + return self._query(sql=sql, query_id=query_id, connection_id=connection_id, is_reattempt=True) def run(self) -> None: """Run the DatabaseClient.""" @@ -153,7 +184,7 @@ class DatabaseClient(Application): if self.operating_state == ApplicationOperatingState.RUNNING: self.connect() - def query(self, sql: str, is_reattempt: bool = False) -> bool: + def query(self, sql: str, connection_id: Optional[str] = None) -> bool: """ Send a query to the Database Service. @@ -164,20 +195,17 @@ class DatabaseClient(Application): if not self._can_perform_action(): return False - if self.connected: - query_id = str(uuid4()) + if connection_id is None: + connection_id = str(uuid4()) + + if not self.connections.get(connection_id): + if not self.connect(connection_id=connection_id): + return False # Initialise the tracker of this ID to False - self._query_success_tracker[query_id] = False - return self._query(sql=sql, query_id=query_id) - else: - if is_reattempt: - return False - - if not self.connect(): - return False - - self.query(sql=sql, is_reattempt=True) + uuid = str(uuid4()) + self._query_success_tracker[uuid] = False + return self._query(sql=sql, query_id=uuid, connection_id=connection_id) def receive(self, payload: Any, session_id: str, **kwargs) -> bool: """ @@ -192,13 +220,12 @@ class DatabaseClient(Application): if isinstance(payload, dict) and payload.get("type"): if payload["type"] == "connect_response": - self.connected = payload["response"] == True + if payload["response"] is True: + self.connections[payload.get("connection_id")] = payload elif payload["type"] == "sql": query_id = payload.get("uuid") status_code = payload.get("status_code") self._query_success_tracker[query_id] = status_code == 200 if self._query_success_tracker[query_id]: _LOGGER.debug(f"Received payload {payload}") - else: - self.connected = False return True diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index 44a56cf1..87959e9b 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -149,9 +149,9 @@ class DataManipulationBot(DatabaseClient): if simulate_trial(p_of_success): self.sys_log.info(f"{self.name}: Performing data manipulation") # perform the attack - if not self.connected: + if not len(self.connections): self.connect() - if self.connected: + if len(self.connections): self.query(self.payload) self.sys_log.info(f"{self.name} payload delivered: {self.payload}") attack_successful = True diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 61cf1560..70a4e6cc 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -1,4 +1,3 @@ -from datetime import datetime from ipaddress import IPv4Address from typing import Any, Dict, List, Literal, Optional, Union @@ -22,7 +21,6 @@ class DatabaseService(Service): """ password: Optional[str] = None - connections: Dict[str, datetime] = {} backup_server: IPv4Address = None """IP address of the backup server.""" @@ -140,7 +138,7 @@ class DatabaseService(Service): self.folder = self.file_system.get_folder_by_id(self._db_file.folder_id) def _process_connect( - self, session_id: str, password: Optional[str] = None + self, connection_id: str, password: Optional[str] = None ) -> Dict[str, Union[int, Dict[str, bool]]]: status_code = 500 # Default internal server error if self.operating_state == ServiceOperatingState.RUNNING: @@ -148,16 +146,27 @@ class DatabaseService(Service): if self.health_state_actual == SoftwareHealthState.GOOD: if self.password == password: status_code = 200 # ok - self.connections[session_id] = datetime.now() - self.sys_log.info(f"{self.name}: Connect request for {session_id=} authorised") + # try to create connection + if not self.add_connection(connection_id=connection_id): + status_code = 500 + self.sys_log.info(f"{self.name}: Connect request for {connection_id=} declined") + else: + self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised") else: status_code = 401 # Unauthorised - self.sys_log.info(f"{self.name}: Connect request for {session_id=} declined") + self.sys_log.info(f"{self.name}: Connect request for {connection_id=} declined") else: status_code = 404 # service not found - return {"status_code": status_code, "type": "connect_response", "response": status_code == 200} + return { + "status_code": status_code, + "type": "connect_response", + "response": status_code == 200, + "connection_id": connection_id, + } - def _process_sql(self, query: Literal["SELECT", "DELETE"], query_id: str) -> Dict[str, Union[int, List[Any]]]: + def _process_sql( + self, query: Literal["SELECT", "DELETE"], query_id: str, connection_id: Optional[str] = None + ) -> Dict[str, Union[int, List[Any]]]: """ Executes the given SQL query and returns the result. @@ -169,15 +178,28 @@ class DatabaseService(Service): :return: Dictionary containing status code and data fetched. """ self.sys_log.info(f"{self.name}: Running {query}") + if query == "SELECT": if self.health_state_actual == SoftwareHealthState.GOOD: - return {"status_code": 200, "type": "sql", "data": True, "uuid": query_id} + return { + "status_code": 200, + "type": "sql", + "data": True, + "uuid": query_id, + "connection_id": connection_id, + } else: return {"status_code": 404, "data": False} elif query == "DELETE": if self.health_state_actual == SoftwareHealthState.GOOD: self.health_state_actual = SoftwareHealthState.COMPROMISED - return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id} + return { + "status_code": 200, + "type": "sql", + "data": False, + "uuid": query_id, + "connection_id": connection_id, + } else: return {"status_code": 404, "data": False} else: @@ -207,15 +229,24 @@ class DatabaseService(Service): return False result = {"status_code": 500, "data": []} + + # if server service is down, return error + if not self._can_perform_action(): + return False + if isinstance(payload, dict) and payload.get("type"): if payload["type"] == "connect_request": - result = self._process_connect(session_id=session_id, password=payload.get("password")) + result = self._process_connect( + connection_id=payload.get("connection_id"), password=payload.get("password") + ) elif payload["type"] == "disconnect": - if session_id in self.connections: - self.connections.pop(session_id) + if payload["connection_id"] in self.connections: + self.remove_connection(connection_id=payload["connection_id"]) elif payload["type"] == "sql": - if session_id in self.connections: - result = self._process_sql(query=payload["sql"], query_id=payload["uuid"]) + if payload.get("connection_id") in self.connections: + result = self._process_sql( + query=payload["sql"], query_id=payload["uuid"], connection_id=payload["connection_id"] + ) else: result = {"status_code": 401, "type": "sql"} self.send(payload=result, session_id=session_id) diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index 0278b616..6e6c1a48 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -1,5 +1,4 @@ -from ipaddress import IPv4Address -from typing import Any, Dict, Optional +from typing import Any, Optional from primaite import getLogger from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode @@ -21,9 +20,6 @@ class FTPServer(FTPServiceABC): server_password: Optional[str] = None """Password needed to connect to FTP server. Default is None.""" - connections: Dict[str, IPv4Address] = {} - """Current active connections to the FTP server.""" - def __init__(self, **kwargs): kwargs["name"] = "FTPServer" kwargs["port"] = Port.FTP @@ -62,9 +58,6 @@ class FTPServer(FTPServiceABC): self.sys_log.info(f"{self.name}: Received FTP {payload.ftp_command.name} {payload.ftp_command_args}") - if session_id: - session_details = self._get_session_details(session_id) - if payload.ftp_command is not None: self.sys_log.info(f"Received FTP {payload.ftp_command.name} command.") @@ -73,7 +66,7 @@ class FTPServer(FTPServiceABC): # check that the port is valid if isinstance(payload.ftp_command_args, Port) and payload.ftp_command_args.value in range(0, 65535): # return successful connection - self.connections[session_id] = session_details.with_ip_address + self.add_connection(connection_id=session_id, session_id=session_id) payload.status_code = FTPStatusCode.OK return payload @@ -81,7 +74,7 @@ class FTPServer(FTPServiceABC): return payload if payload.ftp_command == FTPCommand.QUIT: - self.connections.pop(session_id) + self.remove_connection(connection_id=session_id) payload.status_code = FTPStatusCode.OK return payload diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index e60b7700..52187e51 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -1,3 +1,5 @@ +import copy +from datetime import datetime from enum import Enum from typing import Any, Dict, Optional @@ -40,6 +42,15 @@ class Service(IOSoftware): restart_countdown: Optional[int] = None "If currently restarting, how many timesteps remain until the restart is finished." + _connections: Dict[str, Dict] = {} + "Active connections to the Service." + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.health_state_visible = SoftwareHealthState.UNUSED + self.health_state_actual = SoftwareHealthState.UNUSED + def _can_perform_action(self) -> bool: """ Checks if the service can perform actions. @@ -74,12 +85,6 @@ class Service(IOSoftware): """ return super().receive(payload=payload, session_id=session_id, **kwargs) - def __init__(self, **kwargs): - super().__init__(**kwargs) - - self.health_state_visible = SoftwareHealthState.UNUSED - self.health_state_actual = SoftwareHealthState.UNUSED - def set_original_state(self): """Sets the original state.""" super().set_original_state() @@ -98,6 +103,11 @@ class Service(IOSoftware): rm.add_request("enable", RequestType(func=lambda request, context: self.enable())) return rm + @property + def connections(self) -> Dict[str, Dict]: + """Return the public version of connections.""" + return copy.copy(self._connections) + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -113,6 +123,56 @@ class Service(IOSoftware): state["health_state_visible"] = self.health_state_visible.value return state + def add_connection(self, connection_id: str, session_id: Optional[str] = None) -> bool: + """ + Create a new connection to this service. + + Returns true if connection successfully created + + :param: connection_id: UUID of the connection to create + :type: string + """ + # if over or at capacity, set to overwhelmed + if len(self._connections) >= self.max_sessions: + self.health_state_actual = SoftwareHealthState.OVERWHELMED + self.sys_log.error(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.") + return False + else: + # if service was previously overwhelmed, set to good because there is enough space for connections + if self.health_state_actual == SoftwareHealthState.OVERWHELMED: + self.health_state_actual = SoftwareHealthState.GOOD + + # check that connection already doesn't exist + if not self._connections.get(connection_id): + session_details = None + if session_id: + session_details = self._get_session_details(session_id) + self._connections[connection_id] = { + "ip_address": session_details.with_ip_address if session_details else None, + "time": datetime.now(), + } + self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised") + return True + # connection with given id already exists + self.sys_log.error( + f"{self.name}: Connect request for {connection_id=} declined. Connection already exists." + ) + return False + + def remove_connection(self, connection_id: str) -> bool: + """ + Remove a connection from this service. + + Returns true if connection successfully removed + + :param: connection_id: UUID of the connection to create + :type: string + """ + if self._connections.get(connection_id): + self._connections.pop(connection_id) + self.sys_log.info(f"{self.name}: Connection {connection_id=} closed.") + return True + def stop(self) -> None: """Stop the service.""" if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]: diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 87802a7b..8746bdf3 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -198,7 +198,7 @@ class IOSoftware(Software): installing_count: int = 0 "The number of times the software has been installed. Default is 0." - max_sessions: int = 1 + max_sessions: int = 100 "The maximum number of sessions that the software can handle simultaneously. Default is 0." tcp: bool = True "Indicates if the software uses TCP protocol for communication. Default is True." diff --git a/tests/integration_tests/system/test_database_on_node.py b/tests/integration_tests/system/test_database_on_node.py index 98c8c87b..daa125ca 100644 --- a/tests/integration_tests/system/test_database_on_node.py +++ b/tests/integration_tests/system/test_database_on_node.py @@ -1,6 +1,9 @@ from ipaddress import IPv4Address +from typing import Tuple -from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +import pytest + +from primaite.simulator.network.hardware.base import Link, NIC, Node, NodeOperatingState from primaite.simulator.network.hardware.nodes.server import Server from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.services.database.database_service import DatabaseService @@ -8,57 +11,109 @@ from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.service import ServiceOperatingState -def test_database_client_server_connection(uc2_network): - web_server: Server = uc2_network.get_node_by_hostname("web_server") - db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient") +@pytest.fixture(scope="function") +def peer_to_peer() -> Tuple[Node, Node]: + node_a = Node(hostname="node_a", operating_state=NodeOperatingState.ON) + nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON) + node_a.connect_nic(nic_a) + node_a.software_manager.get_open_ports() - db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") + node_b = Node(hostname="node_b", operating_state=NodeOperatingState.ON) + nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0") + node_b.connect_nic(nic_b) + Link(endpoint_a=nic_a, endpoint_b=nic_b) + + assert node_a.ping("192.168.0.11") + + node_a.software_manager.install(DatabaseClient) + node_a.software_manager.software["DatabaseClient"].configure(server_ip_address=IPv4Address("192.168.0.11")) + node_a.software_manager.software["DatabaseClient"].run() + + node_b.software_manager.install(DatabaseService) + database_service: DatabaseService = node_b.software_manager.software["DatabaseService"] # noqa + database_service.start() + return node_a, node_b + + +@pytest.fixture(scope="function") +def peer_to_peer_secure_db() -> Tuple[Node, Node]: + node_a = Node(hostname="node_a", operating_state=NodeOperatingState.ON) + nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON) + node_a.connect_nic(nic_a) + node_a.software_manager.get_open_ports() + + node_b = Node(hostname="node_b", operating_state=NodeOperatingState.ON) + nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0") + node_b.connect_nic(nic_b) + + Link(endpoint_a=nic_a, endpoint_b=nic_b) + + assert node_a.ping("192.168.0.11") + + node_a.software_manager.install(DatabaseClient) + node_a.software_manager.software["DatabaseClient"].configure(server_ip_address=IPv4Address("192.168.0.11")) + node_a.software_manager.software["DatabaseClient"].run() + + node_b.software_manager.install(DatabaseService) + database_service: DatabaseService = node_b.software_manager.software["DatabaseService"] # noqa + database_service.password = "12345" + database_service.start() + return node_a, node_b + + +def test_database_client_server_connection(peer_to_peer): + node_a, node_b = peer_to_peer + + db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"] + + db_service: DatabaseService = node_b.software_manager.software["DatabaseService"] + + db_client.connect() + assert len(db_client.connections) == 1 assert len(db_service.connections) == 1 db_client.disconnect() + assert len(db_client.connections) == 0 assert len(db_service.connections) == 0 -def test_database_client_server_correct_password(uc2_network): - web_server: Server = uc2_network.get_node_by_hostname("web_server") - db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient") +def test_database_client_server_correct_password(peer_to_peer_secure_db): + node_a, node_b = peer_to_peer_secure_db - db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") + db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"] - db_client.disconnect() - - db_client.configure(server_ip_address=IPv4Address("192.168.1.14"), server_password="12345") - db_service.password = "12345" - - assert db_client.connect() + db_service: DatabaseService = node_b.software_manager.software["DatabaseService"] + db_client.configure(server_ip_address=IPv4Address("192.168.0.11"), server_password="12345") + db_client.connect() + assert len(db_client.connections) == 1 assert len(db_service.connections) == 1 -def test_database_client_server_incorrect_password(uc2_network): - web_server: Server = uc2_network.get_node_by_hostname("web_server") - db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient") +def test_database_client_server_incorrect_password(peer_to_peer_secure_db): + node_a, node_b = peer_to_peer_secure_db - db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") + db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"] - db_client.disconnect() - db_client.configure(server_ip_address=IPv4Address("192.168.1.14"), server_password="54321") - db_service.password = "12345" + db_service: DatabaseService = node_b.software_manager.software["DatabaseService"] - assert not db_client.connect() + # should fail + db_client.connect() + assert len(db_client.connections) == 0 + assert len(db_service.connections) == 0 + + db_client.configure(server_ip_address=IPv4Address("192.168.0.11"), server_password="wrongpass") + db_client.connect() + assert len(db_client.connections) == 0 assert len(db_service.connections) == 0 def test_database_client_query(uc2_network): """Tests DB query across the network returns HTTP status 200 and date.""" web_server: Server = uc2_network.get_node_by_hostname("web_server") - db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient") - - assert db_client.connected + db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + db_client.connect() assert db_client.query("SELECT") @@ -66,13 +121,13 @@ def test_database_client_query(uc2_network): def test_create_database_backup(uc2_network): """Run the backup_database method and check if the FTP server has the relevant file.""" db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") + db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] # back up should be created assert db_service.backup_database() is True backup_server: Server = uc2_network.get_node_by_hostname("backup_server") - ftp_server: FTPServer = backup_server.software_manager.software.get("FTPServer") + ftp_server: FTPServer = backup_server.software_manager.software["FTPServer"] # backup file should exist in the backup server assert ftp_server.file_system.get_file(folder_name=db_service.uuid, file_name="database.db") is not None @@ -81,7 +136,7 @@ def test_create_database_backup(uc2_network): def test_restore_backup(uc2_network): """Run the restore_backup method and check if the backup is properly restored.""" db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") + db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] # create a back up assert db_service.backup_database() is True @@ -107,7 +162,7 @@ def test_database_client_cannot_query_offline_database_server(uc2_network): web_server: Server = uc2_network.get_node_by_hostname("web_server") db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient") - assert db_client.connected + assert len(db_client.connections) assert db_client.query("SELECT") is True diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py index b0ff0467..2ca67119 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py @@ -70,4 +70,4 @@ def test_dm_bot_perform_data_manipulation_success(dm_bot): dm_bot._perform_data_manipulation(p_of_success=1.0) assert dm_bot.attack_stage in (DataManipulationAttackStage.SUCCEEDED, DataManipulationAttackStage.FAILED) - assert dm_bot.connected + assert len(dm_bot.connections) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py index 59d44561..15d28d4b 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py @@ -1,5 +1,6 @@ from ipaddress import IPv4Address from typing import Tuple, Union +from uuid import uuid4 import pytest @@ -65,15 +66,14 @@ def test_disconnect(database_client_on_computer): """Database client should set connected to False and remove the database server ip address.""" database_client, computer = database_client_on_computer - database_client.connected = True + database_client.connections[uuid4()] = {} assert database_client.operating_state is ApplicationOperatingState.RUNNING assert database_client.server_ip_address is not None database_client.disconnect() - assert database_client.connected is False - assert database_client.server_ip_address is None + assert len(database_client.connections) == 0 def test_query_when_client_is_closed(database_client_on_computer): @@ -86,19 +86,6 @@ def test_query_when_client_is_closed(database_client_on_computer): assert database_client.query(sql="test") is False -def test_query_failed_reattempt(database_client_on_computer): - """Database client query should return False if the reattempt fails.""" - database_client, computer = database_client_on_computer - - def return_false(): - return False - - database_client.connect = return_false - - database_client.connected = False - assert database_client.query(sql="test", is_reattempt=True) is False - - def test_query_fail_to_connect(database_client_on_computer): """Database client query should return False if the connect attempt fails.""" database_client, computer = database_client_on_computer