From 094e89fff15a073207a43a8422a7fef23669544c Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Fri, 8 Dec 2023 14:54:29 +0000 Subject: [PATCH 1/5] #2059: Renamed Red service to red application and moved the datamanipulation bot to the red application folder --- src/primaite/game/agent/data_manipulation_bot.py | 2 +- src/primaite/game/game.py | 3 ++- src/primaite/simulator/network/networks.py | 2 +- .../red_services => applications/red_applications}/__init__.py | 0 .../red_applications}/data_manipulation_bot.py | 0 .../test_uc2_data_manipulation_scenario.py | 2 +- .../_red_applications}/__init__.py | 0 .../_red_applications}/test_data_manipulation_bot.py | 2 +- 8 files changed, 6 insertions(+), 5 deletions(-) rename src/primaite/simulator/system/{services/red_services => applications/red_applications}/__init__.py (100%) rename src/primaite/simulator/system/{services/red_services => applications/red_applications}/data_manipulation_bot.py (100%) rename tests/unit_tests/_primaite/_simulator/_system/{_services/_red_services => _applications/_red_applications}/__init__.py (100%) rename tests/unit_tests/_primaite/_simulator/_system/{_services/_red_services => _applications/_red_applications}/test_data_manipulation_bot.py (96%) diff --git a/src/primaite/game/agent/data_manipulation_bot.py b/src/primaite/game/agent/data_manipulation_bot.py index 8237ce06..791c362d 100644 --- a/src/primaite/game/agent/data_manipulation_bot.py +++ b/src/primaite/game/agent/data_manipulation_bot.py @@ -4,7 +4,7 @@ from typing import Dict, List, Tuple from gymnasium.core import ObsType from primaite.game.agent.interface import AbstractScriptedAgent -from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot +from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot class DataManipulationAgent(AbstractScriptedAgent): diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 8c32f41d..b6b815f1 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -20,13 +20,13 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.ftp.ftp_server import FTPServer -from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot from primaite.simulator.system.services.web_server.web_server import WebServer _LOGGER = getLogger(__name__) @@ -314,6 +314,7 @@ class PrimaiteGame: opt = application_cfg["options"] new_application.configure( server_ip_address=IPv4Address(opt.get("server_ip")), + server_password=opt.get("server_password"), payload=opt.get("payload"), port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")), diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index 4cd9c8d3..61ec7baf 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -9,10 +9,10 @@ from primaite.simulator.network.hardware.nodes.switch import Switch from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.ftp.ftp_server import FTPServer -from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot from primaite.simulator.system.services.web_server.web_server import WebServer diff --git a/src/primaite/simulator/system/services/red_services/__init__.py b/src/primaite/simulator/system/applications/red_applications/__init__.py similarity index 100% rename from src/primaite/simulator/system/services/red_services/__init__.py rename to src/primaite/simulator/system/applications/red_applications/__init__.py diff --git a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py similarity index 100% rename from src/primaite/simulator/system/services/red_services/data_manipulation_bot.py rename to src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py diff --git a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py index 0dc2c031..5206561b 100644 --- a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py +++ b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py @@ -1,8 +1,8 @@ from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.server import Server from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot from primaite.simulator.system.services.database.database_service import DatabaseService -from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot def test_data_manipulation(uc2_network): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/__init__.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/__init__.py similarity index 100% rename from tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/__init__.py rename to tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/__init__.py diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py similarity index 96% rename from tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py rename to tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py index 2c4826bf..b0ff0467 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py @@ -5,7 +5,7 @@ from primaite.simulator.network.networks import arcd_uc2_network from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import ApplicationOperatingState -from primaite.simulator.system.services.red_services.data_manipulation_bot import ( +from primaite.simulator.system.applications.red_applications.data_manipulation_bot import ( DataManipulationAttackStage, DataManipulationBot, ) From cd5ed48b007c0b4e8304dd75f861698b488337bb Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Fri, 8 Dec 2023 17:07:57 +0000 Subject: [PATCH 2/5] #2059: implementing the service connections limit --- src/primaite/simulator/network/networks.py | 4 +- .../system/applications/database_client.py | 109 ++++++++++------ .../red_applications/data_manipulation_bot.py | 4 +- .../services/database/database_service.py | 61 ++++++--- .../system/services/ftp/ftp_server.py | 13 +- .../simulator/system/services/service.py | 72 ++++++++++- src/primaite/simulator/system/software.py | 2 +- .../system/test_database_on_node.py | 121 +++++++++++++----- .../test_data_manipulation_bot.py | 2 +- .../_applications/test_database_client.py | 19 +-- 10 files changed, 280 insertions(+), 127 deletions(-) diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index 61ec7baf..630846b3 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -252,9 +252,9 @@ def arcd_uc2_network() -> Network: database_service: DatabaseService = database_server.software_manager.software.get("DatabaseService") # noqa database_service.start() database_service.configure_backup(backup_server=IPv4Address("192.168.1.16")) - database_service._process_sql(ddl, None) # noqa + database_service._process_sql(ddl, None, None) # noqa for insert_statement in user_insert_statements: - database_service._process_sql(insert_statement, None) # noqa + database_service._process_sql(insert_statement, None, None) # noqa # Web Server web_server = Server( diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index f57246fc..9d7bfcaa 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -23,7 +23,7 @@ class DatabaseClient(Application): server_ip_address: Optional[IPv4Address] = None server_password: Optional[str] = None - connected: bool = False + connections: Dict[str, Dict] = {} _query_success_tracker: Dict[str, bool] = {} def __init__(self, **kwargs): @@ -66,18 +66,24 @@ class DatabaseClient(Application): self.server_password = server_password self.sys_log.info(f"{self.name}: Configured the {self.name} with {server_ip_address=}, {server_password=}.") - def connect(self) -> bool: + def connect(self, connection_id: Optional[str] = None) -> bool: """Connect to a Database Service.""" if not self._can_perform_action(): return False - if not self.connected: - return self._connect(self.server_ip_address, self.server_password) - # already connected - return True + if not connection_id: + connection_id = str(uuid4()) + + return self._connect( + server_ip_address=self.server_ip_address, password=self.server_password, connection_id=connection_id + ) def _connect( - self, server_ip_address: IPv4Address, password: Optional[str] = None, is_reattempt: bool = False + self, + server_ip_address: IPv4Address, + connection_id: Optional[str] = None, + password: Optional[str] = None, + is_reattempt: bool = False, ) -> bool: """ Connects the DatabaseClient to the DatabaseServer. @@ -92,33 +98,58 @@ class DatabaseClient(Application): :type: is_reattempt: Optional[bool] """ if is_reattempt: - if self.connected: - self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} authorised") + if self.connections.get(connection_id): + self.sys_log.info( + f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} authorised" + ) self.server_ip_address = server_ip_address - return self.connected + return True else: - self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} declined") + self.sys_log.info( + f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} declined" + ) return False - payload = {"type": "connect_request", "password": password} + payload = { + "type": "connect_request", + "password": password, + "connection_id": connection_id, + } software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( payload=payload, dest_ip_address=server_ip_address, dest_port=self.port ) - return self._connect(server_ip_address, password, True) + return self._connect( + server_ip_address=server_ip_address, password=password, connection_id=connection_id, is_reattempt=True + ) - def disconnect(self): + def disconnect(self, connection_id: Optional[str] = None) -> bool: """Disconnect from the Database Service.""" - if self.connected and self.operating_state is ApplicationOperatingState.RUNNING: - software_manager: SoftwareManager = self.software_manager - software_manager.send_payload_to_session_manager( - payload={"type": "disconnect"}, dest_ip_address=self.server_ip_address, dest_port=self.port - ) + if not self._can_perform_action(): + self.sys_log.error(f"Unable to disconnect - {self.name} is {self.operating_state.name}") + return False - self.sys_log.info(f"{self.name}: DatabaseClient disconnected from {self.server_ip_address}") - self.server_ip_address = None - self.connected = False + # if there are no connections - nothing to disconnect + if not len(self.connections): + self.sys_log.error(f"Unable to disconnect - {self.name} has no active connections.") + return False - def _query(self, sql: str, query_id: str, is_reattempt: bool = False) -> bool: + # if no connection provided, disconnect the first connection + if not connection_id: + connection_id = list(self.connections.keys())[0] + + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( + payload={"type": "disconnect", "connection_id": connection_id}, + dest_ip_address=self.server_ip_address, + dest_port=self.port, + ) + self.connections.pop(connection_id) + + self.sys_log.info( + f"{self.name}: DatabaseClient disconnected connection {connection_id} from {self.server_ip_address}" + ) + + def _query(self, sql: str, query_id: str, connection_id: str, is_reattempt: bool = False) -> bool: """ Send a query to the connected database server. @@ -141,11 +172,11 @@ class DatabaseClient(Application): else: software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( - payload={"type": "sql", "sql": sql, "uuid": query_id}, + payload={"type": "sql", "sql": sql, "uuid": query_id, "connection_id": connection_id}, dest_ip_address=self.server_ip_address, dest_port=self.port, ) - return self._query(sql=sql, query_id=query_id, is_reattempt=True) + return self._query(sql=sql, query_id=query_id, connection_id=connection_id, is_reattempt=True) def run(self) -> None: """Run the DatabaseClient.""" @@ -153,7 +184,7 @@ class DatabaseClient(Application): if self.operating_state == ApplicationOperatingState.RUNNING: self.connect() - def query(self, sql: str, is_reattempt: bool = False) -> bool: + def query(self, sql: str, connection_id: Optional[str] = None) -> bool: """ Send a query to the Database Service. @@ -164,20 +195,17 @@ class DatabaseClient(Application): if not self._can_perform_action(): return False - if self.connected: - query_id = str(uuid4()) + if connection_id is None: + connection_id = str(uuid4()) + + if not self.connections.get(connection_id): + if not self.connect(connection_id=connection_id): + return False # Initialise the tracker of this ID to False - self._query_success_tracker[query_id] = False - return self._query(sql=sql, query_id=query_id) - else: - if is_reattempt: - return False - - if not self.connect(): - return False - - self.query(sql=sql, is_reattempt=True) + uuid = str(uuid4()) + self._query_success_tracker[uuid] = False + return self._query(sql=sql, query_id=uuid, connection_id=connection_id) def receive(self, payload: Any, session_id: str, **kwargs) -> bool: """ @@ -192,13 +220,12 @@ class DatabaseClient(Application): if isinstance(payload, dict) and payload.get("type"): if payload["type"] == "connect_response": - self.connected = payload["response"] == True + if payload["response"] is True: + self.connections[payload.get("connection_id")] = payload elif payload["type"] == "sql": query_id = payload.get("uuid") status_code = payload.get("status_code") self._query_success_tracker[query_id] = status_code == 200 if self._query_success_tracker[query_id]: _LOGGER.debug(f"Received payload {payload}") - else: - self.connected = False return True diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index 44a56cf1..87959e9b 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -149,9 +149,9 @@ class DataManipulationBot(DatabaseClient): if simulate_trial(p_of_success): self.sys_log.info(f"{self.name}: Performing data manipulation") # perform the attack - if not self.connected: + if not len(self.connections): self.connect() - if self.connected: + if len(self.connections): self.query(self.payload) self.sys_log.info(f"{self.name} payload delivered: {self.payload}") attack_successful = True diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 61cf1560..70a4e6cc 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -1,4 +1,3 @@ -from datetime import datetime from ipaddress import IPv4Address from typing import Any, Dict, List, Literal, Optional, Union @@ -22,7 +21,6 @@ class DatabaseService(Service): """ password: Optional[str] = None - connections: Dict[str, datetime] = {} backup_server: IPv4Address = None """IP address of the backup server.""" @@ -140,7 +138,7 @@ class DatabaseService(Service): self.folder = self.file_system.get_folder_by_id(self._db_file.folder_id) def _process_connect( - self, session_id: str, password: Optional[str] = None + self, connection_id: str, password: Optional[str] = None ) -> Dict[str, Union[int, Dict[str, bool]]]: status_code = 500 # Default internal server error if self.operating_state == ServiceOperatingState.RUNNING: @@ -148,16 +146,27 @@ class DatabaseService(Service): if self.health_state_actual == SoftwareHealthState.GOOD: if self.password == password: status_code = 200 # ok - self.connections[session_id] = datetime.now() - self.sys_log.info(f"{self.name}: Connect request for {session_id=} authorised") + # try to create connection + if not self.add_connection(connection_id=connection_id): + status_code = 500 + self.sys_log.info(f"{self.name}: Connect request for {connection_id=} declined") + else: + self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised") else: status_code = 401 # Unauthorised - self.sys_log.info(f"{self.name}: Connect request for {session_id=} declined") + self.sys_log.info(f"{self.name}: Connect request for {connection_id=} declined") else: status_code = 404 # service not found - return {"status_code": status_code, "type": "connect_response", "response": status_code == 200} + return { + "status_code": status_code, + "type": "connect_response", + "response": status_code == 200, + "connection_id": connection_id, + } - def _process_sql(self, query: Literal["SELECT", "DELETE"], query_id: str) -> Dict[str, Union[int, List[Any]]]: + def _process_sql( + self, query: Literal["SELECT", "DELETE"], query_id: str, connection_id: Optional[str] = None + ) -> Dict[str, Union[int, List[Any]]]: """ Executes the given SQL query and returns the result. @@ -169,15 +178,28 @@ class DatabaseService(Service): :return: Dictionary containing status code and data fetched. """ self.sys_log.info(f"{self.name}: Running {query}") + if query == "SELECT": if self.health_state_actual == SoftwareHealthState.GOOD: - return {"status_code": 200, "type": "sql", "data": True, "uuid": query_id} + return { + "status_code": 200, + "type": "sql", + "data": True, + "uuid": query_id, + "connection_id": connection_id, + } else: return {"status_code": 404, "data": False} elif query == "DELETE": if self.health_state_actual == SoftwareHealthState.GOOD: self.health_state_actual = SoftwareHealthState.COMPROMISED - return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id} + return { + "status_code": 200, + "type": "sql", + "data": False, + "uuid": query_id, + "connection_id": connection_id, + } else: return {"status_code": 404, "data": False} else: @@ -207,15 +229,24 @@ class DatabaseService(Service): return False result = {"status_code": 500, "data": []} + + # if server service is down, return error + if not self._can_perform_action(): + return False + if isinstance(payload, dict) and payload.get("type"): if payload["type"] == "connect_request": - result = self._process_connect(session_id=session_id, password=payload.get("password")) + result = self._process_connect( + connection_id=payload.get("connection_id"), password=payload.get("password") + ) elif payload["type"] == "disconnect": - if session_id in self.connections: - self.connections.pop(session_id) + if payload["connection_id"] in self.connections: + self.remove_connection(connection_id=payload["connection_id"]) elif payload["type"] == "sql": - if session_id in self.connections: - result = self._process_sql(query=payload["sql"], query_id=payload["uuid"]) + if payload.get("connection_id") in self.connections: + result = self._process_sql( + query=payload["sql"], query_id=payload["uuid"], connection_id=payload["connection_id"] + ) else: result = {"status_code": 401, "type": "sql"} self.send(payload=result, session_id=session_id) diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index 0278b616..6e6c1a48 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -1,5 +1,4 @@ -from ipaddress import IPv4Address -from typing import Any, Dict, Optional +from typing import Any, Optional from primaite import getLogger from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode @@ -21,9 +20,6 @@ class FTPServer(FTPServiceABC): server_password: Optional[str] = None """Password needed to connect to FTP server. Default is None.""" - connections: Dict[str, IPv4Address] = {} - """Current active connections to the FTP server.""" - def __init__(self, **kwargs): kwargs["name"] = "FTPServer" kwargs["port"] = Port.FTP @@ -62,9 +58,6 @@ class FTPServer(FTPServiceABC): self.sys_log.info(f"{self.name}: Received FTP {payload.ftp_command.name} {payload.ftp_command_args}") - if session_id: - session_details = self._get_session_details(session_id) - if payload.ftp_command is not None: self.sys_log.info(f"Received FTP {payload.ftp_command.name} command.") @@ -73,7 +66,7 @@ class FTPServer(FTPServiceABC): # check that the port is valid if isinstance(payload.ftp_command_args, Port) and payload.ftp_command_args.value in range(0, 65535): # return successful connection - self.connections[session_id] = session_details.with_ip_address + self.add_connection(connection_id=session_id, session_id=session_id) payload.status_code = FTPStatusCode.OK return payload @@ -81,7 +74,7 @@ class FTPServer(FTPServiceABC): return payload if payload.ftp_command == FTPCommand.QUIT: - self.connections.pop(session_id) + self.remove_connection(connection_id=session_id) payload.status_code = FTPStatusCode.OK return payload diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index e60b7700..52187e51 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -1,3 +1,5 @@ +import copy +from datetime import datetime from enum import Enum from typing import Any, Dict, Optional @@ -40,6 +42,15 @@ class Service(IOSoftware): restart_countdown: Optional[int] = None "If currently restarting, how many timesteps remain until the restart is finished." + _connections: Dict[str, Dict] = {} + "Active connections to the Service." + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.health_state_visible = SoftwareHealthState.UNUSED + self.health_state_actual = SoftwareHealthState.UNUSED + def _can_perform_action(self) -> bool: """ Checks if the service can perform actions. @@ -74,12 +85,6 @@ class Service(IOSoftware): """ return super().receive(payload=payload, session_id=session_id, **kwargs) - def __init__(self, **kwargs): - super().__init__(**kwargs) - - self.health_state_visible = SoftwareHealthState.UNUSED - self.health_state_actual = SoftwareHealthState.UNUSED - def set_original_state(self): """Sets the original state.""" super().set_original_state() @@ -98,6 +103,11 @@ class Service(IOSoftware): rm.add_request("enable", RequestType(func=lambda request, context: self.enable())) return rm + @property + def connections(self) -> Dict[str, Dict]: + """Return the public version of connections.""" + return copy.copy(self._connections) + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -113,6 +123,56 @@ class Service(IOSoftware): state["health_state_visible"] = self.health_state_visible.value return state + def add_connection(self, connection_id: str, session_id: Optional[str] = None) -> bool: + """ + Create a new connection to this service. + + Returns true if connection successfully created + + :param: connection_id: UUID of the connection to create + :type: string + """ + # if over or at capacity, set to overwhelmed + if len(self._connections) >= self.max_sessions: + self.health_state_actual = SoftwareHealthState.OVERWHELMED + self.sys_log.error(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.") + return False + else: + # if service was previously overwhelmed, set to good because there is enough space for connections + if self.health_state_actual == SoftwareHealthState.OVERWHELMED: + self.health_state_actual = SoftwareHealthState.GOOD + + # check that connection already doesn't exist + if not self._connections.get(connection_id): + session_details = None + if session_id: + session_details = self._get_session_details(session_id) + self._connections[connection_id] = { + "ip_address": session_details.with_ip_address if session_details else None, + "time": datetime.now(), + } + self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised") + return True + # connection with given id already exists + self.sys_log.error( + f"{self.name}: Connect request for {connection_id=} declined. Connection already exists." + ) + return False + + def remove_connection(self, connection_id: str) -> bool: + """ + Remove a connection from this service. + + Returns true if connection successfully removed + + :param: connection_id: UUID of the connection to create + :type: string + """ + if self._connections.get(connection_id): + self._connections.pop(connection_id) + self.sys_log.info(f"{self.name}: Connection {connection_id=} closed.") + return True + def stop(self) -> None: """Stop the service.""" if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]: diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 87802a7b..8746bdf3 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -198,7 +198,7 @@ class IOSoftware(Software): installing_count: int = 0 "The number of times the software has been installed. Default is 0." - max_sessions: int = 1 + max_sessions: int = 100 "The maximum number of sessions that the software can handle simultaneously. Default is 0." tcp: bool = True "Indicates if the software uses TCP protocol for communication. Default is True." diff --git a/tests/integration_tests/system/test_database_on_node.py b/tests/integration_tests/system/test_database_on_node.py index 98c8c87b..daa125ca 100644 --- a/tests/integration_tests/system/test_database_on_node.py +++ b/tests/integration_tests/system/test_database_on_node.py @@ -1,6 +1,9 @@ from ipaddress import IPv4Address +from typing import Tuple -from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +import pytest + +from primaite.simulator.network.hardware.base import Link, NIC, Node, NodeOperatingState from primaite.simulator.network.hardware.nodes.server import Server from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.services.database.database_service import DatabaseService @@ -8,57 +11,109 @@ from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.service import ServiceOperatingState -def test_database_client_server_connection(uc2_network): - web_server: Server = uc2_network.get_node_by_hostname("web_server") - db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient") +@pytest.fixture(scope="function") +def peer_to_peer() -> Tuple[Node, Node]: + node_a = Node(hostname="node_a", operating_state=NodeOperatingState.ON) + nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON) + node_a.connect_nic(nic_a) + node_a.software_manager.get_open_ports() - db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") + node_b = Node(hostname="node_b", operating_state=NodeOperatingState.ON) + nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0") + node_b.connect_nic(nic_b) + Link(endpoint_a=nic_a, endpoint_b=nic_b) + + assert node_a.ping("192.168.0.11") + + node_a.software_manager.install(DatabaseClient) + node_a.software_manager.software["DatabaseClient"].configure(server_ip_address=IPv4Address("192.168.0.11")) + node_a.software_manager.software["DatabaseClient"].run() + + node_b.software_manager.install(DatabaseService) + database_service: DatabaseService = node_b.software_manager.software["DatabaseService"] # noqa + database_service.start() + return node_a, node_b + + +@pytest.fixture(scope="function") +def peer_to_peer_secure_db() -> Tuple[Node, Node]: + node_a = Node(hostname="node_a", operating_state=NodeOperatingState.ON) + nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON) + node_a.connect_nic(nic_a) + node_a.software_manager.get_open_ports() + + node_b = Node(hostname="node_b", operating_state=NodeOperatingState.ON) + nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0") + node_b.connect_nic(nic_b) + + Link(endpoint_a=nic_a, endpoint_b=nic_b) + + assert node_a.ping("192.168.0.11") + + node_a.software_manager.install(DatabaseClient) + node_a.software_manager.software["DatabaseClient"].configure(server_ip_address=IPv4Address("192.168.0.11")) + node_a.software_manager.software["DatabaseClient"].run() + + node_b.software_manager.install(DatabaseService) + database_service: DatabaseService = node_b.software_manager.software["DatabaseService"] # noqa + database_service.password = "12345" + database_service.start() + return node_a, node_b + + +def test_database_client_server_connection(peer_to_peer): + node_a, node_b = peer_to_peer + + db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"] + + db_service: DatabaseService = node_b.software_manager.software["DatabaseService"] + + db_client.connect() + assert len(db_client.connections) == 1 assert len(db_service.connections) == 1 db_client.disconnect() + assert len(db_client.connections) == 0 assert len(db_service.connections) == 0 -def test_database_client_server_correct_password(uc2_network): - web_server: Server = uc2_network.get_node_by_hostname("web_server") - db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient") +def test_database_client_server_correct_password(peer_to_peer_secure_db): + node_a, node_b = peer_to_peer_secure_db - db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") + db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"] - db_client.disconnect() - - db_client.configure(server_ip_address=IPv4Address("192.168.1.14"), server_password="12345") - db_service.password = "12345" - - assert db_client.connect() + db_service: DatabaseService = node_b.software_manager.software["DatabaseService"] + db_client.configure(server_ip_address=IPv4Address("192.168.0.11"), server_password="12345") + db_client.connect() + assert len(db_client.connections) == 1 assert len(db_service.connections) == 1 -def test_database_client_server_incorrect_password(uc2_network): - web_server: Server = uc2_network.get_node_by_hostname("web_server") - db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient") +def test_database_client_server_incorrect_password(peer_to_peer_secure_db): + node_a, node_b = peer_to_peer_secure_db - db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") + db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"] - db_client.disconnect() - db_client.configure(server_ip_address=IPv4Address("192.168.1.14"), server_password="54321") - db_service.password = "12345" + db_service: DatabaseService = node_b.software_manager.software["DatabaseService"] - assert not db_client.connect() + # should fail + db_client.connect() + assert len(db_client.connections) == 0 + assert len(db_service.connections) == 0 + + db_client.configure(server_ip_address=IPv4Address("192.168.0.11"), server_password="wrongpass") + db_client.connect() + assert len(db_client.connections) == 0 assert len(db_service.connections) == 0 def test_database_client_query(uc2_network): """Tests DB query across the network returns HTTP status 200 and date.""" web_server: Server = uc2_network.get_node_by_hostname("web_server") - db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient") - - assert db_client.connected + db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + db_client.connect() assert db_client.query("SELECT") @@ -66,13 +121,13 @@ def test_database_client_query(uc2_network): def test_create_database_backup(uc2_network): """Run the backup_database method and check if the FTP server has the relevant file.""" db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") + db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] # back up should be created assert db_service.backup_database() is True backup_server: Server = uc2_network.get_node_by_hostname("backup_server") - ftp_server: FTPServer = backup_server.software_manager.software.get("FTPServer") + ftp_server: FTPServer = backup_server.software_manager.software["FTPServer"] # backup file should exist in the backup server assert ftp_server.file_system.get_file(folder_name=db_service.uuid, file_name="database.db") is not None @@ -81,7 +136,7 @@ def test_create_database_backup(uc2_network): def test_restore_backup(uc2_network): """Run the restore_backup method and check if the backup is properly restored.""" db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") + db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] # create a back up assert db_service.backup_database() is True @@ -107,7 +162,7 @@ def test_database_client_cannot_query_offline_database_server(uc2_network): web_server: Server = uc2_network.get_node_by_hostname("web_server") db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient") - assert db_client.connected + assert len(db_client.connections) assert db_client.query("SELECT") is True diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py index b0ff0467..2ca67119 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py @@ -70,4 +70,4 @@ def test_dm_bot_perform_data_manipulation_success(dm_bot): dm_bot._perform_data_manipulation(p_of_success=1.0) assert dm_bot.attack_stage in (DataManipulationAttackStage.SUCCEEDED, DataManipulationAttackStage.FAILED) - assert dm_bot.connected + assert len(dm_bot.connections) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py index 59d44561..15d28d4b 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py @@ -1,5 +1,6 @@ from ipaddress import IPv4Address from typing import Tuple, Union +from uuid import uuid4 import pytest @@ -65,15 +66,14 @@ def test_disconnect(database_client_on_computer): """Database client should set connected to False and remove the database server ip address.""" database_client, computer = database_client_on_computer - database_client.connected = True + database_client.connections[uuid4()] = {} assert database_client.operating_state is ApplicationOperatingState.RUNNING assert database_client.server_ip_address is not None database_client.disconnect() - assert database_client.connected is False - assert database_client.server_ip_address is None + assert len(database_client.connections) == 0 def test_query_when_client_is_closed(database_client_on_computer): @@ -86,19 +86,6 @@ def test_query_when_client_is_closed(database_client_on_computer): assert database_client.query(sql="test") is False -def test_query_failed_reattempt(database_client_on_computer): - """Database client query should return False if the reattempt fails.""" - database_client, computer = database_client_on_computer - - def return_false(): - return False - - database_client.connect = return_false - - database_client.connected = False - assert database_client.query(sql="test", is_reattempt=True) is False - - def test_query_fail_to_connect(database_client_on_computer): """Database client query should return False if the connect attempt fails.""" database_client, computer = database_client_on_computer From 4f79d2ad36abd5e25aca33e09577dd3669aa098b Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Tue, 12 Dec 2023 17:01:03 +0000 Subject: [PATCH 3/5] #2059: moved connection handling from Service to IOSoftware + changes that now utilise connections from IOSoftware + dos bot attacking now works + tests --- .../system/applications/database_client.py | 10 +- .../red_applications/data_manipulation_bot.py | 5 +- .../applications/red_applications/dos_bot.py | 184 ++++++++++++++++++ .../services/database/database_service.py | 7 +- .../system/services/ftp/ftp_client.py | 24 +-- .../system/services/ftp/ftp_server.py | 2 +- .../simulator/system/services/service.py | 60 ------ src/primaite/simulator/system/software.py | 63 ++++++ .../test_dos_bot_and_server.py | 107 ++++++++++ .../_red_applications/test_dos_bot.py | 90 +++++++++ .../_applications/test_database_client.py | 13 +- .../_system/_services/test_services.py | 33 ++++ 12 files changed, 510 insertions(+), 88 deletions(-) create mode 100644 src/primaite/simulator/system/applications/red_applications/dos_bot.py create mode 100644 tests/integration_tests/system/red_applications/test_dos_bot_and_server.py create mode 100644 tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 9d7bfcaa..fbeefe6a 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -5,7 +5,7 @@ from uuid import uuid4 from primaite import getLogger from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port -from primaite.simulator.system.applications.application import Application, ApplicationOperatingState +from primaite.simulator.system.applications.application import Application from primaite.simulator.system.core.software_manager import SoftwareManager _LOGGER = getLogger(__name__) @@ -23,7 +23,6 @@ class DatabaseClient(Application): server_ip_address: Optional[IPv4Address] = None server_password: Optional[str] = None - connections: Dict[str, Dict] = {} _query_success_tracker: Dict[str, bool] = {} def __init__(self, **kwargs): @@ -143,7 +142,7 @@ class DatabaseClient(Application): dest_ip_address=self.server_ip_address, dest_port=self.port, ) - self.connections.pop(connection_id) + self.remove_connection(connection_id=connection_id) self.sys_log.info( f"{self.name}: DatabaseClient disconnected connection {connection_id} from {self.server_ip_address}" @@ -181,8 +180,6 @@ class DatabaseClient(Application): def run(self) -> None: """Run the DatabaseClient.""" super().run() - if self.operating_state == ApplicationOperatingState.RUNNING: - self.connect() def query(self, sql: str, connection_id: Optional[str] = None) -> bool: """ @@ -221,7 +218,8 @@ class DatabaseClient(Application): if isinstance(payload, dict) and payload.get("type"): if payload["type"] == "connect_response": if payload["response"] is True: - self.connections[payload.get("connection_id")] = payload + # add connection + self.add_connection(connection_id=payload.get("connection_id"), session_id=session_id) elif payload["type"] == "sql": query_id = payload.get("uuid") status_code = payload.get("status_code") diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index 87959e9b..a1429e51 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -5,7 +5,6 @@ from typing import Optional from primaite import getLogger from primaite.game.science import simulate_trial from primaite.simulator.core import RequestManager, RequestType -from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient _LOGGER = getLogger(__name__) @@ -177,9 +176,9 @@ class DataManipulationBot(DatabaseClient): This is the core loop where the bot sequentially goes through the stages of the attack. """ - if self.operating_state != ApplicationOperatingState.RUNNING: + if not self._can_perform_action(): return - if self.server_ip_address and self.payload and self.operating_state: + if self.server_ip_address and self.payload: self.sys_log.info(f"{self.name}: Running") self._logon() self._perform_port_scan(p_of_success=self.port_scan_p_of_success) diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py new file mode 100644 index 00000000..e6c643ee --- /dev/null +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -0,0 +1,184 @@ +from enum import IntEnum +from ipaddress import IPv4Address +from typing import Optional + +from primaite import getLogger +from primaite.game.science import simulate_trial +from primaite.simulator.core import RequestManager, RequestType +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.application import Application +from primaite.simulator.system.applications.database_client import DatabaseClient + +_LOGGER = getLogger(__name__) + + +class DoSAttackStage(IntEnum): + """Enum representing the different stages of a Denial of Service attack.""" + + NOT_STARTED = 0 + "Attack not yet started." + + PORT_SCAN = 1 + "Attack is in discovery stage - checking if provided ip and port are open." + + ATTACKING = 2 + "Denial of Service attack is in progress." + + COMPLETED = 3 + "Attack is completed." + + +class DoSBot(DatabaseClient, Application): + """A bot that simulates a Denial of Service attack.""" + + target_ip_address: Optional[IPv4Address] = None + """IP address of the target service.""" + + target_port: Optional[Port] = None + """Port of the target service.""" + + payload: Optional[str] = None + """Payload to deliver to the target service as part of the denial of service attack.""" + + repeat: bool = False + """If true, the Denial of Service bot will keep performing the attack.""" + + attack_stage: DoSAttackStage = DoSAttackStage.NOT_STARTED + """Current stage of the DoS kill chain.""" + + port_scan_p_of_success: float = 0.1 + """Probability of port scanning being sucessful.""" + + dos_intensity: float = 0.25 + """How much of the max sessions will be used by the DoS when attacking.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.name = "DoSBot" + self.max_sessions = 1000 # override normal max sessions + + def set_original_state(self): + """Set the original state of the Denial of Service Bot.""" + _LOGGER.debug(f"Setting {self.name} original state on node {self.software_manager.node.hostname}") + super().set_original_state() + vals_to_include = { + "target_ip_address", + "target_port", + "payload", + "repeat", + "attack_stage", + "max_sessions", + "port_scan_p_of_success", + "dos_intensity", + } + self._original_state.update(self.model_dump(include=vals_to_include)) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + _LOGGER.debug(f"Resetting {self.name} state on node {self.software_manager.node.hostname}") + super().reset_component_for_episode(episode) + + def _init_request_manager(self) -> RequestManager: + rm = super()._init_request_manager() + + rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.run())) + + return rm + + def configure( + self, + target_ip_address: IPv4Address, + target_port: Optional[Port] = Port.POSTGRES_SERVER, + payload: Optional[str] = None, + repeat: bool = False, + max_sessions: int = 1000, + ): + """ + Configure the Denial of Service bot. + + :param: target_ip_address: The IP address of the Node containing the target service. + :param: target_port: The port of the target service. Optional - Default is `Port.HTTP` + :param: payload: The payload the DoS Bot will throw at the target service. Optional - Default is `None` + :param: repeat: If True, the bot will maintain the attack. Optional - Default is `True` + :param: max_sessions: The maximum number of sessions the DoS bot will attack with. Optional - Default is 1000 + """ + self.target_ip_address = target_ip_address + self.target_port = target_port + self.payload = payload + self.repeat = repeat + self.max_sessions = max_sessions + self.sys_log.info( + f"{self.name}: Configured the {self.name} with {target_ip_address=}, {target_port=}, {payload=}, {repeat=}." + ) + + def run(self): + """Run the Denial of Service Bot.""" + super().run() + self._application_loop() + + def _application_loop(self): + """ + The main application loop for the Denial of Service bot. + + The loop goes through the stages of a DoS attack. + """ + if not self._can_perform_action(): + return + + # DoS bot cannot do anything without a target + if not self.target_ip_address or not self.target_port: + self.sys_log.error( + f"{self.name} is not properly configured. {self.target_ip_address=}, {self.target_port=}" + ) + return + + self.clear_connections() + self._perform_port_scan(p_of_success=self.port_scan_p_of_success) + self._perform_dos() + + if self.repeat and self.attack_stage is DoSAttackStage.ATTACKING: + self.attack_stage = DoSAttackStage.NOT_STARTED + else: + self.attack_stage = DoSAttackStage.COMPLETED + + def _perform_port_scan(self, p_of_success: Optional[float] = 0.1): + """ + Perform a simulated port scan to check for open SQL ports. + + Advances the attack stage to `PORT_SCAN` if successful. + + :param p_of_success: Probability of successful port scan, by default 0.1. + """ + if self.attack_stage == DoSAttackStage.NOT_STARTED: + # perform a port scan to identify that the SQL port is open on the server + if simulate_trial(p_of_success): + self.sys_log.info(f"{self.name}: Performing port scan") + # perform the port scan + port_is_open = True # Temporary; later we can implement NMAP port scan. + if port_is_open: + self.sys_log.info(f"{self.name}: ") + self.attack_stage = DoSAttackStage.PORT_SCAN + + def _perform_dos(self): + """ + Perform the Denial of Service attack. + + DoSBot does this by clogging up the available connections to a service. + """ + if not self.attack_stage == DoSAttackStage.PORT_SCAN: + return + self.attack_stage = DoSAttackStage.ATTACKING + self.server_ip_address = self.target_ip_address + self.port = self.target_port + + dos_sessions = int(float(self.max_sessions) * self.dos_intensity) + for i in range(dos_sessions): + self.connect() + + def apply_timestep(self, timestep: int) -> None: + """ + Apply a timestep to the bot, iterate through the application loop. + + :param timestep: The timestep value to update the bot's state. + """ + self._application_loop() diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 70a4e6cc..7d313068 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -45,7 +45,7 @@ class DatabaseService(Service): super().set_original_state() vals_to_include = { "password", - "connections", + "_connections", "backup_server", "latest_backup_directory", "latest_backup_file_name", @@ -55,7 +55,7 @@ class DatabaseService(Service): def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}") - self.connections.clear() + self.clear_connections() super().reset_component_for_episode(episode) def configure_backup(self, backup_server: IPv4Address): @@ -225,9 +225,6 @@ class DatabaseService(Service): :param session_id: The session identifier. :return: True if the Status Code is 200, otherwise False. """ - if not super().receive(payload=payload, session_id=session_id, **kwargs): - return False - result = {"status_code": 500, "data": []} # if server service is down, return error diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py index 52655fa4..7faa5d32 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_client.py +++ b/src/primaite/simulator/system/services/ftp/ftp_client.py @@ -20,9 +20,6 @@ class FTPClient(FTPServiceABC): RFC 959: https://datatracker.ietf.org/doc/html/rfc959 """ - connected: bool = False - """Keeps track of whether or not the FTP client is connected to an FTP server.""" - def __init__(self, **kwargs): kwargs["name"] = "FTPClient" kwargs["port"] = Port.FTP @@ -129,10 +126,7 @@ class FTPClient(FTPServiceABC): software_manager.send_payload_to_session_manager( payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port ) - if payload.status_code == FTPStatusCode.OK: - self.connected = False - return True - return False + return payload.status_code == FTPStatusCode.OK def send_file( self, @@ -179,9 +173,9 @@ class FTPClient(FTPServiceABC): return False # check if FTP is currently connected to IP - self.connected = self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port) + self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port) - if not self.connected: + if not len(self.connections): return False else: self.sys_log.info(f"Sending file {src_folder_name}/{src_file_name} to {str(dest_ip_address)}") @@ -230,9 +224,9 @@ class FTPClient(FTPServiceABC): :type: dest_port: Optional[Port] """ # check if FTP is currently connected to IP - self.connected = self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port) + self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port) - if not self.connected: + if not len(self.connections): return False else: # send retrieve request @@ -286,6 +280,14 @@ class FTPClient(FTPServiceABC): self.sys_log.error(f"FTP Server could not be found - Error Code: {FTPStatusCode.NOT_FOUND.value}") return False + # if PORT succeeded, add the connection as an active connection list + if payload.ftp_command is FTPCommand.PORT and payload.status_code is FTPStatusCode.OK: + self.add_connection(connection_id=session_id, session_id=session_id) + + # if QUIT succeeded, remove the session from active connection list + if payload.ftp_command is FTPCommand.QUIT and payload.status_code is FTPStatusCode.OK: + self.remove_connection(connection_id=session_id) + self.sys_log.info(f"{self.name}: Received FTP Response {payload.ftp_command.name} {payload.status_code.value}") self._process_ftp_command(payload=payload, session_id=session_id) diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index 6e6c1a48..585690b6 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -37,7 +37,7 @@ class FTPServer(FTPServiceABC): def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug(f"Resetting FTPServer state on node {self.software_manager.node.hostname}") - self.connections.clear() + self.clear_connections() super().reset_component_for_episode(episode) def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket: diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 52187e51..3155a4bd 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -1,5 +1,3 @@ -import copy -from datetime import datetime from enum import Enum from typing import Any, Dict, Optional @@ -42,9 +40,6 @@ class Service(IOSoftware): restart_countdown: Optional[int] = None "If currently restarting, how many timesteps remain until the restart is finished." - _connections: Dict[str, Dict] = {} - "Active connections to the Service." - def __init__(self, **kwargs): super().__init__(**kwargs) @@ -103,11 +98,6 @@ class Service(IOSoftware): rm.add_request("enable", RequestType(func=lambda request, context: self.enable())) return rm - @property - def connections(self) -> Dict[str, Dict]: - """Return the public version of connections.""" - return copy.copy(self._connections) - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -123,56 +113,6 @@ class Service(IOSoftware): state["health_state_visible"] = self.health_state_visible.value return state - def add_connection(self, connection_id: str, session_id: Optional[str] = None) -> bool: - """ - Create a new connection to this service. - - Returns true if connection successfully created - - :param: connection_id: UUID of the connection to create - :type: string - """ - # if over or at capacity, set to overwhelmed - if len(self._connections) >= self.max_sessions: - self.health_state_actual = SoftwareHealthState.OVERWHELMED - self.sys_log.error(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.") - return False - else: - # if service was previously overwhelmed, set to good because there is enough space for connections - if self.health_state_actual == SoftwareHealthState.OVERWHELMED: - self.health_state_actual = SoftwareHealthState.GOOD - - # check that connection already doesn't exist - if not self._connections.get(connection_id): - session_details = None - if session_id: - session_details = self._get_session_details(session_id) - self._connections[connection_id] = { - "ip_address": session_details.with_ip_address if session_details else None, - "time": datetime.now(), - } - self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised") - return True - # connection with given id already exists - self.sys_log.error( - f"{self.name}: Connect request for {connection_id=} declined. Connection already exists." - ) - return False - - def remove_connection(self, connection_id: str) -> bool: - """ - Remove a connection from this service. - - Returns true if connection successfully removed - - :param: connection_id: UUID of the connection to create - :type: string - """ - if self._connections.get(connection_id): - self._connections.pop(connection_id) - self.sys_log.info(f"{self.name}: Connection {connection_id=} closed.") - return True - def stop(self) -> None: """Stop the service.""" if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]: diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 8746bdf3..b393ffd8 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -1,4 +1,6 @@ +import copy from abc import abstractmethod +from datetime import datetime from enum import Enum from ipaddress import IPv4Address from typing import Any, Dict, Optional @@ -206,6 +208,8 @@ class IOSoftware(Software): "Indicates if the software uses UDP protocol for communication. Default is True." port: Port "The port to which the software is connected." + _connections: Dict[str, Dict] = {} + "Active connections." def set_original_state(self): """Sets the original state.""" @@ -250,6 +254,65 @@ class IOSoftware(Software): return False return True + @property + def connections(self) -> Dict[str, Dict]: + """Return the public version of connections.""" + return copy.copy(self._connections) + + def add_connection(self, connection_id: str, session_id: Optional[str] = None) -> bool: + """ + Create a new connection to this service. + + Returns true if connection successfully created + + :param: connection_id: UUID of the connection to create + :type: string + """ + # if over or at capacity, set to overwhelmed + if len(self._connections) >= self.max_sessions: + self.health_state_actual = SoftwareHealthState.OVERWHELMED + self.sys_log.error(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.") + return False + else: + # if service was previously overwhelmed, set to good because there is enough space for connections + if self.health_state_actual == SoftwareHealthState.OVERWHELMED: + self.health_state_actual = SoftwareHealthState.GOOD + + # check that connection already doesn't exist + if not self._connections.get(connection_id): + session_details = None + if session_id: + session_details = self._get_session_details(session_id) + self._connections[connection_id] = { + "ip_address": session_details.with_ip_address if session_details else None, + "time": datetime.now(), + } + self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised") + return True + # connection with given id already exists + self.sys_log.error( + f"{self.name}: Connect request for {connection_id=} declined. Connection already exists." + ) + return False + + def remove_connection(self, connection_id: str) -> bool: + """ + Remove a connection from this service. + + Returns true if connection successfully removed + + :param: connection_id: UUID of the connection to create + :type: string + """ + if self.connections.get(connection_id): + self._connections.pop(connection_id) + self.sys_log.info(f"{self.name}: Connection {connection_id=} closed.") + return True + + def clear_connections(self): + """Clears all the connections from the software.""" + self._connections = {} + def send( self, payload: Any, diff --git a/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py b/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py new file mode 100644 index 00000000..2828cc25 --- /dev/null +++ b/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py @@ -0,0 +1,107 @@ +from ipaddress import IPv4Address +from typing import Tuple + +import pytest + +from primaite.simulator.network.hardware.nodes.computer import Computer +from primaite.simulator.network.hardware.nodes.server import Server +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.application import ApplicationOperatingState +from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot +from primaite.simulator.system.services.database.database_service import DatabaseService +from primaite.simulator.system.software import SoftwareHealthState + + +@pytest.fixture(scope="function") +def dos_bot_and_db_server(client_server) -> Tuple[DoSBot, Computer, DatabaseService, Server]: + computer, server = client_server + + # Install DoSBot on computer + computer.software_manager.install(DoSBot) + + dos_bot: DoSBot = computer.software_manager.software.get("DoSBot") + dos_bot.configure( + target_ip_address=IPv4Address(server.nics.get(next(iter(server.nics))).ip_address), + target_port=Port.POSTGRES_SERVER, + ) + + # Install FTP Server service on server + server.software_manager.install(DatabaseService) + db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService") + db_server_service.start() + + return dos_bot, computer, db_server_service, server + + +def test_repeating_dos_attack(dos_bot_and_db_server): + dos_bot, computer, db_server_service, server = dos_bot_and_db_server + + assert db_server_service.health_state_actual is SoftwareHealthState.GOOD + + dos_bot.port_scan_p_of_success = 1 + dos_bot.repeat = True + dos_bot.run() + + assert len(dos_bot.connections) == db_server_service.max_sessions + assert len(db_server_service.connections) == db_server_service.max_sessions + assert len(dos_bot.connections) == db_server_service.max_sessions + + assert dos_bot.attack_stage is DoSAttackStage.NOT_STARTED + assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED + + db_server_service.clear_connections() + db_server_service.health_state_actual = SoftwareHealthState.GOOD + assert len(db_server_service.connections) == 0 + + computer.apply_timestep(timestep=1) + server.apply_timestep(timestep=1) + + assert len(dos_bot.connections) == db_server_service.max_sessions + assert len(db_server_service.connections) == db_server_service.max_sessions + assert len(dos_bot.connections) == db_server_service.max_sessions + + assert dos_bot.attack_stage is DoSAttackStage.NOT_STARTED + assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED + + +def test_non_repeating_dos_attack(dos_bot_and_db_server): + dos_bot, computer, db_server_service, server = dos_bot_and_db_server + + assert db_server_service.health_state_actual is SoftwareHealthState.GOOD + + dos_bot.port_scan_p_of_success = 1 + dos_bot.repeat = False + dos_bot.run() + + assert len(dos_bot.connections) == db_server_service.max_sessions + assert len(db_server_service.connections) == db_server_service.max_sessions + assert len(dos_bot.connections) == db_server_service.max_sessions + + assert dos_bot.attack_stage is DoSAttackStage.COMPLETED + assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED + + db_server_service.clear_connections() + db_server_service.health_state_actual = SoftwareHealthState.GOOD + assert len(db_server_service.connections) == 0 + + computer.apply_timestep(timestep=1) + server.apply_timestep(timestep=1) + + assert len(dos_bot.connections) == 0 + assert len(db_server_service.connections) == 0 + assert len(dos_bot.connections) == 0 + + assert dos_bot.attack_stage is DoSAttackStage.COMPLETED + assert db_server_service.health_state_actual is SoftwareHealthState.GOOD + + +def test_dos_bot_database_service_connection(dos_bot_and_db_server): + dos_bot, computer, db_server_service, server = dos_bot_and_db_server + + dos_bot.operating_state = ApplicationOperatingState.RUNNING + dos_bot.attack_stage = DoSAttackStage.PORT_SCAN + dos_bot._perform_dos() + + assert len(dos_bot.connections) == db_server_service.max_sessions + assert len(db_server_service.connections) == db_server_service.max_sessions + assert len(dos_bot.connections) == db_server_service.max_sessions diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py new file mode 100644 index 00000000..71489171 --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py @@ -0,0 +1,90 @@ +from ipaddress import IPv4Address + +import pytest + +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.computer import Computer +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.application import ApplicationOperatingState +from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot + + +@pytest.fixture(scope="function") +def dos_bot() -> DoSBot: + computer = Computer( + hostname="compromised_pc", + ip_address="192.168.0.1", + subnet_mask="255.255.255.0", + operating_state=NodeOperatingState.ON, + ) + + computer.software_manager.install(DoSBot) + + dos_bot: DoSBot = computer.software_manager.software.get("DoSBot") + dos_bot.configure(target_ip_address=IPv4Address("192.168.0.1")) + dos_bot.set_original_state() + return dos_bot + + +def test_dos_bot_creation(dos_bot): + """Test that the DoS bot is installed on a node.""" + assert dos_bot is not None + + +def test_dos_bot_reset(dos_bot): + assert dos_bot.target_ip_address == IPv4Address("192.168.0.1") + assert dos_bot.target_port is Port.POSTGRES_SERVER + assert dos_bot.payload is None + assert dos_bot.repeat is False + + dos_bot.configure( + target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True + ) + + # should reset the relevant items + dos_bot.reset_component_for_episode(episode=0) + assert dos_bot.target_ip_address == IPv4Address("192.168.0.1") + assert dos_bot.target_port is Port.POSTGRES_SERVER + assert dos_bot.payload is None + assert dos_bot.repeat is False + + dos_bot.configure( + target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True + ) + dos_bot.set_original_state() + dos_bot.reset_component_for_episode(episode=1) + # should reset to the configured value + assert dos_bot.target_ip_address == IPv4Address("192.168.1.1") + assert dos_bot.target_port is Port.HTTP + assert dos_bot.payload == "payload" + assert dos_bot.repeat is True + + +def test_dos_bot_cannot_run_when_node_offline(dos_bot): + dos_bot_node: Computer = dos_bot.parent + assert dos_bot_node.operating_state is NodeOperatingState.ON + + dos_bot_node.power_off() + + for i in range(dos_bot_node.shut_down_duration + 1): + dos_bot_node.apply_timestep(timestep=i) + + assert dos_bot_node.operating_state is NodeOperatingState.OFF + + dos_bot._application_loop() + + # assert not run + assert dos_bot.attack_stage is DoSAttackStage.NOT_STARTED + + +def test_dos_bot_not_configured(dos_bot): + dos_bot.target_ip_address = None + + dos_bot.operating_state = ApplicationOperatingState.RUNNING + dos_bot._application_loop() + + +def test_dos_bot_perform_port_scan(dos_bot): + dos_bot._perform_port_scan(p_of_success=1) + + assert dos_bot.attack_stage is DoSAttackStage.PORT_SCAN diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py index 15d28d4b..204b356f 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py @@ -63,10 +63,11 @@ def test_disconnect_when_client_is_closed(database_client_on_computer): def test_disconnect(database_client_on_computer): - """Database client should set connected to False and remove the database server ip address.""" + """Database client should remove the connection.""" database_client, computer = database_client_on_computer - database_client.connections[uuid4()] = {} + database_client._connections[str(uuid4())] = {"item": True} + assert len(database_client.connections) == 1 assert database_client.operating_state is ApplicationOperatingState.RUNNING assert database_client.server_ip_address is not None @@ -75,6 +76,14 @@ def test_disconnect(database_client_on_computer): assert len(database_client.connections) == 0 + uuid = str(uuid4()) + database_client._connections[uuid] = {"item": True} + assert len(database_client.connections) == 1 + + database_client.disconnect(connection_id=uuid) + + assert len(database_client.connections) == 0 + def test_query_when_client_is_closed(database_client_on_computer): """Database client should return False when it is not running.""" diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_services.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_services.py index b32463a2..016cf011 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_services.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_services.py @@ -1,3 +1,5 @@ +from uuid import uuid4 + from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.software import SoftwareHealthState @@ -66,3 +68,34 @@ def test_enable_disable(service): service.enable() assert service.operating_state == ServiceOperatingState.STOPPED + + +def test_overwhelm_service(service): + service.max_sessions = 2 + service.start() + + uuid = str(uuid4()) + assert service.add_connection(connection_id=uuid) # should be true + assert service.health_state_actual is SoftwareHealthState.GOOD + + assert not service.add_connection(connection_id=uuid) # fails because connection already exists + assert service.health_state_actual is SoftwareHealthState.GOOD + + assert service.add_connection(connection_id=str(uuid4())) # succeed + assert service.health_state_actual is SoftwareHealthState.GOOD + + assert not service.add_connection(connection_id=str(uuid4())) # fail because at capacity + assert service.health_state_actual is SoftwareHealthState.OVERWHELMED + + +def test_create_and_remove_connections(service): + service.start() + uuid = str(uuid4()) + + assert service.add_connection(connection_id=uuid) # should be true + assert len(service.connections) == 1 + assert service.health_state_actual is SoftwareHealthState.GOOD + + assert service.remove_connection(connection_id=uuid) # should be true + assert len(service.connections) == 0 + assert service.health_state_actual is SoftwareHealthState.GOOD From f0be77c79b6f2118489b972fcd443d934b650129 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Tue, 12 Dec 2023 17:20:31 +0000 Subject: [PATCH 4/5] #2059: configure missing configurable items --- .../system/applications/red_applications/dos_bot.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index e6c643ee..84e0abb2 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -49,7 +49,7 @@ class DoSBot(DatabaseClient, Application): port_scan_p_of_success: float = 0.1 """Probability of port scanning being sucessful.""" - dos_intensity: float = 0.25 + dos_intensity: float = 1 """How much of the max sessions will be used by the DoS when attacking.""" def __init__(self, **kwargs): @@ -91,6 +91,8 @@ class DoSBot(DatabaseClient, Application): target_port: Optional[Port] = Port.POSTGRES_SERVER, payload: Optional[str] = None, repeat: bool = False, + port_scan_p_of_success: float = 0.1, + dos_intensity: float = 1, max_sessions: int = 1000, ): """ @@ -100,15 +102,21 @@ class DoSBot(DatabaseClient, Application): :param: target_port: The port of the target service. Optional - Default is `Port.HTTP` :param: payload: The payload the DoS Bot will throw at the target service. Optional - Default is `None` :param: repeat: If True, the bot will maintain the attack. Optional - Default is `True` + :param: port_scan_p_of_success: The chance of the port scan being sucessful. Optional - Default is 0.1 (10%) + :param: dos_intensity: The intensity of the DoS attack. + Multiplied with the application's max session - Default is 1.0 :param: max_sessions: The maximum number of sessions the DoS bot will attack with. Optional - Default is 1000 """ self.target_ip_address = target_ip_address self.target_port = target_port self.payload = payload self.repeat = repeat + self.port_scan_p_of_success = port_scan_p_of_success + self.dos_intensity = dos_intensity self.max_sessions = max_sessions self.sys_log.info( - f"{self.name}: Configured the {self.name} with {target_ip_address=}, {target_port=}, {payload=}, {repeat=}." + f"{self.name}: Configured the {self.name} with {target_ip_address=}, {target_port=}, {payload=}, " + f"{repeat=}, {port_scan_p_of_success=}, {dos_intensity=}, {max_sessions=}." ) def run(self): From 592e1a3610c2849e8873a9e372a6774ef9b95df7 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Wed, 13 Dec 2023 11:56:25 +0000 Subject: [PATCH 5/5] #2059: apply suggestions from PR + adding another test that checks for dos affecting green agent --- .../applications/red_applications/dos_bot.py | 4 +- .../services/database/database_service.py | 4 + .../simulator/system/services/service.py | 2 +- .../test_dos_bot_and_server.py | 75 ++++++++++++++++++- 4 files changed, 81 insertions(+), 4 deletions(-) diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index 84e0abb2..dfc48dd3 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -49,7 +49,7 @@ class DoSBot(DatabaseClient, Application): port_scan_p_of_success: float = 0.1 """Probability of port scanning being sucessful.""" - dos_intensity: float = 1 + dos_intensity: float = 1.0 """How much of the max sessions will be used by the DoS when attacking.""" def __init__(self, **kwargs): @@ -92,7 +92,7 @@ class DoSBot(DatabaseClient, Application): payload: Optional[str] = None, repeat: bool = False, port_scan_p_of_success: float = 0.1, - dos_intensity: float = 1, + dos_intensity: float = 1.0, max_sessions: int = 1000, ): """ diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 7d313068..6f333091 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -143,6 +143,10 @@ class DatabaseService(Service): status_code = 500 # Default internal server error if self.operating_state == ServiceOperatingState.RUNNING: status_code = 503 # service unavailable + if self.health_state_actual == SoftwareHealthState.OVERWHELMED: + self.sys_log.error( + f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity." + ) if self.health_state_actual == SoftwareHealthState.GOOD: if self.password == password: status_code = 200 # ok diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 3155a4bd..d45ef3a6 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -58,7 +58,7 @@ class Service(IOSoftware): if not super()._can_perform_action(): return False - if self.operating_state is not self.operating_state.RUNNING: + if self.operating_state is not ServiceOperatingState.RUNNING: # service is not running _LOGGER.error(f"Cannot perform action: {self.name} is {self.operating_state.name}") return False diff --git a/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py b/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py index 2828cc25..85028d75 100644 --- a/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py +++ b/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py @@ -3,10 +3,13 @@ from typing import Tuple import pytest +from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.computer import Computer +from primaite.simulator.network.hardware.nodes.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.server import Server from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import ApplicationOperatingState +from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.software import SoftwareHealthState @@ -25,7 +28,7 @@ def dos_bot_and_db_server(client_server) -> Tuple[DoSBot, Computer, DatabaseServ target_port=Port.POSTGRES_SERVER, ) - # Install FTP Server service on server + # Install DB Server service on server server.software_manager.install(DatabaseService) db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService") db_server_service.start() @@ -33,6 +36,43 @@ def dos_bot_and_db_server(client_server) -> Tuple[DoSBot, Computer, DatabaseServ return dos_bot, computer, db_server_service, server +@pytest.fixture(scope="function") +def dos_bot_db_server_green_client(example_network) -> Network: + network: Network = example_network + + router_1: Router = example_network.get_node_by_hostname("router_1") + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + ) + + client_1: Computer = network.get_node_by_hostname("client_1") + client_2: Computer = network.get_node_by_hostname("client_2") + server: Server = network.get_node_by_hostname("server_1") + + # install DoS bot on client 1 + client_1.software_manager.install(DoSBot) + + dos_bot: DoSBot = client_1.software_manager.software.get("DoSBot") + dos_bot.configure( + target_ip_address=IPv4Address(server.nics.get(next(iter(server.nics))).ip_address), + target_port=Port.POSTGRES_SERVER, + ) + + # install db server service on server + server.software_manager.install(DatabaseService) + db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService") + db_server_service.start() + + # Install DB client (green) on client 2 + client_2.software_manager.install(DatabaseClient) + + database_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient") + database_client.configure(server_ip_address=IPv4Address("192.168.0.1")) + database_client.run() + + return network + + def test_repeating_dos_attack(dos_bot_and_db_server): dos_bot, computer, db_server_service, server = dos_bot_and_db_server @@ -105,3 +145,36 @@ def test_dos_bot_database_service_connection(dos_bot_and_db_server): assert len(dos_bot.connections) == db_server_service.max_sessions assert len(db_server_service.connections) == db_server_service.max_sessions assert len(dos_bot.connections) == db_server_service.max_sessions + + +def test_dos_blocks_green_agent_connection(dos_bot_db_server_green_client): + network: Network = dos_bot_db_server_green_client + + client_1: Computer = network.get_node_by_hostname("client_1") + dos_bot: DoSBot = client_1.software_manager.software.get("DoSBot") + + client_2: Computer = network.get_node_by_hostname("client_2") + green_db_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient") + + server: Server = network.get_node_by_hostname("server_1") + db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService") + + assert db_server_service.health_state_actual is SoftwareHealthState.GOOD + + dos_bot.port_scan_p_of_success = 1 + dos_bot.repeat = False + dos_bot.run() + + # DoS bot fills up connection of db server service + assert len(dos_bot.connections) == db_server_service.max_sessions + assert len(db_server_service.connections) == db_server_service.max_sessions + assert len(dos_bot.connections) == db_server_service.max_sessions + assert len(green_db_client.connections) == 0 + + assert dos_bot.attack_stage is DoSAttackStage.COMPLETED + # db server service is overwhelmed + assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED + + # green agent tries to connect but fails because service is overwhelmed + assert green_db_client.connect() is False + assert len(green_db_client.connections) == 0